Spaces:
Sleeping
Sleeping
tasin
commited on
Commit
·
f075308
1
Parent(s):
3f8c938
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- README.md +84 -12
- Untitled.ipynb +68 -0
- app.py +188 -0
- distributed.py +126 -0
- model_structure.txt +113 -0
- models/diffusion_model.py +762 -0
- models/unet_dual_encoder.py +62 -0
- project_latent_space.py +75 -0
- requirements.txt +30 -0
- sample/blue.jpg +0 -0
- sample/green.jpg +0 -0
- sample/silver.jpg +0 -0
- src/deps/__init__.py +0 -0
- src/deps/facial_recognition/__init__.py +3 -0
- src/deps/facial_recognition/helpers.py +123 -0
- src/deps/facial_recognition/model_irse.py +88 -0
- src/dnnlib/__init__.py +9 -0
- src/dnnlib/util.py +480 -0
- src/infra/__init__.py +0 -0
- src/infra/experiments.yaml +60 -0
- src/infra/launch.py +113 -0
- src/infra/slurm_batch_launch.py +96 -0
- src/infra/slurm_job.py +46 -0
- src/infra/slurm_job_proxy.sh +4 -0
- src/infra/utils.py +140 -0
- src/metrics/__init__.py +9 -0
- src/metrics/frechet_inception_distance.py +54 -0
- src/metrics/frechet_video_distance.py +59 -0
- src/metrics/inception_score.py +47 -0
- src/metrics/kernel_inception_distance.py +46 -0
- src/metrics/metric_main.py +154 -0
- src/metrics/metric_utils.py +332 -0
- src/metrics/video_inception_score.py +54 -0
- src/scripts/__init__.py +0 -0
- src/scripts/calc_metrics.py +250 -0
- src/scripts/calc_metrics_for_dataset.py +169 -0
- src/scripts/clip_edit.py +403 -0
- src/scripts/construct_static_videos_dataset.py +46 -0
- src/scripts/convert_video_to_dataset.py +87 -0
- src/scripts/convert_videos_to_frames.py +105 -0
- src/scripts/crop_video_dataset.py +69 -0
- src/scripts/frames_to_video_grid.py +78 -0
- src/scripts/generate.py +148 -0
- src/scripts/preprocess_ffs.py +204 -0
- src/scripts/profile_model.py +104 -0
- src/scripts/project.py +479 -0
- src/torch_utils/__init__.py +9 -0
- src/torch_utils/custom_ops.py +126 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.idea/
|
.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [],
|
| 3 |
+
"metadata": {},
|
| 4 |
+
"nbformat": 4,
|
| 5 |
+
"nbformat_minor": 5
|
| 6 |
+
}
|
README.md
CHANGED
|
@@ -1,12 +1,84 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div id="top"></div>
|
| 2 |
+
|
| 3 |
+
<h3>FashionFlow: Leveraging Diffusion Models for Dynamic Fashion Video Synthesis from Static Imagery</h3>
|
| 4 |
+
|
| 5 |
+
<p>
|
| 6 |
+
This repository has the official code for 'FashionFlow: Leveraging Diffusion Models for Dynamic Fashion Video Synthesis from Static Imagery'.
|
| 7 |
+
We have included the pre-trained checkpoint, dataset and results.
|
| 8 |
+
</p>
|
| 9 |
+
|
| 10 |
+
> **Abstract:** *Our study introduces a new image-to-video generator called FashionFlow to generate fashion videos. By utilising a diffusion model, we are able to create short videos from still fashion images. Our approach involves developing and connecting relevant components with the diffusion model, which results in the creation of high-fidelity videos that are aligned with the conditional image. The components include the use of pseudo-3D convolutional layers to generate videos efficiently. VAE and CLIP encoders capture vital characteristics from still images to condition the diffusion model at a global level. Our research demonstrates a successful synthesis of fashion videos featuring models posing from various angles, showcasing the fit and appearance of the garment. Our findings hold great promise for improving and enhancing the shopping experience for the online fashion industry.*
|
| 11 |
+
|
| 12 |
+
<!-- Results -->
|
| 13 |
+
## Teaser
|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
## Requirements
|
| 17 |
+
- Python 3.9
|
| 18 |
+
- PyTorch 1.11+
|
| 19 |
+
- Tensoboard
|
| 20 |
+
- cv2
|
| 21 |
+
- transformers
|
| 22 |
+
- diffusers
|
| 23 |
+
|
| 24 |
+
## Model Specification
|
| 25 |
+
|
| 26 |
+
The model was developed using PyTorch and loads pretrained weights for VAE and CLIP. The latent diffusion model consists of a 1D convolutional layer stacked against a 2D convolutional layer (forming a pseudo 3D convolution) and includes attention layers. See the ```model_structure.txt``` file to see the exact layers of our LDM.
|
| 27 |
+
|
| 28 |
+
## Installation
|
| 29 |
+
|
| 30 |
+
Clone this repository:
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
git clone https://github.com/1702609/FashionFlow
|
| 34 |
+
cd ./FashionFlow/
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Install PyTorch and other dependencies:
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Dataset
|
| 45 |
+
|
| 46 |
+
Download the Fashion dataset by clicking on this link:
|
| 47 |
+
[[Fashion dataset]](https://vision.cs.ubc.ca/datasets/fashion/)
|
| 48 |
+
|
| 49 |
+
Extract the files and place them in the ```fashion_dataset``` directory. The dataset should be organised as follows:
|
| 50 |
+
|
| 51 |
+
```
|
| 52 |
+
fashion_dataset
|
| 53 |
+
test
|
| 54 |
+
|-- 91-3003CN5S.mp4
|
| 55 |
+
|-- 91BjuE6irxS.mp4
|
| 56 |
+
|-- 91bxAN6BjAS.mp4
|
| 57 |
+
|-- ...
|
| 58 |
+
train
|
| 59 |
+
|-- 81FyMPk-WIS.mp4
|
| 60 |
+
|-- 91+bCFG1jOS.mp4
|
| 61 |
+
|-- 91+PxmDyrgS.mp4
|
| 62 |
+
|-- ...
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Feel free to add your own dataset while following the provided file and folder structure.
|
| 66 |
+
|
| 67 |
+
## Pre-trained Checkpoint
|
| 68 |
+
|
| 69 |
+
Download the checkpoint by clicking on this link:
|
| 70 |
+
[[Pre-trained checkpoints]](https://www.dropbox.com/scl/fi/p9fv7o3j7ti0yu2umsgmv/FashionFlow_checkpoint.pth?rlkey=mqsto9i4ujh6xhvab0e2s6n7d&dl=0)
|
| 71 |
+
Extract the files and place them in the ```checkpoint``` directory
|
| 72 |
+
|
| 73 |
+
## Inference
|
| 74 |
+
To run the inference of our model, execute ```python inference.py```. The results will be saved in the ```result``` directory.
|
| 75 |
+
|
| 76 |
+
## Train
|
| 77 |
+
|
| 78 |
+
Before training, images and videos have to be projected to latent space for efficient training. Execute ```python project_latent_space.py``` where the tensors will be saved in the ```fashion_dataset_tensor``` directory.
|
| 79 |
+
|
| 80 |
+
Run ```python -m torch.distributed.launch --nproc_per_node=<number of GPUs> train.py``` to train the model. The checkpoints will be saved in the ```checkpoint``` directory periodically. Also, you can view the training progress using tensorboardX located in ```video_progress``` or find the generated ```.mp4``` on ```training_sample```.
|
| 81 |
+
|
| 82 |
+
## Comparison
|
| 83 |
+
|
| 84 |
+

|
Untitled.ipynb
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "5b3c9fac-51c3-4ecc-8606-5a298076560e",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from huggingface_hub import notebook_login\n"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": 2,
|
| 16 |
+
"id": "51fe1170-c6eb-4b3a-a055-663faf35ab5a",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [
|
| 19 |
+
{
|
| 20 |
+
"data": {
|
| 21 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 22 |
+
"model_id": "6c6d20c8f5e847d7985f6b49a7206a2d",
|
| 23 |
+
"version_major": 2,
|
| 24 |
+
"version_minor": 0
|
| 25 |
+
},
|
| 26 |
+
"text/plain": [
|
| 27 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"output_type": "display_data"
|
| 32 |
+
}
|
| 33 |
+
],
|
| 34 |
+
"source": [
|
| 35 |
+
"notebook_login()"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"id": "4297ea17-a4f8-4290-b561-f582b1adc189",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": []
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"metadata": {
|
| 48 |
+
"kernelspec": {
|
| 49 |
+
"display_name": "work",
|
| 50 |
+
"language": "python",
|
| 51 |
+
"name": "work"
|
| 52 |
+
},
|
| 53 |
+
"language_info": {
|
| 54 |
+
"codemirror_mode": {
|
| 55 |
+
"name": "ipython",
|
| 56 |
+
"version": 3
|
| 57 |
+
},
|
| 58 |
+
"file_extension": ".py",
|
| 59 |
+
"mimetype": "text/x-python",
|
| 60 |
+
"name": "python",
|
| 61 |
+
"nbconvert_exporter": "python",
|
| 62 |
+
"pygments_lexer": "ipython3",
|
| 63 |
+
"version": "3.11.9"
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"nbformat": 4,
|
| 67 |
+
"nbformat_minor": 5
|
| 68 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
from models.unet_dual_encoder import Embedding_Adapter
|
| 5 |
+
from models.diffusion_model import SpaceTimeUnet
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torchvision.transforms.functional as TVF
|
| 8 |
+
from diffusers import AutoencoderKL
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from transformers import CLIPVisionModel, CLIPProcessor
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import gradio as gr
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
|
| 15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
+
frameLimit = 70
|
| 17 |
+
|
| 18 |
+
def cosine_beta_schedule(timesteps, start=0.0001, end=0.02):
|
| 19 |
+
betas = []
|
| 20 |
+
for i in reversed(range(timesteps)):
|
| 21 |
+
T = timesteps - 1
|
| 22 |
+
beta = start + 0.5 * (end - start) * (1 + np.cos((i / T) * np.pi))
|
| 23 |
+
betas.append(beta)
|
| 24 |
+
return torch.Tensor(betas)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_index_from_list(vals, t, x_shape):
|
| 28 |
+
batch_size = t.shape[0]
|
| 29 |
+
out = vals.gather(-1, t.cpu())
|
| 30 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
| 31 |
+
|
| 32 |
+
def forward_diffusion_sample(x_0, t):
|
| 33 |
+
noise = torch.randn_like(x_0)
|
| 34 |
+
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
|
| 35 |
+
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
|
| 36 |
+
sqrt_one_minus_alphas_cumprod, t, x_0.shape
|
| 37 |
+
)
|
| 38 |
+
# mean + variance
|
| 39 |
+
return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
|
| 40 |
+
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
|
| 41 |
+
|
| 42 |
+
T = 1000
|
| 43 |
+
betas = cosine_beta_schedule(timesteps=T)
|
| 44 |
+
# Pre-calculate different terms for closed form
|
| 45 |
+
alphas = 1. - betas
|
| 46 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
| 47 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 48 |
+
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
|
| 49 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 50 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
|
| 51 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
| 52 |
+
|
| 53 |
+
def get_transform():
|
| 54 |
+
image_transforms = transforms.Compose(
|
| 55 |
+
[
|
| 56 |
+
transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 57 |
+
transforms.ToTensor(),
|
| 58 |
+
])
|
| 59 |
+
return image_transforms
|
| 60 |
+
|
| 61 |
+
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4",
|
| 62 |
+
subfolder="vae",
|
| 63 |
+
revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c")
|
| 64 |
+
vae.to(device)
|
| 65 |
+
vae.requires_grad_(False)
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
Net = SpaceTimeUnet(
|
| 69 |
+
dim = 64,
|
| 70 |
+
channels = 4,
|
| 71 |
+
dim_mult = (1, 2, 4, 8),
|
| 72 |
+
temporal_compression = (False, False, False, True),
|
| 73 |
+
self_attns = (False, False, False, True),
|
| 74 |
+
condition_on_timestep=True
|
| 75 |
+
).to(device)
|
| 76 |
+
adapter = Embedding_Adapter(input_nc=1280, output_nc=1280).to(device)
|
| 77 |
+
|
| 78 |
+
clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 79 |
+
clip_encoder.requires_grad_(False)
|
| 80 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 81 |
+
|
| 82 |
+
checkpoint = torch.load(hf_hub_download(repo_id="sunjuice/FashionFlow_model", filename="FashionFlow_model.pth"))
|
| 83 |
+
|
| 84 |
+
Net.load_state_dict(checkpoint['net'])
|
| 85 |
+
adapter.load_state_dict(checkpoint['adapter'])
|
| 86 |
+
del checkpoint
|
| 87 |
+
torch.cuda.empty_cache()
|
| 88 |
+
|
| 89 |
+
def save_video_frames_as_mp4(frames, fps, save_path):
|
| 90 |
+
frame_h, frame_w = frames[0].shape[2:]
|
| 91 |
+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
| 92 |
+
video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h))
|
| 93 |
+
frames = frames[0]
|
| 94 |
+
for frame in frames:
|
| 95 |
+
frame = np.array(TVF.to_pil_image(frame))
|
| 96 |
+
video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
| 97 |
+
video.release()
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def VAE_encode(image):
|
| 101 |
+
init_latent_dist = vae.encode(image).latent_dist.sample()
|
| 102 |
+
init_latent_dist *= 0.18215
|
| 103 |
+
encoded_image = (init_latent_dist).unsqueeze(1)
|
| 104 |
+
return encoded_image
|
| 105 |
+
|
| 106 |
+
@torch.no_grad()
|
| 107 |
+
def VAE_decode(video, vae_net):
|
| 108 |
+
decoded_video = None
|
| 109 |
+
for i in range(video.shape[1]):
|
| 110 |
+
image = video[:, i, :, :, :]
|
| 111 |
+
image = 1 / 0.18215 * image
|
| 112 |
+
image = vae_net.decode(image).sample
|
| 113 |
+
image = image.clamp(0,1)
|
| 114 |
+
if i == 0:
|
| 115 |
+
decoded_video = image.unsqueeze(1)
|
| 116 |
+
else:
|
| 117 |
+
decoded_video = torch.cat([decoded_video, image.unsqueeze(1)], 1)
|
| 118 |
+
return decoded_video
|
| 119 |
+
|
| 120 |
+
@torch.no_grad()
|
| 121 |
+
def sample_timestep(x, image, t):
|
| 122 |
+
betas_t = get_index_from_list(betas, t, x.shape)
|
| 123 |
+
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
|
| 124 |
+
sqrt_one_minus_alphas_cumprod, t, x.shape
|
| 125 |
+
)
|
| 126 |
+
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
|
| 127 |
+
|
| 128 |
+
# Call model (current image - noise prediction)
|
| 129 |
+
with torch.cuda.amp.autocast():
|
| 130 |
+
sample_output = Net(x.permute(0, 2, 1, 3, 4), image, timestep=t.float())
|
| 131 |
+
|
| 132 |
+
sample_output = sample_output.permute(0, 2, 1, 3, 4)
|
| 133 |
+
model_mean = sqrt_recip_alphas_t * (
|
| 134 |
+
x - betas_t * sample_output / sqrt_one_minus_alphas_cumprod_t
|
| 135 |
+
)
|
| 136 |
+
if t.item() == 0:
|
| 137 |
+
return model_mean
|
| 138 |
+
else:
|
| 139 |
+
noise = torch.randn_like(x)
|
| 140 |
+
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
|
| 141 |
+
return model_mean + torch.sqrt(posterior_variance_t) * noise
|
| 142 |
+
|
| 143 |
+
def tensor2image(tensor):
|
| 144 |
+
numpy_image = tensor[0].cpu().detach().numpy()
|
| 145 |
+
rescaled_image = (numpy_image * 255).astype(np.uint8)
|
| 146 |
+
pil_image = Image.fromarray(rescaled_image.transpose(1, 2, 0))
|
| 147 |
+
return pil_image
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def get_image_embedding(input_image):
|
| 151 |
+
inputs = clip_processor(images=list(input_image), return_tensors="pt")
|
| 152 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 153 |
+
clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device)
|
| 154 |
+
vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215
|
| 155 |
+
encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states)
|
| 156 |
+
return encoder_hidden_states
|
| 157 |
+
|
| 158 |
+
def predict_fn(img_path, progress=gr.Progress()):
|
| 159 |
+
image = get_transform(Image.open(img_path).convert('RGB'))
|
| 160 |
+
encoder_hidden_states = get_image_embedding(input_image=image)
|
| 161 |
+
encoded_image = VAE_encode(image)
|
| 162 |
+
noise_video = torch.randn([1, frameLimit, 4, 80, 64]).to(device)
|
| 163 |
+
noise_video[:, 0:1] = encoded_image
|
| 164 |
+
with torch.no_grad():
|
| 165 |
+
for i in progress.tqdm(range(0, T)[::-1]):
|
| 166 |
+
t = torch.full((1,), i, device=device).long()
|
| 167 |
+
noise_video = sample_timestep(noise_video, encoder_hidden_states, t)
|
| 168 |
+
noise_video[:, 0:1] = encoded_image
|
| 169 |
+
final_video = VAE_decode(noise_video, vae)
|
| 170 |
+
save_video_frames_as_mp4(final_video, 25, "result.mp4")
|
| 171 |
+
return "result.mp4"
|
| 172 |
+
|
| 173 |
+
with gr.Tab("Image-to-Video"):
|
| 174 |
+
with gr.Row():
|
| 175 |
+
with gr.Column():
|
| 176 |
+
image_input = gr.Image(type="pil", label="Input Image")
|
| 177 |
+
img_generate = gr.Button("Generate Video")
|
| 178 |
+
with gr.Column():
|
| 179 |
+
img_output = gr.Video(label="Generated Video")
|
| 180 |
+
gr.Examples(
|
| 181 |
+
examples=[
|
| 182 |
+
['sample/blue.jpg',]
|
| 183 |
+
],
|
| 184 |
+
inputs=[image_input],
|
| 185 |
+
outputs=[img_output],
|
| 186 |
+
fn=predict_fn,
|
| 187 |
+
cache_examples='lazy',
|
| 188 |
+
)
|
distributed.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import distributed as dist
|
| 6 |
+
from torch.utils.data.sampler import Sampler
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_rank():
|
| 10 |
+
if not dist.is_available():
|
| 11 |
+
return 0
|
| 12 |
+
|
| 13 |
+
if not dist.is_initialized():
|
| 14 |
+
return 0
|
| 15 |
+
|
| 16 |
+
return dist.get_rank()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def synchronize():
|
| 20 |
+
if not dist.is_available():
|
| 21 |
+
return
|
| 22 |
+
|
| 23 |
+
if not dist.is_initialized():
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
world_size = dist.get_world_size()
|
| 27 |
+
|
| 28 |
+
if world_size == 1:
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
dist.barrier()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_world_size():
|
| 35 |
+
if not dist.is_available():
|
| 36 |
+
return 1
|
| 37 |
+
|
| 38 |
+
if not dist.is_initialized():
|
| 39 |
+
return 1
|
| 40 |
+
|
| 41 |
+
return dist.get_world_size()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def reduce_sum(tensor):
|
| 45 |
+
if not dist.is_available():
|
| 46 |
+
return tensor
|
| 47 |
+
|
| 48 |
+
if not dist.is_initialized():
|
| 49 |
+
return tensor
|
| 50 |
+
|
| 51 |
+
tensor = tensor.clone()
|
| 52 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
| 53 |
+
|
| 54 |
+
return tensor
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def gather_grad(params):
|
| 58 |
+
world_size = get_world_size()
|
| 59 |
+
|
| 60 |
+
if world_size == 1:
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
for param in params:
|
| 64 |
+
if param.grad is not None:
|
| 65 |
+
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
|
| 66 |
+
param.grad.data.div_(world_size)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def all_gather(data):
|
| 70 |
+
world_size = get_world_size()
|
| 71 |
+
|
| 72 |
+
if world_size == 1:
|
| 73 |
+
return [data]
|
| 74 |
+
|
| 75 |
+
buffer = pickle.dumps(data)
|
| 76 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
| 77 |
+
tensor = torch.ByteTensor(storage).to('cuda')
|
| 78 |
+
|
| 79 |
+
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
|
| 80 |
+
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
|
| 81 |
+
dist.all_gather(size_list, local_size)
|
| 82 |
+
size_list = [int(size.item()) for size in size_list]
|
| 83 |
+
max_size = max(size_list)
|
| 84 |
+
|
| 85 |
+
tensor_list = []
|
| 86 |
+
for _ in size_list:
|
| 87 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
|
| 88 |
+
|
| 89 |
+
if local_size != max_size:
|
| 90 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
|
| 91 |
+
tensor = torch.cat((tensor, padding), 0)
|
| 92 |
+
|
| 93 |
+
dist.all_gather(tensor_list, tensor)
|
| 94 |
+
|
| 95 |
+
data_list = []
|
| 96 |
+
|
| 97 |
+
for size, tensor in zip(size_list, tensor_list):
|
| 98 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
| 99 |
+
data_list.append(pickle.loads(buffer))
|
| 100 |
+
|
| 101 |
+
return data_list
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def reduce_loss_dict(loss_dict):
|
| 105 |
+
world_size = get_world_size()
|
| 106 |
+
|
| 107 |
+
if world_size < 2:
|
| 108 |
+
return loss_dict
|
| 109 |
+
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
keys = []
|
| 112 |
+
losses = []
|
| 113 |
+
|
| 114 |
+
for k in sorted(loss_dict.keys()):
|
| 115 |
+
keys.append(k)
|
| 116 |
+
losses.append(loss_dict[k])
|
| 117 |
+
|
| 118 |
+
losses = torch.stack(losses, 0)
|
| 119 |
+
dist.reduce(losses, dst=0)
|
| 120 |
+
|
| 121 |
+
if dist.get_rank() == 0:
|
| 122 |
+
losses /= world_size
|
| 123 |
+
|
| 124 |
+
reduced_losses = {k: v for k, v in zip(keys, losses)}
|
| 125 |
+
|
| 126 |
+
return reduced_losses
|
model_structure.txt
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
======================================================================================================================================================
|
| 2 |
+
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
|
| 3 |
+
======================================================================================================================================================
|
| 4 |
+
SpaceTimeUnet (SpaceTimeUnet) [1, 4, 70, 80, 64] [1, 4, 70, 80, 64] -- True
|
| 5 |
+
├─Sequential (to_timestep_cond) [1] [1, 256] -- True
|
| 6 |
+
│ └─SinusoidalPosEmb (0) [1] [1, 64] -- --
|
| 7 |
+
│ └─Linear (1) [1, 64] [1, 256] 16,640 True
|
| 8 |
+
│ └─SiLU (2) [1, 256] [1, 256] -- --
|
| 9 |
+
├─PseudoConv3d (conv_in) [1, 4, 70, 80, 64] [1, 64, 70, 80, 64] -- True
|
| 10 |
+
│ └─Conv2d (spatial_conv) [70, 4, 80, 64] [70, 64, 80, 64] 12,608 True
|
| 11 |
+
│ └─Conv1d (temporal_conv) [5120, 64, 70] [5120, 64, 70] 12,352 True
|
| 12 |
+
├─ModuleList (downs) -- -- -- True
|
| 13 |
+
│ └─ModuleList (0) -- -- -- True
|
| 14 |
+
│ │ └─ResnetBlock (0) [1, 64, 70, 80, 64] [1, 64, 70, 80, 64] 131,712 True
|
| 15 |
+
│ │ └─ModuleList (1) -- -- 197,632 True
|
| 16 |
+
│ │ └─Downsample (3) [1, 64, 70, 80, 64] [1, 64, 70, 40, 32] 16,384 True
|
| 17 |
+
│ │ └─AttentionBlock (4) [1, 64, 70, 40, 32] [1, 64, 70, 40, 32] 160,704 True
|
| 18 |
+
│ └─ModuleList (1) -- -- -- True
|
| 19 |
+
│ │ └─ResnetBlock (0) [1, 64, 70, 40, 32] [1, 128, 70, 40, 32] 394,624 True
|
| 20 |
+
│ │ └─ModuleList (1) -- -- 788,480 True
|
| 21 |
+
│ │ └─Downsample (3) [1, 128, 70, 40, 32] [1, 128, 70, 20, 16] 65,536 True
|
| 22 |
+
│ │ └─AttentionBlock (4) [1, 128, 70, 20, 16] [1, 128, 70, 20, 16] 444,288 True
|
| 23 |
+
│ └─ModuleList (2) -- -- -- True
|
| 24 |
+
│ │ └─ResnetBlock (0) [1, 128, 70, 20, 16] [1, 256, 70, 20, 16] 1,444,608 True
|
| 25 |
+
│ │ └─ModuleList (1) -- -- 3,149,824 True
|
| 26 |
+
│ │ └─Downsample (3) [1, 256, 70, 20, 16] [1, 256, 70, 10, 8] 262,144 True
|
| 27 |
+
│ │ └─AttentionBlock (4) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,380,096 True
|
| 28 |
+
│ └─ModuleList (3) -- -- -- True
|
| 29 |
+
│ │ └─ResnetBlock (0) [1, 256, 70, 10, 8] [1, 512, 70, 10, 8] 5,510,656 True
|
| 30 |
+
│ │ └─ModuleList (1) -- -- 12,591,104 True
|
| 31 |
+
│ │ └─SpatioTemporalAttention (2) [1, 512, 70, 10, 8] [1, 512, 70, 10, 8] 4,334,181 True
|
| 32 |
+
│ │ └─Downsample (3) [1, 512, 70, 10, 8] [1, 512, 35, 5, 4] 1,572,864 True
|
| 33 |
+
│ │ └─AttentionBlock (4) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 4,726,272 True
|
| 34 |
+
├─ResnetBlock (mid_block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 35 |
+
│ └─Sequential (timestep_mlp) [1, 256] [1, 1024] -- True
|
| 36 |
+
│ │ └─SiLU (0) [1, 256] [1, 256] -- --
|
| 37 |
+
│ │ └─Linear (1) [1, 256] [1, 1024] 263,168 True
|
| 38 |
+
│ └─Block (block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 39 |
+
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
|
| 40 |
+
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
|
| 41 |
+
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
|
| 42 |
+
│ └─Block (block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 43 |
+
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
|
| 44 |
+
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
|
| 45 |
+
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
|
| 46 |
+
│ └─Identity (res_conv) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
|
| 47 |
+
├─SpatioTemporalAttention (mid_attn) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 48 |
+
│ └─ContinuousPositionBias (spatial_rel_pos_bias) -- [8, 20, 20] -- True
|
| 49 |
+
│ │ └─ModuleList (net) -- -- 68,616 True
|
| 50 |
+
│ └─Attention (spatial_attn) [35, 20, 512] [35, 20, 512] -- True
|
| 51 |
+
│ │ └─LayerNorm (norm) [35, 20, 512] [35, 20, 512] 1,024 True
|
| 52 |
+
│ │ └─Linear (to_q) [35, 20, 512] [35, 20, 512] 262,144 True
|
| 53 |
+
│ │ └─Linear (to_kv) [35, 20, 512] [35, 20, 1024] 524,288 True
|
| 54 |
+
│ │ └─Linear (to_out) [35, 20, 512] [35, 20, 512] 262,144 True
|
| 55 |
+
│ └─ContinuousPositionBias (temporal_rel_pos_bias) -- [8, 35, 35] -- True
|
| 56 |
+
│ │ └─ModuleList (net) -- -- 68,360 True
|
| 57 |
+
│ └─Attention (temporal_attn) [20, 35, 512] [20, 35, 512] -- True
|
| 58 |
+
│ │ └─LayerNorm (norm) [20, 35, 512] [20, 35, 512] 1,024 True
|
| 59 |
+
│ │ └─Linear (to_q) [20, 35, 512] [20, 35, 512] 262,144 True
|
| 60 |
+
│ │ └─Linear (to_kv) [20, 35, 512] [20, 35, 1024] 524,288 True
|
| 61 |
+
│ │ └─Linear (to_out) [20, 35, 512] [20, 35, 512] 262,144 True
|
| 62 |
+
│ └─FeedForward (ff) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 63 |
+
│ │ └─Sequential (proj_in) [1, 512, 35, 5, 4] [1, 1365, 35, 5, 4] 1,397,760 True
|
| 64 |
+
│ │ └─Sequential (proj_out) [1, 1365, 35, 5, 4] [1, 512, 35, 5, 4] 700,245 True
|
| 65 |
+
├─ResnetBlock (mid_block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 66 |
+
│ └─Sequential (timestep_mlp) [1, 256] [1, 1024] -- True
|
| 67 |
+
│ │ └─SiLU (0) [1, 256] [1, 256] -- --
|
| 68 |
+
│ │ └─Linear (1) [1, 256] [1, 1024] 263,168 True
|
| 69 |
+
│ └─Block (block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 70 |
+
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
|
| 71 |
+
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
|
| 72 |
+
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
|
| 73 |
+
│ └─Block (block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
|
| 74 |
+
│ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
|
| 75 |
+
│ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
|
| 76 |
+
│ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
|
| 77 |
+
│ └─Identity (res_conv) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
|
| 78 |
+
├─ModuleList (ups) -- -- -- True
|
| 79 |
+
│ └─ModuleList (3) -- -- -- True
|
| 80 |
+
│ │ └─Upsample (3) [1, 512, 35, 5, 4] [1, 512, 70, 10, 8] 1,575,936 True
|
| 81 |
+
│ │ └─ResnetBlock (0) [1, 1024, 70, 10, 8] [1, 256, 70, 10, 8] 3,738,368 True
|
| 82 |
+
│ │ └─ModuleList (1) -- -- 4,526,336 True
|
| 83 |
+
│ │ └─SpatioTemporalAttention (2) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,609,786 True
|
| 84 |
+
│ │ └─AttentionBlock (4) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,380,096 True
|
| 85 |
+
│ └─ModuleList (2) -- -- -- True
|
| 86 |
+
│ │ └─Upsample (3) [1, 256, 70, 10, 8] [1, 256, 70, 20, 16] 263,168 True
|
| 87 |
+
│ │ └─ResnetBlock (0) [1, 512, 70, 20, 16] [1, 128, 70, 20, 16] 968,064 True
|
| 88 |
+
│ │ └─ModuleList (1) -- -- 1,132,672 True
|
| 89 |
+
│ │ └─AttentionBlock (4) [1, 128, 70, 20, 16] [1, 128, 70, 20, 16] 444,288 True
|
| 90 |
+
│ └─ModuleList (1) -- -- -- True
|
| 91 |
+
│ │ └─Upsample (3) [1, 128, 70, 20, 16] [1, 128, 70, 40, 32] 66,048 True
|
| 92 |
+
│ │ └─ResnetBlock (0) [1, 256, 70, 40, 32] [1, 64, 70, 40, 32] 258,752 True
|
| 93 |
+
│ │ └─ModuleList (1) -- -- 283,712 True
|
| 94 |
+
│ │ └─AttentionBlock (4) [1, 64, 70, 40, 32] [1, 64, 70, 40, 32] 160,704 True
|
| 95 |
+
│ └─ModuleList (0) -- -- -- True
|
| 96 |
+
│ │ └─Upsample (3) [1, 64, 70, 40, 32] [1, 64, 70, 80, 64] 16,640 True
|
| 97 |
+
│ │ └─ResnetBlock (0) [1, 128, 70, 80, 64] [1, 64, 70, 80, 64] 176,832 True
|
| 98 |
+
│ │ └─ModuleList (1) -- -- 242,752 True
|
| 99 |
+
│ │ └─AttentionBlock (4) [1, 64, 70, 80, 64] [1, 64, 70, 80, 64] 160,704 True
|
| 100 |
+
├─PseudoConv3d (conv_out) [1, 64, 70, 80, 64] [1, 4, 70, 80, 64] -- True
|
| 101 |
+
│ └─Conv2d (spatial_conv) [70, 64, 80, 64] [70, 4, 80, 64] 2,308 True
|
| 102 |
+
│ └─Conv1d (temporal_conv) [5120, 4, 70] [5120, 4, 70] 52 True
|
| 103 |
+
======================================================================================================================================================
|
| 104 |
+
Total params: 71,671,548
|
| 105 |
+
Trainable params: 71,671,548
|
| 106 |
+
Non-trainable params: 0
|
| 107 |
+
Total mult-adds (G): 732.56
|
| 108 |
+
======================================================================================================================================================
|
| 109 |
+
Input size (MB): 5.89
|
| 110 |
+
Forward/backward pass size (MB): 18136.46
|
| 111 |
+
Params size (MB): 286.69
|
| 112 |
+
Estimated Total Size (MB): 18429.04
|
| 113 |
+
======================================================================================================================================================
|
models/diffusion_model.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import functools
|
| 3 |
+
from operator import mul
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn, einsum
|
| 8 |
+
|
| 9 |
+
from einops import rearrange, repeat, pack, unpack
|
| 10 |
+
from einops.layers.torch import Rearrange
|
| 11 |
+
# helper functions
|
| 12 |
+
|
| 13 |
+
def exists(val):
|
| 14 |
+
return val is not None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def default(val, d):
|
| 18 |
+
return val if exists(val) else d
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def mul_reduce(tup):
|
| 22 |
+
return functools.reduce(mul, tup)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def divisible_by(numer, denom):
|
| 26 |
+
return (numer % denom) == 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
mlist = nn.ModuleList
|
| 30 |
+
|
| 31 |
+
# for time conditioning
|
| 32 |
+
|
| 33 |
+
class SinusoidalPosEmb(nn.Module):
|
| 34 |
+
def __init__(self, dim, theta=10000):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.theta = theta
|
| 37 |
+
self.dim = dim
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
dtype, device = x.dtype, x.device
|
| 41 |
+
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
|
| 42 |
+
|
| 43 |
+
half_dim = self.dim // 2
|
| 44 |
+
emb = math.log(self.theta) / (half_dim - 1)
|
| 45 |
+
emb = torch.exp(torch.arange(half_dim, device=device, dtype=dtype) * -emb)
|
| 46 |
+
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
|
| 47 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1).type(dtype)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# layernorm 3d
|
| 51 |
+
|
| 52 |
+
class ChanLayerNorm(nn.Module):
|
| 53 |
+
def __init__(self, dim):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.g = nn.Parameter(torch.ones(dim, 1, 1, 1))
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
|
| 59 |
+
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
|
| 60 |
+
mean = torch.mean(x, dim=1, keepdim=True)
|
| 61 |
+
return (x - mean) * var.clamp(min=eps).rsqrt() * self.g
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# feedforward
|
| 65 |
+
|
| 66 |
+
def shift_token(t):
|
| 67 |
+
t, t_shift = t.chunk(2, dim=1)
|
| 68 |
+
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value=0.)
|
| 69 |
+
return torch.cat((t, t_shift), dim=1)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class GEGLU(nn.Module):
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x, gate = x.chunk(2, dim=1)
|
| 75 |
+
return x * F.gelu(gate)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class FeedForward(nn.Module):
|
| 79 |
+
def __init__(self, dim, mult=4):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
inner_dim = int(dim * mult * 2 / 3)
|
| 83 |
+
self.proj_in = nn.Sequential(
|
| 84 |
+
nn.Conv3d(dim, inner_dim * 2, 1, bias=False),
|
| 85 |
+
GEGLU()
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.proj_out = nn.Sequential(
|
| 89 |
+
ChanLayerNorm(inner_dim),
|
| 90 |
+
nn.Conv3d(inner_dim, dim, 1, bias=False)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def forward(self, x, enable_time=True):
|
| 94 |
+
x = self.proj_in(x)
|
| 95 |
+
if enable_time:
|
| 96 |
+
x = shift_token(x)
|
| 97 |
+
return self.proj_out(x)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# best relative positional encoding
|
| 101 |
+
|
| 102 |
+
class ContinuousPositionBias(nn.Module):
|
| 103 |
+
""" from https://arxiv.org/abs/2111.09883 """
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
*,
|
| 108 |
+
dim,
|
| 109 |
+
heads,
|
| 110 |
+
num_dims=1,
|
| 111 |
+
layers=2
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.num_dims = num_dims
|
| 115 |
+
|
| 116 |
+
self.net = nn.ModuleList([])
|
| 117 |
+
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
|
| 118 |
+
|
| 119 |
+
for _ in range(layers - 1):
|
| 120 |
+
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
|
| 121 |
+
|
| 122 |
+
self.net.append(nn.Linear(dim, heads))
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def device(self):
|
| 126 |
+
return next(self.parameters()).device
|
| 127 |
+
|
| 128 |
+
def forward(self, *dimensions):
|
| 129 |
+
device = self.device
|
| 130 |
+
|
| 131 |
+
shape = torch.tensor(dimensions, device=device)
|
| 132 |
+
rel_pos_shape = 2 * shape - 1
|
| 133 |
+
|
| 134 |
+
# calculate strides
|
| 135 |
+
|
| 136 |
+
strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim=-1)
|
| 137 |
+
strides = torch.flip(F.pad(strides, (1, -1), value=1), (0,))
|
| 138 |
+
|
| 139 |
+
# get all positions and calculate all the relative distances
|
| 140 |
+
|
| 141 |
+
positions = [torch.arange(d, device=device) for d in dimensions]
|
| 142 |
+
grid = torch.stack(torch.meshgrid(*positions, indexing='ij'), dim=-1)
|
| 143 |
+
grid = rearrange(grid, '... c -> (...) c')
|
| 144 |
+
rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
|
| 145 |
+
|
| 146 |
+
# get all relative positions across all dimensions
|
| 147 |
+
|
| 148 |
+
rel_positions = [torch.arange(-d + 1, d, device=device) for d in dimensions]
|
| 149 |
+
rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing='ij'), dim=-1)
|
| 150 |
+
rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c')
|
| 151 |
+
|
| 152 |
+
# mlp input
|
| 153 |
+
|
| 154 |
+
bias = rel_pos_grid.float()
|
| 155 |
+
|
| 156 |
+
for layer in self.net:
|
| 157 |
+
bias = layer(bias)
|
| 158 |
+
|
| 159 |
+
# convert relative distances to indices of the bias
|
| 160 |
+
|
| 161 |
+
rel_dist += (shape - 1) # make sure all positive
|
| 162 |
+
rel_dist *= strides
|
| 163 |
+
rel_dist_indices = rel_dist.sum(dim=-1)
|
| 164 |
+
|
| 165 |
+
# now select the bias for each unique relative position combination
|
| 166 |
+
|
| 167 |
+
bias = bias[rel_dist_indices]
|
| 168 |
+
return rearrange(bias, 'i j h -> h i j')
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# helper classes
|
| 172 |
+
|
| 173 |
+
class CrossAttention(nn.Module):
|
| 174 |
+
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
|
| 177 |
+
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
| 178 |
+
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
|
| 179 |
+
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
|
| 180 |
+
self.n_heads = n_heads
|
| 181 |
+
self.d_head = d_embed // n_heads
|
| 182 |
+
|
| 183 |
+
def forward(self, x, y):
|
| 184 |
+
input_shape = x.shape
|
| 185 |
+
batch_size, sequence_length, d_embed = input_shape
|
| 186 |
+
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
|
| 187 |
+
|
| 188 |
+
q = self.q_proj(x)
|
| 189 |
+
k = self.k_proj(y)
|
| 190 |
+
v = self.v_proj(y)
|
| 191 |
+
|
| 192 |
+
q = q.view(interim_shape).transpose(1, 2)
|
| 193 |
+
k = k.view(interim_shape).transpose(1, 2)
|
| 194 |
+
v = v.view(interim_shape).transpose(1, 2)
|
| 195 |
+
|
| 196 |
+
weight = q @ k.transpose(-1, -2)
|
| 197 |
+
weight /= math.sqrt(self.d_head)
|
| 198 |
+
weight = F.softmax(weight, dim=-1)
|
| 199 |
+
|
| 200 |
+
output = weight @ v
|
| 201 |
+
output = output.transpose(1, 2).contiguous()
|
| 202 |
+
output = output.view(input_shape)
|
| 203 |
+
output = self.out_proj(output)
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
class AttentionBlock(nn.Module):
|
| 207 |
+
def __init__(self, n_head: int, n_embd: int, d_context=768):
|
| 208 |
+
super().__init__()
|
| 209 |
+
channels = n_head * n_embd
|
| 210 |
+
|
| 211 |
+
#self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
|
| 212 |
+
#self.conv_input = PseudoConv3d(channels, channels, 1)
|
| 213 |
+
self.layernorm_2 = nn.LayerNorm(channels)
|
| 214 |
+
self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
|
| 215 |
+
self.layernorm_3 = nn.LayerNorm(channels)
|
| 216 |
+
self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
|
| 217 |
+
self.linear_geglu_2 = nn.Linear(4 * channels, channels)
|
| 218 |
+
self.conv_output = PseudoConv3d(channels, channels, 1, bias=False)
|
| 219 |
+
|
| 220 |
+
def forward(self, x, context):
|
| 221 |
+
b, c, *_, h, w = x.shape
|
| 222 |
+
#x = self.groupnorm(x)
|
| 223 |
+
#x = self.conv_input(x)
|
| 224 |
+
x = rearrange(x, 'b c f h w -> b (h w f) c')
|
| 225 |
+
|
| 226 |
+
residue_short = x
|
| 227 |
+
x = self.layernorm_2(x)
|
| 228 |
+
x = self.attention_2(x, context)
|
| 229 |
+
x += residue_short
|
| 230 |
+
|
| 231 |
+
residue_short = x
|
| 232 |
+
x = self.layernorm_3(x)
|
| 233 |
+
x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
|
| 234 |
+
x = x * F.gelu(gate)
|
| 235 |
+
x = self.linear_geglu_2(x)
|
| 236 |
+
x += residue_short
|
| 237 |
+
|
| 238 |
+
x = rearrange(x, 'b (h w f) c -> b c f h w', b=b, c=c, h=h, w=w)
|
| 239 |
+
x = self.conv_output(x)
|
| 240 |
+
return x
|
| 241 |
+
|
| 242 |
+
class Attention(nn.Module):
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
dim,
|
| 246 |
+
dim_head=64,
|
| 247 |
+
heads=8
|
| 248 |
+
):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.heads = heads
|
| 251 |
+
self.scale = dim_head ** -0.5
|
| 252 |
+
inner_dim = dim_head * heads
|
| 253 |
+
|
| 254 |
+
self.norm = nn.LayerNorm(dim)
|
| 255 |
+
|
| 256 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 257 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 258 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 259 |
+
nn.init.zeros_(self.to_out.weight.data) # identity with skip connection
|
| 260 |
+
|
| 261 |
+
def forward(
|
| 262 |
+
self,
|
| 263 |
+
x,
|
| 264 |
+
rel_pos_bias=None
|
| 265 |
+
):
|
| 266 |
+
x = self.norm(x)
|
| 267 |
+
|
| 268 |
+
q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)
|
| 269 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
|
| 270 |
+
|
| 271 |
+
q = q * self.scale
|
| 272 |
+
|
| 273 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
| 274 |
+
|
| 275 |
+
if exists(rel_pos_bias):
|
| 276 |
+
sim = sim + rel_pos_bias
|
| 277 |
+
|
| 278 |
+
attn = sim.softmax(dim=-1)
|
| 279 |
+
|
| 280 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
| 281 |
+
|
| 282 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 283 |
+
return self.to_out(out)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# main contribution - pseudo 3d conv
|
| 287 |
+
|
| 288 |
+
class PseudoConv3d(nn.Module):
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
dim,
|
| 292 |
+
dim_out=None,
|
| 293 |
+
kernel_size=3,
|
| 294 |
+
*,
|
| 295 |
+
temporal_kernel_size=None,
|
| 296 |
+
**kwargs
|
| 297 |
+
):
|
| 298 |
+
super().__init__()
|
| 299 |
+
dim_out = default(dim_out, dim)
|
| 300 |
+
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
|
| 301 |
+
|
| 302 |
+
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size // 2)
|
| 303 |
+
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size=temporal_kernel_size,
|
| 304 |
+
padding=temporal_kernel_size // 2) if kernel_size > 1 else None
|
| 305 |
+
|
| 306 |
+
if exists(self.temporal_conv):
|
| 307 |
+
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
|
| 308 |
+
nn.init.zeros_(self.temporal_conv.bias.data)
|
| 309 |
+
|
| 310 |
+
def forward(
|
| 311 |
+
self,
|
| 312 |
+
x,
|
| 313 |
+
enable_time=True
|
| 314 |
+
):
|
| 315 |
+
b, c, *_, h, w = x.shape
|
| 316 |
+
|
| 317 |
+
is_video = x.ndim == 5
|
| 318 |
+
enable_time &= is_video
|
| 319 |
+
|
| 320 |
+
if is_video:
|
| 321 |
+
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
| 322 |
+
|
| 323 |
+
x = self.spatial_conv(x)
|
| 324 |
+
|
| 325 |
+
if is_video:
|
| 326 |
+
x = rearrange(x, '(b f) c h w -> b c f h w', b=b)
|
| 327 |
+
|
| 328 |
+
if not enable_time or not exists(self.temporal_conv):
|
| 329 |
+
return x
|
| 330 |
+
|
| 331 |
+
x = rearrange(x, 'b c f h w -> (b h w) c f')
|
| 332 |
+
|
| 333 |
+
x = self.temporal_conv(x)
|
| 334 |
+
|
| 335 |
+
x = rearrange(x, '(b h w) c f -> b c f h w', h=h, w=w)
|
| 336 |
+
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# factorized spatial temporal attention from Ho et al.
|
| 341 |
+
|
| 342 |
+
class SpatioTemporalAttention(nn.Module):
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
dim,
|
| 346 |
+
*,
|
| 347 |
+
dim_head=64,
|
| 348 |
+
heads=8,
|
| 349 |
+
add_feed_forward=True,
|
| 350 |
+
ff_mult=4
|
| 351 |
+
):
|
| 352 |
+
super().__init__()
|
| 353 |
+
self.spatial_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
|
| 354 |
+
self.spatial_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=2)
|
| 355 |
+
|
| 356 |
+
self.temporal_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
|
| 357 |
+
self.temporal_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=1)
|
| 358 |
+
|
| 359 |
+
self.has_feed_forward = add_feed_forward
|
| 360 |
+
if not add_feed_forward:
|
| 361 |
+
return
|
| 362 |
+
|
| 363 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult)
|
| 364 |
+
|
| 365 |
+
def forward(
|
| 366 |
+
self,
|
| 367 |
+
x,
|
| 368 |
+
enable_time=True
|
| 369 |
+
):
|
| 370 |
+
b, c, *_, h, w = x.shape
|
| 371 |
+
is_video = x.ndim == 5
|
| 372 |
+
enable_time &= is_video
|
| 373 |
+
|
| 374 |
+
if is_video:
|
| 375 |
+
x = rearrange(x, 'b c f h w -> (b f) (h w) c')
|
| 376 |
+
else:
|
| 377 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
| 378 |
+
|
| 379 |
+
space_rel_pos_bias = self.spatial_rel_pos_bias(h, w)
|
| 380 |
+
|
| 381 |
+
x = self.spatial_attn(x, rel_pos_bias=space_rel_pos_bias) + x
|
| 382 |
+
|
| 383 |
+
if is_video:
|
| 384 |
+
x = rearrange(x, '(b f) (h w) c -> b c f h w', b=b, h=h, w=w)
|
| 385 |
+
else:
|
| 386 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 387 |
+
|
| 388 |
+
if enable_time:
|
| 389 |
+
x = rearrange(x, 'b c f h w -> (b h w) f c')
|
| 390 |
+
|
| 391 |
+
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
|
| 392 |
+
|
| 393 |
+
x = self.temporal_attn(x, rel_pos_bias=time_rel_pos_bias) + x
|
| 394 |
+
|
| 395 |
+
x = rearrange(x, '(b h w) f c -> b c f h w', w=w, h=h)
|
| 396 |
+
|
| 397 |
+
if self.has_feed_forward:
|
| 398 |
+
x = self.ff(x, enable_time=enable_time) + x
|
| 399 |
+
|
| 400 |
+
return x
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
# resnet block
|
| 404 |
+
|
| 405 |
+
class Block(nn.Module):
|
| 406 |
+
def __init__(
|
| 407 |
+
self,
|
| 408 |
+
dim,
|
| 409 |
+
dim_out,
|
| 410 |
+
kernel_size=3,
|
| 411 |
+
temporal_kernel_size=None,
|
| 412 |
+
groups=8
|
| 413 |
+
):
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.project = PseudoConv3d(dim, dim_out, 3)
|
| 416 |
+
self.norm = nn.GroupNorm(groups, dim_out)
|
| 417 |
+
self.act = nn.SiLU()
|
| 418 |
+
|
| 419 |
+
def forward(
|
| 420 |
+
self,
|
| 421 |
+
x,
|
| 422 |
+
scale_shift=None,
|
| 423 |
+
enable_time=False
|
| 424 |
+
):
|
| 425 |
+
x = self.project(x, enable_time=enable_time)
|
| 426 |
+
x = self.norm(x)
|
| 427 |
+
|
| 428 |
+
if exists(scale_shift):
|
| 429 |
+
scale, shift = scale_shift
|
| 430 |
+
x = x * (scale + 1) + shift
|
| 431 |
+
|
| 432 |
+
return self.act(x)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
class ResnetBlock(nn.Module):
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
dim,
|
| 439 |
+
dim_out,
|
| 440 |
+
*,
|
| 441 |
+
timestep_cond_dim=None,
|
| 442 |
+
groups=8
|
| 443 |
+
):
|
| 444 |
+
super().__init__()
|
| 445 |
+
|
| 446 |
+
self.timestep_mlp = None
|
| 447 |
+
|
| 448 |
+
if exists(timestep_cond_dim):
|
| 449 |
+
self.timestep_mlp = nn.Sequential(
|
| 450 |
+
nn.SiLU(),
|
| 451 |
+
nn.Linear(timestep_cond_dim, dim_out * 2)
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
self.block1 = Block(dim, dim_out, groups=groups)
|
| 455 |
+
self.block2 = Block(dim_out, dim_out, groups=groups)
|
| 456 |
+
self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
| 457 |
+
|
| 458 |
+
def forward(
|
| 459 |
+
self,
|
| 460 |
+
x,
|
| 461 |
+
timestep_emb=None,
|
| 462 |
+
enable_time=True
|
| 463 |
+
):
|
| 464 |
+
assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))
|
| 465 |
+
|
| 466 |
+
scale_shift = None
|
| 467 |
+
|
| 468 |
+
if exists(self.timestep_mlp) and exists(timestep_emb):
|
| 469 |
+
time_emb = self.timestep_mlp(timestep_emb)
|
| 470 |
+
to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
|
| 471 |
+
time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
|
| 472 |
+
scale_shift = time_emb.chunk(2, dim=1)
|
| 473 |
+
|
| 474 |
+
h = self.block1(x, scale_shift=scale_shift, enable_time=enable_time)
|
| 475 |
+
|
| 476 |
+
h = self.block2(h, enable_time=enable_time)
|
| 477 |
+
|
| 478 |
+
return h + self.res_conv(x)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# pixelshuffle upsamples and downsamples
|
| 482 |
+
# where time dimension can be configured
|
| 483 |
+
|
| 484 |
+
class Downsample(nn.Module):
|
| 485 |
+
def __init__(
|
| 486 |
+
self,
|
| 487 |
+
dim,
|
| 488 |
+
downsample_space=True,
|
| 489 |
+
downsample_time=False,
|
| 490 |
+
nonlin=False
|
| 491 |
+
):
|
| 492 |
+
super().__init__()
|
| 493 |
+
assert downsample_space or downsample_time
|
| 494 |
+
|
| 495 |
+
self.down_space = nn.Sequential(
|
| 496 |
+
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
|
| 497 |
+
nn.Conv2d(dim * 4, dim, 1, bias=False),
|
| 498 |
+
nn.SiLU() if nonlin else nn.Identity()
|
| 499 |
+
) if downsample_space else None
|
| 500 |
+
|
| 501 |
+
self.down_time = nn.Sequential(
|
| 502 |
+
Rearrange('b c (f p) h w -> b (c p) f h w', p=2),
|
| 503 |
+
nn.Conv3d(dim * 2, dim, 1, bias=False),
|
| 504 |
+
nn.SiLU() if nonlin else nn.Identity()
|
| 505 |
+
) if downsample_time else None
|
| 506 |
+
|
| 507 |
+
def forward(
|
| 508 |
+
self,
|
| 509 |
+
x,
|
| 510 |
+
enable_time=True
|
| 511 |
+
):
|
| 512 |
+
is_video = x.ndim == 5
|
| 513 |
+
|
| 514 |
+
if is_video:
|
| 515 |
+
x = rearrange(x, 'b c f h w -> b f c h w')
|
| 516 |
+
x, ps = pack([x], '* c h w')
|
| 517 |
+
|
| 518 |
+
if exists(self.down_space):
|
| 519 |
+
x = self.down_space(x)
|
| 520 |
+
|
| 521 |
+
if is_video:
|
| 522 |
+
x, = unpack(x, ps, '* c h w')
|
| 523 |
+
x = rearrange(x, 'b f c h w -> b c f h w')
|
| 524 |
+
|
| 525 |
+
if not is_video or not exists(self.down_time) or not enable_time:
|
| 526 |
+
return x
|
| 527 |
+
|
| 528 |
+
x = self.down_time(x)
|
| 529 |
+
|
| 530 |
+
return x
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class Upsample(nn.Module):
|
| 534 |
+
def __init__(
|
| 535 |
+
self,
|
| 536 |
+
dim,
|
| 537 |
+
upsample_space=True,
|
| 538 |
+
upsample_time=False,
|
| 539 |
+
nonlin=False
|
| 540 |
+
):
|
| 541 |
+
super().__init__()
|
| 542 |
+
assert upsample_space or upsample_time
|
| 543 |
+
|
| 544 |
+
self.up_space = nn.Sequential(
|
| 545 |
+
nn.Conv2d(dim, dim * 4, 1),
|
| 546 |
+
nn.SiLU() if nonlin else nn.Identity(),
|
| 547 |
+
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2)
|
| 548 |
+
) if upsample_space else None
|
| 549 |
+
|
| 550 |
+
self.up_time = nn.Sequential(
|
| 551 |
+
nn.Conv3d(dim, dim * 2, 1),
|
| 552 |
+
nn.SiLU() if nonlin else nn.Identity(),
|
| 553 |
+
Rearrange('b (c p) f h w -> b c (f p) h w', p=2)
|
| 554 |
+
) if upsample_time else None
|
| 555 |
+
|
| 556 |
+
self.init_()
|
| 557 |
+
|
| 558 |
+
def init_(self):
|
| 559 |
+
if exists(self.up_space):
|
| 560 |
+
self.init_conv_(self.up_space[0], 4)
|
| 561 |
+
|
| 562 |
+
if exists(self.up_time):
|
| 563 |
+
self.init_conv_(self.up_time[0], 2)
|
| 564 |
+
|
| 565 |
+
def init_conv_(self, conv, factor):
|
| 566 |
+
o, *remain_dims = conv.weight.shape
|
| 567 |
+
conv_weight = torch.empty(o // factor, *remain_dims)
|
| 568 |
+
nn.init.kaiming_uniform_(conv_weight)
|
| 569 |
+
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r=factor)
|
| 570 |
+
|
| 571 |
+
conv.weight.data.copy_(conv_weight)
|
| 572 |
+
nn.init.zeros_(conv.bias.data)
|
| 573 |
+
|
| 574 |
+
def forward(
|
| 575 |
+
self,
|
| 576 |
+
x,
|
| 577 |
+
enable_time=True
|
| 578 |
+
):
|
| 579 |
+
is_video = x.ndim == 5
|
| 580 |
+
|
| 581 |
+
if is_video:
|
| 582 |
+
x = rearrange(x, 'b c f h w -> b f c h w')
|
| 583 |
+
x, ps = pack([x], '* c h w')
|
| 584 |
+
|
| 585 |
+
if exists(self.up_space):
|
| 586 |
+
x = self.up_space(x)
|
| 587 |
+
|
| 588 |
+
if is_video:
|
| 589 |
+
x, = unpack(x, ps, '* c h w')
|
| 590 |
+
x = rearrange(x, 'b f c h w -> b c f h w')
|
| 591 |
+
|
| 592 |
+
if not is_video or not exists(self.up_time) or not enable_time:
|
| 593 |
+
return x
|
| 594 |
+
|
| 595 |
+
x = self.up_time(x)
|
| 596 |
+
|
| 597 |
+
return x
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class SpaceTimeUnet(nn.Module):
|
| 601 |
+
def __init__(
|
| 602 |
+
self,
|
| 603 |
+
*,
|
| 604 |
+
dim,
|
| 605 |
+
channels=4,
|
| 606 |
+
dim_mult=(1, 2, 4, 8),
|
| 607 |
+
self_attns=(False, False, False, True),
|
| 608 |
+
temporal_compression=(False, True, True, True),
|
| 609 |
+
resnet_block_depths=(2, 2, 2, 2),
|
| 610 |
+
attn_dim_head=64,
|
| 611 |
+
attn_heads=8,
|
| 612 |
+
condition_on_timestep=False,
|
| 613 |
+
):
|
| 614 |
+
super().__init__()
|
| 615 |
+
assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
|
| 616 |
+
num_layers = len(dim_mult)
|
| 617 |
+
|
| 618 |
+
dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
|
| 619 |
+
dim_in_out = zip(dims[:-1], dims[1:])
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# determine the valid multiples of the image size and frames of the video
|
| 623 |
+
self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression)))
|
| 624 |
+
self.image_size_multiple = 2 ** num_layers
|
| 625 |
+
|
| 626 |
+
# timestep conditioning for DDPM, not to be confused with the time dimension of the video
|
| 627 |
+
|
| 628 |
+
self.to_timestep_cond = None
|
| 629 |
+
timestep_cond_dim = (dim * 4) if condition_on_timestep else None
|
| 630 |
+
|
| 631 |
+
if condition_on_timestep:
|
| 632 |
+
self.to_timestep_cond = nn.Sequential(
|
| 633 |
+
SinusoidalPosEmb(dim),
|
| 634 |
+
nn.Linear(dim, timestep_cond_dim),
|
| 635 |
+
nn.SiLU()
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# Cross Attention
|
| 639 |
+
cross_attention_D1 = AttentionBlock(1, 64) # 64
|
| 640 |
+
cross_attention_D2 = AttentionBlock(1, 128) # 128
|
| 641 |
+
cross_attention_D3 = AttentionBlock(2, 128) # 256
|
| 642 |
+
cross_attention_D4 = AttentionBlock(4, 128) # 512
|
| 643 |
+
|
| 644 |
+
cross_attention_U1 = AttentionBlock(4, 64) # 256
|
| 645 |
+
cross_attention_U2 = AttentionBlock(2, 64) # 128
|
| 646 |
+
cross_attention_U3 = AttentionBlock(1, 64) # 64
|
| 647 |
+
cross_attention_U4 = AttentionBlock(1, 64) # 64
|
| 648 |
+
|
| 649 |
+
cross_attns_down = (cross_attention_D1, cross_attention_D2, cross_attention_D3, cross_attention_D4)
|
| 650 |
+
cross_attns_up = (cross_attention_U4, cross_attention_U3, cross_attention_U2, cross_attention_U1)
|
| 651 |
+
# layers
|
| 652 |
+
|
| 653 |
+
self.downs = mlist([])
|
| 654 |
+
self.ups = mlist([])
|
| 655 |
+
|
| 656 |
+
attn_kwargs = dict(
|
| 657 |
+
dim_head=attn_dim_head,
|
| 658 |
+
heads=attn_heads
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
mid_dim = dims[-1]
|
| 662 |
+
|
| 663 |
+
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
|
| 664 |
+
self.mid_attn = SpatioTemporalAttention(dim=mid_dim)
|
| 665 |
+
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
|
| 666 |
+
for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth, cross_attns_d, cross_attns_u in zip(range(num_layers),
|
| 667 |
+
self_attns,
|
| 668 |
+
dim_in_out,
|
| 669 |
+
temporal_compression,
|
| 670 |
+
resnet_block_depths,
|
| 671 |
+
cross_attns_down,
|
| 672 |
+
cross_attns_up):
|
| 673 |
+
assert resnet_block_depth >= 1
|
| 674 |
+
self.downs.append(mlist([
|
| 675 |
+
ResnetBlock(dim_in, dim_out, timestep_cond_dim=timestep_cond_dim),
|
| 676 |
+
mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]),
|
| 677 |
+
SpatioTemporalAttention(dim=dim_out, **attn_kwargs) if self_attend else None,
|
| 678 |
+
Downsample(dim_out, downsample_time=compress_time),
|
| 679 |
+
cross_attns_d if exists(cross_attns_d) else None
|
| 680 |
+
]))
|
| 681 |
+
self.ups.append(mlist([
|
| 682 |
+
ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim=timestep_cond_dim),
|
| 683 |
+
mlist(
|
| 684 |
+
[ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]),
|
| 685 |
+
SpatioTemporalAttention(dim=dim_in, **attn_kwargs) if self_attend else None,
|
| 686 |
+
Upsample(dim_out, upsample_time=compress_time),
|
| 687 |
+
cross_attns_u if exists(cross_attns_u) else None
|
| 688 |
+
|
| 689 |
+
]))
|
| 690 |
+
self.skip_scale = 2 ** -0.5 # paper shows faster convergence
|
| 691 |
+
|
| 692 |
+
self.conv_in = PseudoConv3d(dim=channels, dim_out=dim, kernel_size=7, temporal_kernel_size=3)
|
| 693 |
+
self.conv_out = PseudoConv3d(dim=dim, dim_out=channels, kernel_size=3, temporal_kernel_size=3)
|
| 694 |
+
|
| 695 |
+
def forward(
|
| 696 |
+
self,
|
| 697 |
+
x,
|
| 698 |
+
clip_vae_embed,
|
| 699 |
+
timestep=None,
|
| 700 |
+
enable_time=True
|
| 701 |
+
):
|
| 702 |
+
|
| 703 |
+
assert not (exists(self.to_timestep_cond) ^ exists(timestep))
|
| 704 |
+
is_video = x.ndim == 5
|
| 705 |
+
|
| 706 |
+
if enable_time and is_video:
|
| 707 |
+
frames = x.shape[2]
|
| 708 |
+
assert divisible_by(frames,
|
| 709 |
+
self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'
|
| 710 |
+
|
| 711 |
+
height, width = x.shape[-2:]
|
| 712 |
+
assert divisible_by(height, self.image_size_multiple) and divisible_by(width,
|
| 713 |
+
self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'
|
| 714 |
+
|
| 715 |
+
# main logic
|
| 716 |
+
|
| 717 |
+
t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None
|
| 718 |
+
x = self.conv_in(x, enable_time=enable_time)
|
| 719 |
+
|
| 720 |
+
hiddens = []
|
| 721 |
+
for init_block, blocks, maybe_attention, downsample, cross_attn in self.downs:
|
| 722 |
+
x = init_block(x, t, enable_time=enable_time)
|
| 723 |
+
hiddens.append(x.clone())
|
| 724 |
+
for block in blocks:
|
| 725 |
+
x = block(x, enable_time=enable_time)
|
| 726 |
+
if exists(maybe_attention):
|
| 727 |
+
x = maybe_attention(x, enable_time=enable_time) # only happens in the last layer
|
| 728 |
+
hiddens.append(x.clone())
|
| 729 |
+
x = downsample(x, enable_time=enable_time)
|
| 730 |
+
if exists(cross_attn):
|
| 731 |
+
x = cross_attn(x, clip_vae_embed)
|
| 732 |
+
|
| 733 |
+
x = self.mid_block1(x, t, enable_time=enable_time)
|
| 734 |
+
x = self.mid_attn(x, enable_time=enable_time)
|
| 735 |
+
x = self.mid_block2(x, t, enable_time=enable_time)
|
| 736 |
+
|
| 737 |
+
for init_block, blocks, maybe_attention, upsample, cross_attn in reversed(self.ups):
|
| 738 |
+
x = upsample(x, enable_time=enable_time)
|
| 739 |
+
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
|
| 740 |
+
x = init_block(x, t, enable_time=enable_time)
|
| 741 |
+
x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
|
| 742 |
+
for block in blocks:
|
| 743 |
+
x = block(x, enable_time=enable_time)
|
| 744 |
+
if exists(maybe_attention):
|
| 745 |
+
x = maybe_attention(x, enable_time=enable_time)
|
| 746 |
+
if exists(cross_attn):
|
| 747 |
+
x = cross_attn(x, clip_vae_embed)
|
| 748 |
+
|
| 749 |
+
x = self.conv_out(x, enable_time=enable_time)
|
| 750 |
+
return x
|
| 751 |
+
|
| 752 |
+
if __name__ == '__main__':
|
| 753 |
+
Net = SpaceTimeUnet(
|
| 754 |
+
dim=64,
|
| 755 |
+
channels=3,
|
| 756 |
+
dim_mult=(1, 2, 4, 8),
|
| 757 |
+
temporal_compression=(False, False, False, True),
|
| 758 |
+
self_attns=(False, False, False, True),
|
| 759 |
+
condition_on_timestep=False)
|
| 760 |
+
|
| 761 |
+
x = torch.randn([1,8,3,32,32])
|
| 762 |
+
sample_output = Net(x.permute(0, 2, 1, 3, 4))
|
models/unet_dual_encoder.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Load pretrained 2D UNet and modify with temporal attention
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.utils.checkpoint
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
from diffusers.models import UNet2DConditionModel
|
| 8 |
+
|
| 9 |
+
def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5):
|
| 10 |
+
# Load pretrained UNet layers
|
| 11 |
+
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4",
|
| 12 |
+
subfolder="unet",
|
| 13 |
+
revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c",
|
| 14 |
+
cache_dir="checkpoints/unet")
|
| 15 |
+
|
| 16 |
+
# Modify input layer to have 1 additional input channels (pose)
|
| 17 |
+
weights = unet.conv_in.weight.clone()
|
| 18 |
+
unet.conv_in = nn.Conv2d(4 + 2*n_poses, weights.shape[0], kernel_size=3, padding=(1, 1)) # input noise + n poses
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
unet.conv_in.weight[:, :4] = weights # original weights
|
| 21 |
+
unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) # new weights initialized to zero
|
| 22 |
+
|
| 23 |
+
return unet
|
| 24 |
+
|
| 25 |
+
'''
|
| 26 |
+
This module takes in CLIP + VAE embeddings and outputs CLIP-compatible embeddings.
|
| 27 |
+
'''
|
| 28 |
+
class Embedding_Adapter(nn.Module):
|
| 29 |
+
def __init__(self, input_nc=38, output_nc=4, norm_layer=nn.InstanceNorm2d, chkpt=None):
|
| 30 |
+
super(Embedding_Adapter, self).__init__()
|
| 31 |
+
|
| 32 |
+
self.save_method_name = "adapter"
|
| 33 |
+
|
| 34 |
+
self.pool = nn.MaxPool2d(2)
|
| 35 |
+
self.vae2clip = nn.Linear(1280, 768)
|
| 36 |
+
|
| 37 |
+
self.linear1 = nn.Linear(54, 50) # 50 x 54 shape
|
| 38 |
+
|
| 39 |
+
# initialize weights
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
self.linear1.weight = nn.Parameter(torch.eye(50, 54))
|
| 42 |
+
|
| 43 |
+
if chkpt is not None:
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
def forward(self, clip, vae):
|
| 47 |
+
|
| 48 |
+
vae = self.pool(vae) # 1 4 80 64 --> 1 4 40 32
|
| 49 |
+
vae = rearrange(vae, 'b c h w -> b c (h w)') # 1 4 20 16 --> 1 4 1280
|
| 50 |
+
|
| 51 |
+
vae = self.vae2clip(vae) # 1 4 768
|
| 52 |
+
|
| 53 |
+
# Concatenate
|
| 54 |
+
concat = torch.cat((clip, vae), 1)
|
| 55 |
+
|
| 56 |
+
# Encode
|
| 57 |
+
|
| 58 |
+
concat = rearrange(concat, 'b c d -> b d c')
|
| 59 |
+
concat = self.linear1(concat)
|
| 60 |
+
concat = rearrange(concat, 'b d c -> b c d')
|
| 61 |
+
|
| 62 |
+
return concat
|
project_latent_space.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.transforms as transforms
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import os, argparse
|
| 6 |
+
import tqdm
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from diffusers import AutoencoderKL
|
| 9 |
+
import random
|
| 10 |
+
device = torch.device("cuda")
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser(description="Configuration of the tensor projection.")
|
| 13 |
+
parser.add_argument('--dataset', default="fashion_dataset/train", help="Path to the dataset")
|
| 14 |
+
parser.add_argument('--output_dir', default="fashion_dataset_tensor", help="Path to save the tensors")
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
|
| 17 |
+
vae = AutoencoderKL.from_pretrained(
|
| 18 |
+
"CompVis/stable-diffusion-v1-4",
|
| 19 |
+
subfolder="vae",
|
| 20 |
+
revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c"
|
| 21 |
+
).to(device)
|
| 22 |
+
vae.requires_grad_(False)
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def VAE_encode(video):
|
| 26 |
+
for i in range(video.shape[0]):
|
| 27 |
+
image = video[i, :, :, :]
|
| 28 |
+
image = image.unsqueeze(0)
|
| 29 |
+
if i == 0:
|
| 30 |
+
init_latent_dist = vae.encode(image).latent_dist.sample()
|
| 31 |
+
init_latent_dist *= 0.18215
|
| 32 |
+
encoded_video = (init_latent_dist).unsqueeze(1)
|
| 33 |
+
else:
|
| 34 |
+
init_latent_dist = vae.encode(image).latent_dist.sample()
|
| 35 |
+
init_latent_dist *= 0.18215
|
| 36 |
+
encoded_video = torch.cat([encoded_video, (init_latent_dist).unsqueeze(1)], 1)
|
| 37 |
+
return encoded_video
|
| 38 |
+
|
| 39 |
+
def get_transform():
|
| 40 |
+
image_transforms = transforms.Compose(
|
| 41 |
+
[
|
| 42 |
+
transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR),
|
| 43 |
+
transforms.ToTensor(),
|
| 44 |
+
])
|
| 45 |
+
return image_transforms
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
path = osp.join(args.dataset)
|
| 49 |
+
video_names = os.listdir(path)
|
| 50 |
+
transform = get_transform()
|
| 51 |
+
|
| 52 |
+
if not os.path.exists(args.output_dir):
|
| 53 |
+
os.makedirs(args.output_dir)
|
| 54 |
+
|
| 55 |
+
for video_name in tqdm.tqdm(video_names):
|
| 56 |
+
cap = cv2.VideoCapture(osp.join(path, video_name))
|
| 57 |
+
numberOfFrames = 241
|
| 58 |
+
number = random.randint(0, numberOfFrames - 70)
|
| 59 |
+
for i in range(number, number + 70):
|
| 60 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 61 |
+
_, frame = cap.read()
|
| 62 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 63 |
+
frame = Image.fromarray(frame)
|
| 64 |
+
frame = transform(frame)
|
| 65 |
+
if i == number:
|
| 66 |
+
inputImage = frame
|
| 67 |
+
torch.save(inputImage, args.output_dir + "/" + video_name[:-4] + "_image.pt")
|
| 68 |
+
frame = frame.unsqueeze(0)
|
| 69 |
+
restOfVideo = torch.clone(frame)
|
| 70 |
+
else:
|
| 71 |
+
frame = frame.unsqueeze(0)
|
| 72 |
+
restOfVideo = torch.cat([restOfVideo, frame], 0)
|
| 73 |
+
restOfVideo = restOfVideo.to(device=device)
|
| 74 |
+
vae_video = VAE_encode(restOfVideo).detach().cpu()[0]
|
| 75 |
+
torch.save(vae_video, args.output_dir + "/" + video_name[:-4] + ".pt")
|
requirements.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.26.1
|
| 2 |
+
certifi==2023.11.17
|
| 3 |
+
charset-normalizer==3.3.2
|
| 4 |
+
diffusers==0.14.0
|
| 5 |
+
einops==0.7.0
|
| 6 |
+
filelock==3.13.1
|
| 7 |
+
fsspec==2023.12.2
|
| 8 |
+
huggingface-hub==0.20.2
|
| 9 |
+
idna==3.6
|
| 10 |
+
importlib-metadata==7.0.1
|
| 11 |
+
numpy==1.26.3
|
| 12 |
+
opencv-python==4.9.0.80
|
| 13 |
+
packaging==23.2
|
| 14 |
+
pillow==10.2.0
|
| 15 |
+
protobuf==4.25.2
|
| 16 |
+
psutil==5.9.7
|
| 17 |
+
PyYAML==6.0.1
|
| 18 |
+
regex==2023.12.25
|
| 19 |
+
requests==2.31.0
|
| 20 |
+
safetensors==0.4.1
|
| 21 |
+
tensorboardX==2.6.2.2
|
| 22 |
+
tokenizers==0.15.0
|
| 23 |
+
torch==1.11.0+cu113
|
| 24 |
+
torchaudio==0.11.0+cu113
|
| 25 |
+
torchvision==0.12.0+cu113
|
| 26 |
+
tqdm==4.66.1
|
| 27 |
+
transformers==4.36.2
|
| 28 |
+
typing_extensions==4.9.0
|
| 29 |
+
urllib3==2.1.0
|
| 30 |
+
zipp==3.17.0
|
sample/blue.jpg
ADDED
|
sample/green.jpg
ADDED
|
sample/silver.jpg
ADDED
|
src/deps/__init__.py
ADDED
|
File without changes
|
src/deps/facial_recognition/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copy-pasted from https://github.com/orpatashnik/StyleCLIP/tree/main/models/facial_recognition/__init__.py
|
| 3 |
+
"""
|
src/deps/facial_recognition/helpers.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copy-pasted from https://github.com/orpatashnik/StyleCLIP/tree/main/models/facial_recognition/helpers.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Flatten(Module):
|
| 15 |
+
def forward(self, input):
|
| 16 |
+
return input.view(input.size(0), -1)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def l2_norm(input, axis=1):
|
| 20 |
+
norm = torch.norm(input, 2, axis, True)
|
| 21 |
+
output = torch.div(input, norm)
|
| 22 |
+
return output
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
| 26 |
+
""" A named tuple describing a ResNet block. """
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
| 30 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_blocks(num_layers):
|
| 34 |
+
if num_layers == 50:
|
| 35 |
+
blocks = [
|
| 36 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 37 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
| 38 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
| 39 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 40 |
+
]
|
| 41 |
+
elif num_layers == 100:
|
| 42 |
+
blocks = [
|
| 43 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 44 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
| 45 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
| 46 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 47 |
+
]
|
| 48 |
+
elif num_layers == 152:
|
| 49 |
+
blocks = [
|
| 50 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
| 51 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
| 52 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
| 53 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
| 54 |
+
]
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
| 57 |
+
return blocks
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SEModule(Module):
|
| 61 |
+
def __init__(self, channels, reduction):
|
| 62 |
+
super(SEModule, self).__init__()
|
| 63 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
| 64 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
| 65 |
+
self.relu = ReLU(inplace=True)
|
| 66 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
| 67 |
+
self.sigmoid = Sigmoid()
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
module_input = x
|
| 71 |
+
x = self.avg_pool(x)
|
| 72 |
+
x = self.fc1(x)
|
| 73 |
+
x = self.relu(x)
|
| 74 |
+
x = self.fc2(x)
|
| 75 |
+
x = self.sigmoid(x)
|
| 76 |
+
return module_input * x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class bottleneck_IR(Module):
|
| 80 |
+
def __init__(self, in_channel, depth, stride):
|
| 81 |
+
super(bottleneck_IR, self).__init__()
|
| 82 |
+
if in_channel == depth:
|
| 83 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
| 84 |
+
else:
|
| 85 |
+
self.shortcut_layer = Sequential(
|
| 86 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
| 87 |
+
BatchNorm2d(depth)
|
| 88 |
+
)
|
| 89 |
+
self.res_layer = Sequential(
|
| 90 |
+
BatchNorm2d(in_channel),
|
| 91 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
| 92 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
shortcut = self.shortcut_layer(x)
|
| 97 |
+
res = self.res_layer(x)
|
| 98 |
+
return res + shortcut
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class bottleneck_IR_SE(Module):
|
| 102 |
+
def __init__(self, in_channel, depth, stride):
|
| 103 |
+
super(bottleneck_IR_SE, self).__init__()
|
| 104 |
+
if in_channel == depth:
|
| 105 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
| 106 |
+
else:
|
| 107 |
+
self.shortcut_layer = Sequential(
|
| 108 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
| 109 |
+
BatchNorm2d(depth)
|
| 110 |
+
)
|
| 111 |
+
self.res_layer = Sequential(
|
| 112 |
+
BatchNorm2d(in_channel),
|
| 113 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
| 114 |
+
PReLU(depth),
|
| 115 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
| 116 |
+
BatchNorm2d(depth),
|
| 117 |
+
SEModule(depth, 16)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def forward(self, x):
|
| 121 |
+
shortcut = self.shortcut_layer(x)
|
| 122 |
+
res = self.res_layer(x)
|
| 123 |
+
return res + shortcut
|
src/deps/facial_recognition/model_irse.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copy-pasted from https://github.com/orpatashnik/StyleCLIP/tree/main/models/facial_recognition/model_irse.py
|
| 3 |
+
"""
|
| 4 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
| 5 |
+
from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
class Backbone(Module):
|
| 12 |
+
WEIGHTS_URL = "https://www.dropbox.com/s/n6xicva1lrghb5w/model_ir_se50.pth?dl=1"
|
| 13 |
+
|
| 14 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
| 15 |
+
super(Backbone, self).__init__()
|
| 16 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
| 17 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
| 18 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
| 19 |
+
blocks = get_blocks(num_layers)
|
| 20 |
+
if mode == 'ir':
|
| 21 |
+
unit_module = bottleneck_IR
|
| 22 |
+
elif mode == 'ir_se':
|
| 23 |
+
unit_module = bottleneck_IR_SE
|
| 24 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
| 25 |
+
BatchNorm2d(64),
|
| 26 |
+
PReLU(64))
|
| 27 |
+
if input_size == 112:
|
| 28 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
| 29 |
+
Dropout(drop_ratio),
|
| 30 |
+
Flatten(),
|
| 31 |
+
Linear(512 * 7 * 7, 512),
|
| 32 |
+
BatchNorm1d(512, affine=affine))
|
| 33 |
+
else:
|
| 34 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
| 35 |
+
Dropout(drop_ratio),
|
| 36 |
+
Flatten(),
|
| 37 |
+
Linear(512 * 14 * 14, 512),
|
| 38 |
+
BatchNorm1d(512, affine=affine))
|
| 39 |
+
|
| 40 |
+
modules = []
|
| 41 |
+
for block in blocks:
|
| 42 |
+
for bottleneck in block:
|
| 43 |
+
modules.append(unit_module(bottleneck.in_channel,
|
| 44 |
+
bottleneck.depth,
|
| 45 |
+
bottleneck.stride))
|
| 46 |
+
self.body = Sequential(*modules)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = self.input_layer(x)
|
| 50 |
+
x = self.body(x)
|
| 51 |
+
x = self.output_layer(x)
|
| 52 |
+
return l2_norm(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def IR_50(input_size):
|
| 56 |
+
"""Constructs a ir-50 model."""
|
| 57 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
| 58 |
+
return model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def IR_101(input_size):
|
| 62 |
+
"""Constructs a ir-101 model."""
|
| 63 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def IR_152(input_size):
|
| 68 |
+
"""Constructs a ir-152 model."""
|
| 69 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
| 70 |
+
return model
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def IR_SE_50(input_size):
|
| 74 |
+
"""Constructs a ir_se-50 model."""
|
| 75 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
| 76 |
+
return model
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def IR_SE_101(input_size):
|
| 80 |
+
"""Constructs a ir_se-101 model."""
|
| 81 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
| 82 |
+
return model
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def IR_SE_152(input_size):
|
| 86 |
+
"""Constructs a ir_se-152 model."""
|
| 87 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
| 88 |
+
return model
|
src/dnnlib/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
from .util import EasyDict, make_cache_dir_path
|
src/dnnlib/util.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Miscellaneous utility classes and functions."""
|
| 10 |
+
|
| 11 |
+
import ctypes
|
| 12 |
+
import fnmatch
|
| 13 |
+
import importlib
|
| 14 |
+
import inspect
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import sys
|
| 19 |
+
import types
|
| 20 |
+
import io
|
| 21 |
+
import pickle
|
| 22 |
+
import re
|
| 23 |
+
import requests
|
| 24 |
+
import html
|
| 25 |
+
import hashlib
|
| 26 |
+
import glob
|
| 27 |
+
import tempfile
|
| 28 |
+
import urllib
|
| 29 |
+
import urllib.request
|
| 30 |
+
import uuid
|
| 31 |
+
|
| 32 |
+
from distutils.util import strtobool
|
| 33 |
+
from typing import Any, List, Tuple, Union, Dict
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Util classes
|
| 37 |
+
# ------------------------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EasyDict(dict):
|
| 41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 42 |
+
|
| 43 |
+
def __getattr__(self, name: str) -> Any:
|
| 44 |
+
try:
|
| 45 |
+
return self[name]
|
| 46 |
+
except KeyError:
|
| 47 |
+
raise AttributeError(name)
|
| 48 |
+
|
| 49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 50 |
+
self[name] = value
|
| 51 |
+
|
| 52 |
+
def __delattr__(self, name: str) -> None:
|
| 53 |
+
del self[name]
|
| 54 |
+
|
| 55 |
+
def to_dict(self) -> Dict:
|
| 56 |
+
return {k: (v.to_dict() if isinstance(v, EasyDict) else v) for (k, v) in self.items()}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Logger(object):
|
| 60 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
| 61 |
+
|
| 62 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
| 63 |
+
self.file = None
|
| 64 |
+
|
| 65 |
+
if file_name is not None:
|
| 66 |
+
self.file = open(file_name, file_mode)
|
| 67 |
+
|
| 68 |
+
self.should_flush = should_flush
|
| 69 |
+
self.stdout = sys.stdout
|
| 70 |
+
self.stderr = sys.stderr
|
| 71 |
+
|
| 72 |
+
sys.stdout = self
|
| 73 |
+
sys.stderr = self
|
| 74 |
+
|
| 75 |
+
def __enter__(self) -> "Logger":
|
| 76 |
+
return self
|
| 77 |
+
|
| 78 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
| 79 |
+
self.close()
|
| 80 |
+
|
| 81 |
+
def write(self, text: Union[str, bytes]) -> None:
|
| 82 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 83 |
+
if isinstance(text, bytes):
|
| 84 |
+
text = text.decode()
|
| 85 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
if self.file is not None:
|
| 89 |
+
self.file.write(text)
|
| 90 |
+
|
| 91 |
+
self.stdout.write(text)
|
| 92 |
+
|
| 93 |
+
if self.should_flush:
|
| 94 |
+
self.flush()
|
| 95 |
+
|
| 96 |
+
def flush(self) -> None:
|
| 97 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 98 |
+
if self.file is not None:
|
| 99 |
+
self.file.flush()
|
| 100 |
+
|
| 101 |
+
self.stdout.flush()
|
| 102 |
+
|
| 103 |
+
def close(self) -> None:
|
| 104 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 105 |
+
self.flush()
|
| 106 |
+
|
| 107 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 108 |
+
if sys.stdout is self:
|
| 109 |
+
sys.stdout = self.stdout
|
| 110 |
+
if sys.stderr is self:
|
| 111 |
+
sys.stderr = self.stderr
|
| 112 |
+
|
| 113 |
+
if self.file is not None:
|
| 114 |
+
self.file.close()
|
| 115 |
+
self.file = None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Cache directories
|
| 119 |
+
# ------------------------------------------------------------------------------------------
|
| 120 |
+
|
| 121 |
+
_dnnlib_cache_dir = None
|
| 122 |
+
|
| 123 |
+
def set_cache_dir(path: str) -> None:
|
| 124 |
+
global _dnnlib_cache_dir
|
| 125 |
+
_dnnlib_cache_dir = path
|
| 126 |
+
|
| 127 |
+
def make_cache_dir_path(*paths: str) -> str:
|
| 128 |
+
if _dnnlib_cache_dir is not None:
|
| 129 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
| 130 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
| 131 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
| 132 |
+
if 'HOME' in os.environ:
|
| 133 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
| 134 |
+
if 'USERPROFILE' in os.environ:
|
| 135 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
| 136 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
| 137 |
+
|
| 138 |
+
# Small util functions
|
| 139 |
+
# ------------------------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def format_time(seconds: Union[int, float]) -> str:
|
| 143 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
| 144 |
+
s = int(np.rint(seconds))
|
| 145 |
+
|
| 146 |
+
if s < 60:
|
| 147 |
+
return "{0}s".format(s)
|
| 148 |
+
elif s < 60 * 60:
|
| 149 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
| 150 |
+
elif s < 24 * 60 * 60:
|
| 151 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
| 152 |
+
else:
|
| 153 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def ask_yes_no(question: str) -> bool:
|
| 157 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
| 158 |
+
while True:
|
| 159 |
+
try:
|
| 160 |
+
print("{0} [y/n]".format(question))
|
| 161 |
+
return strtobool(input().lower())
|
| 162 |
+
except ValueError:
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def tuple_product(t: Tuple) -> Any:
|
| 167 |
+
"""Calculate the product of the tuple elements."""
|
| 168 |
+
result = 1
|
| 169 |
+
|
| 170 |
+
for v in t:
|
| 171 |
+
result *= v
|
| 172 |
+
|
| 173 |
+
return result
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
_str_to_ctype = {
|
| 177 |
+
"uint8": ctypes.c_ubyte,
|
| 178 |
+
"uint16": ctypes.c_uint16,
|
| 179 |
+
"uint32": ctypes.c_uint32,
|
| 180 |
+
"uint64": ctypes.c_uint64,
|
| 181 |
+
"int8": ctypes.c_byte,
|
| 182 |
+
"int16": ctypes.c_int16,
|
| 183 |
+
"int32": ctypes.c_int32,
|
| 184 |
+
"int64": ctypes.c_int64,
|
| 185 |
+
"float32": ctypes.c_float,
|
| 186 |
+
"float64": ctypes.c_double
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
| 191 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
| 192 |
+
type_str = None
|
| 193 |
+
|
| 194 |
+
if isinstance(type_obj, str):
|
| 195 |
+
type_str = type_obj
|
| 196 |
+
elif hasattr(type_obj, "__name__"):
|
| 197 |
+
type_str = type_obj.__name__
|
| 198 |
+
elif hasattr(type_obj, "name"):
|
| 199 |
+
type_str = type_obj.name
|
| 200 |
+
else:
|
| 201 |
+
raise RuntimeError("Cannot infer type name from input")
|
| 202 |
+
|
| 203 |
+
assert type_str in _str_to_ctype.keys()
|
| 204 |
+
|
| 205 |
+
my_dtype = np.dtype(type_str)
|
| 206 |
+
my_ctype = _str_to_ctype[type_str]
|
| 207 |
+
|
| 208 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
| 209 |
+
|
| 210 |
+
return my_dtype, my_ctype
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def is_pickleable(obj: Any) -> bool:
|
| 214 |
+
try:
|
| 215 |
+
with io.BytesIO() as stream:
|
| 216 |
+
pickle.dump(obj, stream)
|
| 217 |
+
return True
|
| 218 |
+
except:
|
| 219 |
+
return False
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# Functionality to import modules/objects by name, and call functions by name
|
| 223 |
+
# ------------------------------------------------------------------------------------------
|
| 224 |
+
|
| 225 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
| 226 |
+
"""Searches for the underlying module behind the name to some python object.
|
| 227 |
+
Returns the module and the object name (original name with module part removed)."""
|
| 228 |
+
|
| 229 |
+
# allow convenience shorthands, substitute them by full names
|
| 230 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
| 231 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
| 232 |
+
|
| 233 |
+
# list alternatives for (module_name, local_obj_name)
|
| 234 |
+
parts = obj_name.split(".")
|
| 235 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
| 236 |
+
|
| 237 |
+
# try each alternative in turn
|
| 238 |
+
for module_name, local_obj_name in name_pairs:
|
| 239 |
+
try:
|
| 240 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 241 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 242 |
+
return module, local_obj_name
|
| 243 |
+
except:
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
# maybe some of the modules themselves contain errors?
|
| 247 |
+
for module_name, _local_obj_name in name_pairs:
|
| 248 |
+
try:
|
| 249 |
+
importlib.import_module(module_name) # may raise ImportError
|
| 250 |
+
except ImportError:
|
| 251 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
| 252 |
+
raise
|
| 253 |
+
|
| 254 |
+
# maybe the requested attribute is missing?
|
| 255 |
+
for module_name, local_obj_name in name_pairs:
|
| 256 |
+
try:
|
| 257 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
| 258 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
| 259 |
+
except ImportError:
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
# we are out of luck, but we have no idea why
|
| 263 |
+
raise ImportError(obj_name)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
| 267 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
| 268 |
+
if obj_name == '':
|
| 269 |
+
return module
|
| 270 |
+
obj = module
|
| 271 |
+
for part in obj_name.split("."):
|
| 272 |
+
obj = getattr(obj, part)
|
| 273 |
+
return obj
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_obj_by_name(name: str) -> Any:
|
| 277 |
+
"""Finds the python object with the given name."""
|
| 278 |
+
module, obj_name = get_module_from_obj_name(name)
|
| 279 |
+
return get_obj_from_module(module, obj_name)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
| 283 |
+
"""Finds the python object with the given name and calls it as a function."""
|
| 284 |
+
assert func_name is not None
|
| 285 |
+
func_obj = get_obj_by_name(func_name)
|
| 286 |
+
assert callable(func_obj)
|
| 287 |
+
return func_obj(*args, **kwargs)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
| 291 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
| 292 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
| 296 |
+
"""Get the directory path of the module containing the given object name."""
|
| 297 |
+
module, _ = get_module_from_obj_name(obj_name)
|
| 298 |
+
return os.path.dirname(inspect.getfile(module))
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def is_top_level_function(obj: Any) -> bool:
|
| 302 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
| 303 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def get_top_level_function_name(obj: Any) -> str:
|
| 307 |
+
"""Return the fully-qualified name of a top-level function."""
|
| 308 |
+
assert is_top_level_function(obj)
|
| 309 |
+
module = obj.__module__
|
| 310 |
+
if module == '__main__':
|
| 311 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
| 312 |
+
return module + "." + obj.__name__
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# File system helpers
|
| 316 |
+
# ------------------------------------------------------------------------------------------
|
| 317 |
+
|
| 318 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
| 319 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
| 320 |
+
Returns list of tuples containing both absolute and relative paths."""
|
| 321 |
+
assert os.path.isdir(dir_path)
|
| 322 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
| 323 |
+
|
| 324 |
+
if ignores is None:
|
| 325 |
+
ignores = []
|
| 326 |
+
|
| 327 |
+
result = []
|
| 328 |
+
|
| 329 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
| 330 |
+
for ignore_ in ignores:
|
| 331 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
| 332 |
+
|
| 333 |
+
# dirs need to be edited in-place
|
| 334 |
+
for d in dirs_to_remove:
|
| 335 |
+
dirs.remove(d)
|
| 336 |
+
|
| 337 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
| 338 |
+
|
| 339 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
| 340 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
| 341 |
+
|
| 342 |
+
if add_base_to_relative:
|
| 343 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
| 344 |
+
|
| 345 |
+
assert len(absolute_paths) == len(relative_paths)
|
| 346 |
+
result += zip(absolute_paths, relative_paths)
|
| 347 |
+
|
| 348 |
+
return result
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
| 352 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
| 353 |
+
Will create all necessary directories."""
|
| 354 |
+
for file in files:
|
| 355 |
+
target_dir_name = os.path.dirname(file[1])
|
| 356 |
+
|
| 357 |
+
# will create all intermediate-level directories
|
| 358 |
+
if not os.path.exists(target_dir_name):
|
| 359 |
+
os.makedirs(target_dir_name)
|
| 360 |
+
|
| 361 |
+
shutil.copyfile(file[0], file[1])
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# URL helpers
|
| 365 |
+
# ------------------------------------------------------------------------------------------
|
| 366 |
+
|
| 367 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
| 368 |
+
"""Determine whether the given object is a valid URL string."""
|
| 369 |
+
if not isinstance(obj, str) or not "://" in obj:
|
| 370 |
+
return False
|
| 371 |
+
if allow_file_urls and obj.startswith('file://'):
|
| 372 |
+
return True
|
| 373 |
+
try:
|
| 374 |
+
res = requests.compat.urlparse(obj)
|
| 375 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 376 |
+
return False
|
| 377 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
| 378 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
| 379 |
+
return False
|
| 380 |
+
except:
|
| 381 |
+
return False
|
| 382 |
+
return True
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
| 386 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 387 |
+
assert num_attempts >= 1
|
| 388 |
+
assert not (return_filename and (not cache))
|
| 389 |
+
|
| 390 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 391 |
+
if not re.match('^[a-z]+://', url):
|
| 392 |
+
return url if return_filename else open(url, "rb")
|
| 393 |
+
|
| 394 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 395 |
+
# arise on Windows:
|
| 396 |
+
#
|
| 397 |
+
# file:///c:/foo.txt
|
| 398 |
+
#
|
| 399 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 400 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 401 |
+
#
|
| 402 |
+
# If you touch this code path, you should test it on both Linux and
|
| 403 |
+
# Windows.
|
| 404 |
+
#
|
| 405 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 406 |
+
# but that converts forward slashes to backslashes and this causes
|
| 407 |
+
# its own set of problems.
|
| 408 |
+
if url.startswith('file://'):
|
| 409 |
+
filename = urllib.parse.urlparse(url).path
|
| 410 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
| 411 |
+
filename = filename[1:]
|
| 412 |
+
return filename if return_filename else open(filename, "rb")
|
| 413 |
+
|
| 414 |
+
assert is_url(url)
|
| 415 |
+
|
| 416 |
+
# Lookup from cache.
|
| 417 |
+
if cache_dir is None:
|
| 418 |
+
cache_dir = make_cache_dir_path('downloads')
|
| 419 |
+
|
| 420 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 421 |
+
if cache:
|
| 422 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
| 423 |
+
if len(cache_files) == 1:
|
| 424 |
+
filename = cache_files[0]
|
| 425 |
+
return filename if return_filename else open(filename, "rb")
|
| 426 |
+
|
| 427 |
+
# Download.
|
| 428 |
+
url_name = None
|
| 429 |
+
url_data = None
|
| 430 |
+
with requests.Session() as session:
|
| 431 |
+
if verbose:
|
| 432 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 433 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 434 |
+
try:
|
| 435 |
+
with session.get(url) as res:
|
| 436 |
+
res.raise_for_status()
|
| 437 |
+
if len(res.content) == 0:
|
| 438 |
+
raise IOError("No data received")
|
| 439 |
+
|
| 440 |
+
if len(res.content) < 8192:
|
| 441 |
+
content_str = res.content.decode("utf-8")
|
| 442 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 443 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
| 444 |
+
if len(links) == 1:
|
| 445 |
+
url = requests.compat.urljoin(url, links[0])
|
| 446 |
+
raise IOError("Google Drive virus checker nag")
|
| 447 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 448 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
| 449 |
+
|
| 450 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
| 451 |
+
url_name = match[1] if match else url
|
| 452 |
+
url_data = res.content
|
| 453 |
+
if verbose:
|
| 454 |
+
print(" done")
|
| 455 |
+
break
|
| 456 |
+
except KeyboardInterrupt:
|
| 457 |
+
raise
|
| 458 |
+
except:
|
| 459 |
+
if not attempts_left:
|
| 460 |
+
if verbose:
|
| 461 |
+
print(" failed")
|
| 462 |
+
raise
|
| 463 |
+
if verbose:
|
| 464 |
+
print(".", end="", flush=True)
|
| 465 |
+
|
| 466 |
+
# Save to cache.
|
| 467 |
+
if cache:
|
| 468 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
| 469 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
| 470 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
| 471 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 472 |
+
with open(temp_file, "wb") as f:
|
| 473 |
+
f.write(url_data)
|
| 474 |
+
os.replace(temp_file, cache_file) # atomic
|
| 475 |
+
if return_filename:
|
| 476 |
+
return cache_file
|
| 477 |
+
|
| 478 |
+
# Return data as file object.
|
| 479 |
+
assert not return_filename
|
| 480 |
+
return io.BytesIO(url_data)
|
src/infra/__init__.py
ADDED
|
File without changes
|
src/infra/experiments.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#----------------------------------------------------------------------------
|
| 2 |
+
# Here, we keep the experiments HPs in case we want to do mass-launching via SLURM
|
| 3 |
+
#----------------------------------------------------------------------------
|
| 4 |
+
|
| 5 |
+
mocogan_sg2:
|
| 6 |
+
common_args:
|
| 7 |
+
model: mocogan
|
| 8 |
+
training.batch: 16
|
| 9 |
+
dataset.max_num_frames: 32
|
| 10 |
+
experiments:
|
| 11 |
+
b16_mnf16:
|
| 12 |
+
sampling: traditional_16
|
| 13 |
+
dataset.max_num_frames: 16
|
| 14 |
+
model.generator.motion.long_history: false
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
ffs:
|
| 19 |
+
common_args:
|
| 20 |
+
sampling.num_frames_per_video: 3
|
| 21 |
+
experiments:
|
| 22 |
+
mnf1024_sfpm32_minperiod16: {}
|
| 23 |
+
mnf1024_sfpm32_minperiod32:
|
| 24 |
+
model.generator.time_enc.min_period_len: 32
|
| 25 |
+
|
| 26 |
+
#----------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
sky_timelapse:
|
| 29 |
+
common_args:
|
| 30 |
+
sampling.num_frames_per_video: 3
|
| 31 |
+
experiments:
|
| 32 |
+
mnf1024_sfpm32_minperiod16: {}
|
| 33 |
+
mnf1024_sfpm256_minperiod256:
|
| 34 |
+
model.generator.motion.motion_z_distance: 256
|
| 35 |
+
model.generator.time_enc.min_period_len: 256
|
| 36 |
+
|
| 37 |
+
#----------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
highres:
|
| 40 |
+
common_args:
|
| 41 |
+
training.metrics: \"fvd2048_16f,fvd2048_128f_subsample,fid50k_full\"
|
| 42 |
+
training.batch: 16
|
| 43 |
+
sampling.num_frames_per_video: 2
|
| 44 |
+
experiments:
|
| 45 |
+
mnf1024_sfpm32_minperiod16_batch16: {}
|
| 46 |
+
mnf32_sfpm32_minperiod16_batch16:
|
| 47 |
+
dataset.max_num_frames: 32
|
| 48 |
+
|
| 49 |
+
#----------------------------------------------------------------------------
|
| 50 |
+
|
| 51 |
+
cond_ablation_ffs:
|
| 52 |
+
common_args:
|
| 53 |
+
sampling.num_frames_per_video: 3
|
| 54 |
+
experiments:
|
| 55 |
+
hyper_mod:
|
| 56 |
+
model.discriminator.hyper_type: hyper
|
| 57 |
+
without_proj_cond:
|
| 58 |
+
model.discriminator.dummy_c: true
|
| 59 |
+
|
| 60 |
+
#----------------------------------------------------------------------------
|
src/infra/launch.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run a __reproducible__ experiment on __allocated__ resources
|
| 3 |
+
It submits a slurm job(s) with the given hyperparams which will then execute `slurm_job.py`
|
| 4 |
+
This is the main entry-point
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
os.environ["HYDRA_FULL_ERROR"] = "1"
|
| 9 |
+
|
| 10 |
+
import subprocess
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
import hydra
|
| 14 |
+
from omegaconf import DictConfig, OmegaConf
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
from utils import create_project_dir, recursive_instantiate
|
| 18 |
+
|
| 19 |
+
#----------------------------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
HYDRA_ARGS = "hydra.run.dir=. hydra.output_subdir=null hydra/job_logging=disabled hydra/hydra_logging=disabled"
|
| 22 |
+
|
| 23 |
+
#----------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
@hydra.main(config_path="../../configs", config_name="config.yaml")
|
| 26 |
+
def main(cfg: DictConfig):
|
| 27 |
+
recursive_instantiate(cfg)
|
| 28 |
+
OmegaConf.set_struct(cfg, True)
|
| 29 |
+
cfg.env.project_path = str(cfg.env.project_path) # This is needed to evaluate ${hydra:runtime.cwd}
|
| 30 |
+
|
| 31 |
+
before_train_cmd = '\n'.join(cfg.env.before_train_commands)
|
| 32 |
+
before_train_cmd = before_train_cmd + '\n' if len(before_train_cmd) > 0 else ''
|
| 33 |
+
torch_extensions_dir = os.environ.get('TORCH_EXTENSIONS_DIR', cfg.env.torch_extensions_dir)
|
| 34 |
+
training_cmd = f'{before_train_cmd}TORCH_EXTENSIONS_DIR={torch_extensions_dir} cd {cfg.project_release_dir} && {cfg.env.python_bin} src/train.py {HYDRA_ARGS}'
|
| 35 |
+
quiet = cfg.get('quiet', False)
|
| 36 |
+
training_cmd_save_path = os.path.join(cfg.project_release_dir, 'training_cmd.sh')
|
| 37 |
+
cfg_save_path = os.path.join(cfg.project_release_dir, 'experiment_config.yaml')
|
| 38 |
+
|
| 39 |
+
if not quiet:
|
| 40 |
+
print('<=== TRAINING COMMAND START ===>')
|
| 41 |
+
print(training_cmd)
|
| 42 |
+
print('<=== TRAINING COMMAND END ===>')
|
| 43 |
+
|
| 44 |
+
is_running_from_scratch = True
|
| 45 |
+
|
| 46 |
+
if cfg.training.resume == "latest" and os.path.isdir(cfg.project_release_dir) and os.path.isfile(training_cmd_save_path) and os.path.isfile(cfg_save_path):
|
| 47 |
+
is_running_from_scratch = False
|
| 48 |
+
if not quiet:
|
| 49 |
+
print("We are going to resume the training and the experiment already exists. " \
|
| 50 |
+
"That's why the provided config/training_cmd are discarded and the project dir is not created.")
|
| 51 |
+
|
| 52 |
+
if is_running_from_scratch and not cfg.print_only:
|
| 53 |
+
create_project_dir(
|
| 54 |
+
cfg.project_release_dir,
|
| 55 |
+
cfg.env.objects_to_copy,
|
| 56 |
+
cfg.env.symlinks_to_create,
|
| 57 |
+
quiet=quiet,
|
| 58 |
+
ignore_uncommited_changes=cfg.get('ignore_uncommited_changes', False),
|
| 59 |
+
overwrite=cfg.get('overwrite', False))
|
| 60 |
+
|
| 61 |
+
with open(training_cmd_save_path, 'w') as f:
|
| 62 |
+
f.write(training_cmd + '\n')
|
| 63 |
+
if not quiet:
|
| 64 |
+
print(f'Saved training command in {training_cmd_save_path}')
|
| 65 |
+
|
| 66 |
+
with open(cfg_save_path, 'w') as f:
|
| 67 |
+
OmegaConf.save(config=cfg, f=f)
|
| 68 |
+
if not quiet:
|
| 69 |
+
print(f'Saved config in {cfg_save_path}')
|
| 70 |
+
|
| 71 |
+
if not cfg.print_only:
|
| 72 |
+
os.chdir(cfg.project_release_dir)
|
| 73 |
+
|
| 74 |
+
if cfg.slurm:
|
| 75 |
+
assert Path(cfg.dataset.path_for_slurm_job).exists()
|
| 76 |
+
|
| 77 |
+
curr_job_id = None
|
| 78 |
+
|
| 79 |
+
for i in range(cfg.job_sequence_length):
|
| 80 |
+
if i == 0:
|
| 81 |
+
deps_args_str = ''
|
| 82 |
+
else:
|
| 83 |
+
deps_args_str = f'--dependency=afterany:{curr_job_id}'
|
| 84 |
+
|
| 85 |
+
# Submitting the slurm job
|
| 86 |
+
qos_arg_str = f'--account {os.environ["PRIORITY_BOOST_ACC"]}' if cfg.use_qos else ''
|
| 87 |
+
output_file_arg_str = f'--output {cfg.project_release_dir}/slurm_{i}.log'
|
| 88 |
+
submit_job_cmd = f'sbatch {cfg.sbatch_args_str} {output_file_arg_str} {qos_arg_str} --export=ALL,{cfg.env_args_str} {deps_args_str} src/infra/slurm_job_proxy.sh'
|
| 89 |
+
|
| 90 |
+
if cfg.print_only:
|
| 91 |
+
print(submit_job_cmd)
|
| 92 |
+
curr_job_id = "DUMMY_JOB_ID"
|
| 93 |
+
else:
|
| 94 |
+
result = subprocess.run(submit_job_cmd, stdout=subprocess.PIPE, shell=True)
|
| 95 |
+
output_str = result.stdout.decode("utf-8").strip("\n") # It has a format of "Submitted batch job 17033559"
|
| 96 |
+
if not quiet or i == 0:
|
| 97 |
+
print(output_str)
|
| 98 |
+
curr_job_id = re.findall(r"^Submitted\ batch\ job\ \d{5,8}$", output_str)
|
| 99 |
+
assert len(curr_job_id) == 1, f"Bad output: `{output_str}`"
|
| 100 |
+
curr_job_id = int(curr_job_id[0][len('Submitted batch job '):])
|
| 101 |
+
else:
|
| 102 |
+
assert cfg.job_sequence_length == 1, "You can use a job sequence only when running via slurm."
|
| 103 |
+
if cfg.print_only:
|
| 104 |
+
print(training_cmd)
|
| 105 |
+
else:
|
| 106 |
+
os.system(training_cmd)
|
| 107 |
+
|
| 108 |
+
#----------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
| 112 |
+
|
| 113 |
+
#----------------------------------------------------------------------------
|
src/infra/slurm_batch_launch.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import copy
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from omegaconf import OmegaConf, DictConfig
|
| 6 |
+
from src.infra.utils import cfg_to_args_str
|
| 7 |
+
|
| 8 |
+
#----------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
HYDRA_ARGS = "hydra.run.dir=. hydra.output_subdir=null hydra/job_logging=disabled hydra/hydra_logging=disabled"
|
| 11 |
+
|
| 12 |
+
#----------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
def batch_launch(launcher: str, experiments_dir: os.PathLike, cfg: DictConfig, datasets: List[str], print_only: bool, time: str, use_qos: bool=False, other_args: Dict={}, num_gpus: int=4, *args, **kwargs):
|
| 15 |
+
for dataset in datasets:
|
| 16 |
+
for exp_args in construct_experiments_args(cfg, *args, **kwargs):
|
| 17 |
+
exp_args['sbatch_args.time'] = time
|
| 18 |
+
exp_args['experiments_dir'] = experiments_dir
|
| 19 |
+
exp_args['dataset'] = dataset
|
| 20 |
+
exp_args['env'] = 'ibex'
|
| 21 |
+
exp_args['use_qos'] = use_qos
|
| 22 |
+
exp_args = {**exp_args, **other_args}
|
| 23 |
+
curr_exp_args_str = cfg_to_args_str(exp_args, use_dashes=False)
|
| 24 |
+
launching_command = f"{launcher} num_gpus={num_gpus} {curr_exp_args_str}"
|
| 25 |
+
|
| 26 |
+
if print_only:
|
| 27 |
+
os.makedirs(exp_args['experiments_dir'], exist_ok=True)
|
| 28 |
+
print(launching_command)
|
| 29 |
+
else:
|
| 30 |
+
os.system(launching_command)
|
| 31 |
+
|
| 32 |
+
#----------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
def construct_experiments_args(cfg: DictConfig, experiments_list: Optional[List[str]]=None, suffix: str="") -> List[Dict]:
|
| 35 |
+
args_dicts = []
|
| 36 |
+
common_cfg = cfg.get('common_args', {})
|
| 37 |
+
|
| 38 |
+
for exp_name, exp_cfg in to_dict(cfg.experiments).items():
|
| 39 |
+
if not experiments_list is None and not exp_name in experiments_list:
|
| 40 |
+
continue
|
| 41 |
+
curr_exp_cfg = {**copy.deepcopy(to_dict(common_cfg)), **to_dict(exp_cfg)}
|
| 42 |
+
curr_exp_cfg['exp_suffix'] = f'{exp_name}{suffix}'
|
| 43 |
+
args_dicts.append(curr_exp_cfg)
|
| 44 |
+
|
| 45 |
+
return args_dicts
|
| 46 |
+
|
| 47 |
+
#----------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
def to_dict(cfg) -> Dict:
|
| 50 |
+
return OmegaConf.to_container(OmegaConf.create({**cfg}))
|
| 51 |
+
|
| 52 |
+
#----------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
parser = argparse.ArgumentParser(description="Experiments launcher")
|
| 56 |
+
parser.add_argument('-e', '--series_name', type=str, required=True, help="Which experiments series to launch?")
|
| 57 |
+
parser.add_argument('-d', '--datasets', required=True, type=str, help='Comma-separate list of datasets')
|
| 58 |
+
parser.add_argument('-p', '--print_only', action='store_true', help='Just print commands and exit?')
|
| 59 |
+
parser.add_argument('-t', '--time', type=str, default='1-0', help='Which time to specify for the sbatch command?')
|
| 60 |
+
parser.add_argument('-q', '--use_qos', action='store_true', help='Should we use QoS to launch jobs?')
|
| 61 |
+
parser.add_argument('--experiments_list', type=str, help='Should we run only some specific experiments from this experiments series?')
|
| 62 |
+
parser.add_argument('--other_args', type=str, default="", help='Additional arguments for the experiments')
|
| 63 |
+
parser.add_argument('--suffix', type=str, default="", help='Additional suffix for the experiments')
|
| 64 |
+
parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs to use per each experiment')
|
| 65 |
+
parser.add_argument('--project_dir', type=str, default=os.getcwd(), help='Project directory path')
|
| 66 |
+
parser.add_argument('--project_dir_for_exps_cfg', type=str, help="Overwrite the project directory to use for experiments.yaml. Useful for debugging the config.")
|
| 67 |
+
args = parser.parse_args()
|
| 68 |
+
|
| 69 |
+
os.chdir(args.project_dir)
|
| 70 |
+
user = os.environ.get('USER', 'unknown')
|
| 71 |
+
python_bin = os.path.join(args.project_dir, 'env/bin/python')
|
| 72 |
+
launcher = f"{python_bin} src/infra/launch.py {HYDRA_ARGS} +quiet=true slurm=true"
|
| 73 |
+
experiments_dir = f'experiments/{user}/{args.series_name}'
|
| 74 |
+
exps_cfg_path = os.path.join(args.project_dir if args.project_dir_for_exps_cfg is None else args.project_dir_for_exps_cfg, 'src/infra/experiments.yaml')
|
| 75 |
+
all_exp_series = OmegaConf.load(exps_cfg_path)
|
| 76 |
+
assert args.series_name in all_exp_series, f"Experiments series not found: {args.series_name}"
|
| 77 |
+
cfg = all_exp_series[args.series_name]
|
| 78 |
+
datasets = args.datasets.split(',')
|
| 79 |
+
experiments_list = None if args.experiments_list is None else args.experiments_list.split(',')
|
| 80 |
+
other_args = {kv.split('=')[0]: kv.split('=')[1] for kv in args.other_args.split(',') if len(kv.split('=')) == 2}
|
| 81 |
+
|
| 82 |
+
batch_launch(
|
| 83 |
+
launcher=launcher,
|
| 84 |
+
experiments_dir=experiments_dir,
|
| 85 |
+
cfg=cfg,
|
| 86 |
+
datasets=datasets,
|
| 87 |
+
print_only=args.print_only,
|
| 88 |
+
time=args.time,
|
| 89 |
+
use_qos=args.use_qos,
|
| 90 |
+
experiments_list=experiments_list,
|
| 91 |
+
other_args=other_args,
|
| 92 |
+
suffix=args.suffix,
|
| 93 |
+
num_gpus=args.num_gpus,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
#----------------------------------------------------------------------------
|
src/infra/slurm_job.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Must be launched from the released project dir
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import random
|
| 8 |
+
import subprocess
|
| 9 |
+
from shutil import copyfile
|
| 10 |
+
|
| 11 |
+
import hydra
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
|
| 14 |
+
# Unfortunately, (AFAIK) we cannot pass arguments normally (to parse them with argparse)
|
| 15 |
+
# that's why we are reading them from env
|
| 16 |
+
SLURM_JOB_ID = os.getenv('SLURM_JOB_ID')
|
| 17 |
+
project_dir = os.getenv('project_dir')
|
| 18 |
+
python_bin = os.getenv('python_bin')
|
| 19 |
+
|
| 20 |
+
# Printing the environment
|
| 21 |
+
print('PROJECT DIR:', project_dir)
|
| 22 |
+
print(f'SLURM_JOB_ID: {SLURM_JOB_ID}')
|
| 23 |
+
print('HOSTNAME:', subprocess.run(['hostname'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
| 24 |
+
print(subprocess.run([os.path.join(os.path.dirname(python_bin), 'gpustat')], stdout=subprocess.PIPE).stdout.decode('utf-8'))
|
| 25 |
+
|
| 26 |
+
@hydra.main(config_name=os.path.join(project_dir, 'experiment_config.yaml'))
|
| 27 |
+
def main(cfg: DictConfig):
|
| 28 |
+
os.chdir(project_dir)
|
| 29 |
+
|
| 30 |
+
target_data_dir_base = os.path.dirname(cfg.dataset.path)
|
| 31 |
+
if os.path.islink(target_data_dir_base):
|
| 32 |
+
os.makedirs(os.readlink(target_data_dir_base), exist_ok=True)
|
| 33 |
+
else:
|
| 34 |
+
os.makedirs(target_data_dir_base, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
copyfile(cfg.dataset.path_for_slurm_job, cfg.dataset.path)
|
| 37 |
+
print(f'Copied the data: {cfg.dataset.path_for_slurm_job} => {cfg.dataset.path}. Starting the training...')
|
| 38 |
+
|
| 39 |
+
training_cmd = open('training_cmd.sh').read()
|
| 40 |
+
print('<=== TRAINING COMMAND ===>')
|
| 41 |
+
print(training_cmd)
|
| 42 |
+
os.system(training_cmd)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
src/infra/slurm_job_proxy.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# We need this proxy so not to put the shebang into `slurm_job.py`
|
| 3 |
+
# We cannot put a shebang there since we use different python executors for it
|
| 4 |
+
$python_bin $python_script
|
src/infra/utils.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import subprocess
|
| 4 |
+
from distutils.dir_util import copy_tree
|
| 5 |
+
from shutil import copyfile
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
from hydra.utils import instantiate
|
| 9 |
+
import click
|
| 10 |
+
import git
|
| 11 |
+
from omegaconf import DictConfig
|
| 12 |
+
|
| 13 |
+
#----------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
def copy_objects(target_dir: os.PathLike, objects_to_copy: List[os.PathLike]):
|
| 16 |
+
for src_path in objects_to_copy:
|
| 17 |
+
trg_path = os.path.join(target_dir, os.path.basename(src_path))
|
| 18 |
+
|
| 19 |
+
if os.path.islink(src_path):
|
| 20 |
+
os.symlink(os.readlink(src_path), trg_path)
|
| 21 |
+
elif os.path.isfile(src_path):
|
| 22 |
+
copyfile(src_path, trg_path)
|
| 23 |
+
elif os.path.isdir(src_path):
|
| 24 |
+
copy_tree(src_path, trg_path)
|
| 25 |
+
else:
|
| 26 |
+
raise NotImplementedError(f"Unknown object type: {src_path}")
|
| 27 |
+
|
| 28 |
+
#----------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def create_symlinks(target_dir: os.PathLike, symlinks_to_create: List[os.PathLike]):
|
| 31 |
+
"""
|
| 32 |
+
Creates symlinks to the given paths
|
| 33 |
+
"""
|
| 34 |
+
for src_path in symlinks_to_create:
|
| 35 |
+
trg_path = os.path.join(target_dir, os.path.basename(src_path))
|
| 36 |
+
|
| 37 |
+
if os.path.islink(src_path):
|
| 38 |
+
# Let's not create symlinks to symlinks
|
| 39 |
+
# Since dropping the current symlink will break the experiment
|
| 40 |
+
os.symlink(os.readlink(src_path), trg_path)
|
| 41 |
+
else:
|
| 42 |
+
print(f'Creating a symlink to {src_path}, so try not to delete it occasionally!')
|
| 43 |
+
os.symlink(src_path, trg_path)
|
| 44 |
+
|
| 45 |
+
#----------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
def is_git_repo(path: os.PathLike):
|
| 48 |
+
try:
|
| 49 |
+
_ = git.Repo(path).git_dir
|
| 50 |
+
return True
|
| 51 |
+
except git.exc.InvalidGitRepositoryError:
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
#----------------------------------------------------------------------------
|
| 55 |
+
|
| 56 |
+
def create_project_dir(
|
| 57 |
+
project_dir: os.PathLike,
|
| 58 |
+
objects_to_copy: List[os.PathLike],
|
| 59 |
+
symlinks_to_create: List[os.PathLike],
|
| 60 |
+
quiet: bool=False,
|
| 61 |
+
ignore_uncommited_changes: bool=False,
|
| 62 |
+
overwrite: bool=False):
|
| 63 |
+
|
| 64 |
+
if is_git_repo(os.getcwd()) and are_there_uncommitted_changes():
|
| 65 |
+
if ignore_uncommited_changes or click.confirm("There are uncommited changes. Continue?", default=False):
|
| 66 |
+
pass
|
| 67 |
+
else:
|
| 68 |
+
raise PermissionError("Cannot created a dir when there are uncommited changes")
|
| 69 |
+
|
| 70 |
+
if os.path.exists(project_dir):
|
| 71 |
+
if overwrite or click.confirm(f'Dir {project_dir} already exists. Overwrite it?', default=False):
|
| 72 |
+
shutil.rmtree(project_dir)
|
| 73 |
+
else:
|
| 74 |
+
print('User refused to delete an existing project dir.')
|
| 75 |
+
raise PermissionError("There is an existing dir and I cannot delete it.")
|
| 76 |
+
|
| 77 |
+
os.makedirs(project_dir)
|
| 78 |
+
copy_objects(project_dir, objects_to_copy)
|
| 79 |
+
create_symlinks(project_dir, symlinks_to_create)
|
| 80 |
+
|
| 81 |
+
if not quiet:
|
| 82 |
+
print(f'Created a project dir: {project_dir}')
|
| 83 |
+
|
| 84 |
+
#----------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def get_git_hash() -> Optional[str]:
|
| 87 |
+
if not is_git_repo(os.getcwd()):
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
return subprocess \
|
| 92 |
+
.check_output(['git', 'rev-parse', '--short', 'HEAD']) \
|
| 93 |
+
.decode("utf-8") \
|
| 94 |
+
.strip()
|
| 95 |
+
except:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
#----------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
# def get_experiment_path(master_dir: os.PathLike, experiment_name: str) -> os.PathLike:
|
| 101 |
+
# return os.path.join(master_dir, f"{experiment_name}-{get_git_hash()}")
|
| 102 |
+
|
| 103 |
+
#----------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
def get_git_hash_suffix() -> str:
|
| 106 |
+
git_hash: Optional[str] = get_git_hash()
|
| 107 |
+
git_hash_suffix = "-nogit" if git_hash is None else f"-{git_hash}"
|
| 108 |
+
|
| 109 |
+
return git_hash_suffix
|
| 110 |
+
|
| 111 |
+
#----------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def are_there_uncommitted_changes() -> bool:
|
| 114 |
+
return len(subprocess.check_output('git status -s'.split()).decode("utf-8")) > 0
|
| 115 |
+
|
| 116 |
+
#----------------------------------------------------------------------------
|
| 117 |
+
|
| 118 |
+
def cfg_to_args_str(cfg: DictConfig, use_dashes=True) -> str:
|
| 119 |
+
dashes = '--' if use_dashes else ''
|
| 120 |
+
|
| 121 |
+
return ' '.join([f'{dashes}{p}={cfg[p]}' for p in cfg])
|
| 122 |
+
|
| 123 |
+
#----------------------------------------------------------------------------
|
| 124 |
+
|
| 125 |
+
def recursive_instantiate(cfg: DictConfig):
|
| 126 |
+
for key in cfg:
|
| 127 |
+
# print(type(cfg[key]))
|
| 128 |
+
if isinstance(cfg[key], DictConfig):
|
| 129 |
+
if '_target_' in cfg[key]:
|
| 130 |
+
cfg[key] = instantiate(cfg[key])
|
| 131 |
+
else:
|
| 132 |
+
recursive_instantiate(cfg[key])
|
| 133 |
+
|
| 134 |
+
#----------------------------------------------------------------------------
|
| 135 |
+
|
| 136 |
+
def num_gpus_to_mem(num_gpus: int, mem_per_gpu: 64) -> str:
|
| 137 |
+
# Doing it here since hydra config cannot do formatting for ${...}
|
| 138 |
+
return f"{num_gpus * mem_per_gpu}G"
|
| 139 |
+
|
| 140 |
+
#----------------------------------------------------------------------------
|
src/metrics/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
# empty
|
src/metrics/frechet_inception_distance.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Frechet Inception Distance (FID) from the paper
|
| 10 |
+
"GANs trained by a two time-scale update rule converge to a local Nash
|
| 11 |
+
equilibrium". Matches the original implementation by Heusel et al. at
|
| 12 |
+
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import scipy.linalg
|
| 16 |
+
from . import metric_utils
|
| 17 |
+
|
| 18 |
+
NUM_FRAMES_IN_BATCH = {128: 32, 256: 32, 512: 8, 1024: 2}
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
def compute_fid(opts, max_real, num_gen):
|
| 23 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 24 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
| 25 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
| 26 |
+
|
| 27 |
+
batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution]
|
| 28 |
+
|
| 29 |
+
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
| 30 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 31 |
+
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, use_image_dataset=True).get_mean_cov()
|
| 32 |
+
|
| 33 |
+
if opts.generator_as_dataset:
|
| 34 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
|
| 35 |
+
gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
|
| 36 |
+
gen_kwargs = dict(use_image_dataset=True)
|
| 37 |
+
else:
|
| 38 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
|
| 39 |
+
gen_opts = opts
|
| 40 |
+
gen_kwargs = dict()
|
| 41 |
+
|
| 42 |
+
mu_gen, sigma_gen = compute_gen_stats_fn(
|
| 43 |
+
opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, batch_size=batch_size,
|
| 44 |
+
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, **gen_kwargs).get_mean_cov()
|
| 45 |
+
|
| 46 |
+
if opts.rank != 0:
|
| 47 |
+
return float('nan')
|
| 48 |
+
|
| 49 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 50 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
| 51 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 52 |
+
return float(fid)
|
| 53 |
+
|
| 54 |
+
#----------------------------------------------------------------------------
|
src/metrics/frechet_video_distance.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Frechet Video Distance (FVD). Matches the original tensorflow implementation from
|
| 3 |
+
https://github.com/google-research/google-research/blob/master/frechet_video_distance/frechet_video_distance.py
|
| 4 |
+
up to the upsampling operation. Note that this tf.hub I3D model is different from the one released in the I3D repo.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import numpy as np
|
| 9 |
+
import scipy.linalg
|
| 10 |
+
from . import metric_utils
|
| 11 |
+
|
| 12 |
+
#----------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
NUM_FRAMES_IN_BATCH = {128: 128, 256: 128, 512: 64, 1024: 32}
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def compute_fvd(opts, max_real: int, num_gen: int, num_frames: int, subsample_factor: int=1):
|
| 19 |
+
# Perfectly reproduced torchscript version of the I3D model, trained on Kinetics-400, used here:
|
| 20 |
+
# https://github.com/google-research/google-research/blob/master/frechet_video_distance/frechet_video_distance.py
|
| 21 |
+
# Note that the weights on tf.hub (used in the script above) differ from the original released weights
|
| 22 |
+
detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
|
| 23 |
+
detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
|
| 24 |
+
|
| 25 |
+
opts = copy.deepcopy(opts)
|
| 26 |
+
opts.dataset_kwargs.load_n_consecutive = num_frames
|
| 27 |
+
opts.dataset_kwargs.subsample_factor = subsample_factor
|
| 28 |
+
opts.dataset_kwargs.discard_short_videos = True
|
| 29 |
+
batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution] // num_frames
|
| 30 |
+
|
| 31 |
+
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
| 32 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, rel_lo=0, rel_hi=0,
|
| 33 |
+
capture_mean_cov=True, max_items=max_real, temporal_detector=True, batch_size=batch_size).get_mean_cov()
|
| 34 |
+
|
| 35 |
+
if opts.generator_as_dataset:
|
| 36 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
|
| 37 |
+
gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
|
| 38 |
+
gen_opts.dataset_kwargs.load_n_consecutive = num_frames
|
| 39 |
+
gen_opts.dataset_kwargs.load_n_consecutive_random_offset = False
|
| 40 |
+
gen_opts.dataset_kwargs.subsample_factor = subsample_factor
|
| 41 |
+
gen_kwargs = dict()
|
| 42 |
+
else:
|
| 43 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
|
| 44 |
+
gen_opts = opts
|
| 45 |
+
gen_kwargs = dict(num_video_frames=num_frames, subsample_factor=subsample_factor)
|
| 46 |
+
|
| 47 |
+
mu_gen, sigma_gen = compute_gen_stats_fn(
|
| 48 |
+
opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, rel_lo=0, rel_hi=1, capture_mean_cov=True,
|
| 49 |
+
max_items=num_gen, temporal_detector=True, batch_size=batch_size, **gen_kwargs).get_mean_cov()
|
| 50 |
+
|
| 51 |
+
if opts.rank != 0:
|
| 52 |
+
return float('nan')
|
| 53 |
+
|
| 54 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 55 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
| 56 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 57 |
+
return float(fid)
|
| 58 |
+
|
| 59 |
+
#----------------------------------------------------------------------------
|
src/metrics/inception_score.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Inception Score (IS) from the paper "Improved techniques for training
|
| 10 |
+
GANs". Matches the original implementation by Salimans et al. at
|
| 11 |
+
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from . import metric_utils
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def compute_is(opts, num_gen, num_splits):
|
| 19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 20 |
+
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
| 21 |
+
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
| 22 |
+
|
| 23 |
+
if opts.generator_as_dataset:
|
| 24 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
|
| 25 |
+
gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
|
| 26 |
+
gen_kwargs = dict(use_image_dataset=True)
|
| 27 |
+
else:
|
| 28 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
|
| 29 |
+
gen_opts = opts
|
| 30 |
+
gen_kwargs = dict()
|
| 31 |
+
|
| 32 |
+
gen_probs = compute_gen_stats_fn(
|
| 33 |
+
opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 34 |
+
capture_all=True, max_items=num_gen, **gen_kwargs).get_all()
|
| 35 |
+
|
| 36 |
+
if opts.rank != 0:
|
| 37 |
+
return float('nan'), float('nan')
|
| 38 |
+
|
| 39 |
+
scores = []
|
| 40 |
+
for i in range(num_splits):
|
| 41 |
+
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
| 42 |
+
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
| 43 |
+
kl = np.mean(np.sum(kl, axis=1))
|
| 44 |
+
scores.append(np.exp(kl))
|
| 45 |
+
return float(np.mean(scores)), float(np.std(scores))
|
| 46 |
+
|
| 47 |
+
#----------------------------------------------------------------------------
|
src/metrics/kernel_inception_distance.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
| 10 |
+
GANs". Matches the original implementation by Binkowski et al. at
|
| 11 |
+
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from . import metric_utils
|
| 15 |
+
|
| 16 |
+
#----------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
| 19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 20 |
+
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
| 21 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
| 22 |
+
|
| 23 |
+
real_features = metric_utils.compute_feature_stats_for_dataset(
|
| 24 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 25 |
+
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real, use_image_dataset=True).get_all()
|
| 26 |
+
|
| 27 |
+
gen_features = metric_utils.compute_feature_stats_for_generator(
|
| 28 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 29 |
+
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
| 30 |
+
|
| 31 |
+
if opts.rank != 0:
|
| 32 |
+
return float('nan')
|
| 33 |
+
|
| 34 |
+
n = real_features.shape[1]
|
| 35 |
+
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
| 36 |
+
t = 0
|
| 37 |
+
for _subset_idx in range(num_subsets):
|
| 38 |
+
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
| 39 |
+
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
| 40 |
+
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
| 41 |
+
b = (x @ y.T / n + 1) ** 3
|
| 42 |
+
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
| 43 |
+
kid = t / num_subsets / m
|
| 44 |
+
return float(kid) * 1000.0
|
| 45 |
+
|
| 46 |
+
#----------------------------------------------------------------------------
|
src/metrics/metric_main.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import json
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
from src import dnnlib
|
| 15 |
+
|
| 16 |
+
from . import metric_utils
|
| 17 |
+
from . import frechet_inception_distance
|
| 18 |
+
from . import kernel_inception_distance
|
| 19 |
+
from . import inception_score
|
| 20 |
+
from . import video_inception_score
|
| 21 |
+
from . import frechet_video_distance
|
| 22 |
+
|
| 23 |
+
#----------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
_metric_dict = dict() # name => fn
|
| 26 |
+
|
| 27 |
+
def register_metric(fn):
|
| 28 |
+
assert callable(fn)
|
| 29 |
+
_metric_dict[fn.__name__] = fn
|
| 30 |
+
return fn
|
| 31 |
+
|
| 32 |
+
def is_valid_metric(metric):
|
| 33 |
+
return metric in _metric_dict
|
| 34 |
+
|
| 35 |
+
def list_valid_metrics():
|
| 36 |
+
return list(_metric_dict.keys())
|
| 37 |
+
|
| 38 |
+
def is_power_of_two(n: int) -> bool:
|
| 39 |
+
return (n & (n-1) == 0) and n != 0
|
| 40 |
+
|
| 41 |
+
#----------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def calc_metric(metric, num_runs: int=1, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
| 44 |
+
assert is_valid_metric(metric)
|
| 45 |
+
opts = metric_utils.MetricOptions(**kwargs)
|
| 46 |
+
|
| 47 |
+
# Calculate.
|
| 48 |
+
start_time = time.time()
|
| 49 |
+
all_runs_results = [_metric_dict[metric](opts) for _ in range(num_runs)]
|
| 50 |
+
total_time = time.time() - start_time
|
| 51 |
+
|
| 52 |
+
# Broadcast results.
|
| 53 |
+
for results in all_runs_results:
|
| 54 |
+
for key, value in list(results.items()):
|
| 55 |
+
if opts.num_gpus > 1:
|
| 56 |
+
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
| 57 |
+
torch.distributed.broadcast(tensor=value, src=0)
|
| 58 |
+
value = float(value.cpu())
|
| 59 |
+
results[key] = value
|
| 60 |
+
|
| 61 |
+
if num_runs > 1:
|
| 62 |
+
results = {f'{key}_run{i+1:02d}': value for i, results in enumerate(all_runs_results) for key, value in results.items()}
|
| 63 |
+
for key, value in all_runs_results[0].items():
|
| 64 |
+
all_runs_values = [r[key] for r in all_runs_results]
|
| 65 |
+
results[f'{key}_mean'] = np.mean(all_runs_values)
|
| 66 |
+
results[f'{key}_std'] = np.std(all_runs_values)
|
| 67 |
+
else:
|
| 68 |
+
results = all_runs_results[0]
|
| 69 |
+
|
| 70 |
+
# Decorate with metadata.
|
| 71 |
+
return dnnlib.EasyDict(
|
| 72 |
+
results = dnnlib.EasyDict(results),
|
| 73 |
+
metric = metric,
|
| 74 |
+
total_time = total_time,
|
| 75 |
+
total_time_str = dnnlib.util.format_time(total_time),
|
| 76 |
+
num_gpus = opts.num_gpus,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
#----------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
| 82 |
+
metric = result_dict['metric']
|
| 83 |
+
assert is_valid_metric(metric)
|
| 84 |
+
if run_dir is not None and snapshot_pkl is not None:
|
| 85 |
+
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
| 86 |
+
|
| 87 |
+
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
| 88 |
+
print(jsonl_line)
|
| 89 |
+
if run_dir is not None and os.path.isdir(run_dir):
|
| 90 |
+
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
| 91 |
+
f.write(jsonl_line + '\n')
|
| 92 |
+
|
| 93 |
+
#----------------------------------------------------------------------------
|
| 94 |
+
# Primary metrics.
|
| 95 |
+
|
| 96 |
+
@register_metric
|
| 97 |
+
def fid50k_full(opts):
|
| 98 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 99 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
| 100 |
+
return dict(fid50k_full=fid)
|
| 101 |
+
|
| 102 |
+
@register_metric
|
| 103 |
+
def kid50k_full(opts):
|
| 104 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 105 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
| 106 |
+
return dict(kid50k_full=kid)
|
| 107 |
+
|
| 108 |
+
@register_metric
|
| 109 |
+
def is50k(opts):
|
| 110 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 111 |
+
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
| 112 |
+
return dict(is50k_mean=mean, is50k_std=std)
|
| 113 |
+
|
| 114 |
+
@register_metric
|
| 115 |
+
def fvd2048_16f(opts):
|
| 116 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 117 |
+
fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=16)
|
| 118 |
+
return dict(fvd2048_16f=fvd)
|
| 119 |
+
|
| 120 |
+
@register_metric
|
| 121 |
+
def fvd2048_128f(opts):
|
| 122 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 123 |
+
fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=128)
|
| 124 |
+
return dict(fvd2048_128f=fvd)
|
| 125 |
+
|
| 126 |
+
@register_metric
|
| 127 |
+
def fvd2048_128f_subsample8f(opts):
|
| 128 |
+
"""Similar to `fvd2048_128f`, but we sample each 8-th frame"""
|
| 129 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 130 |
+
fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=16, subsample_factor=8)
|
| 131 |
+
return dict(fvd2048_128f_subsample8f=fvd)
|
| 132 |
+
|
| 133 |
+
@register_metric
|
| 134 |
+
def isv2048_ucf(opts):
|
| 135 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 136 |
+
mean, std = video_inception_score.compute_isv(opts, num_gen=2048, num_splits=10, backbone='c3d_ucf101')
|
| 137 |
+
return dict(isv2048_ucf_mean=mean, isv2048_ucf_std=std)
|
| 138 |
+
|
| 139 |
+
#----------------------------------------------------------------------------
|
| 140 |
+
# Legacy metrics.
|
| 141 |
+
|
| 142 |
+
@register_metric
|
| 143 |
+
def fid50k(opts):
|
| 144 |
+
opts.dataset_kwargs.update(max_size=None)
|
| 145 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
| 146 |
+
return dict(fid50k=fid)
|
| 147 |
+
|
| 148 |
+
@register_metric
|
| 149 |
+
def kid50k(opts):
|
| 150 |
+
opts.dataset_kwargs.update(max_size=None)
|
| 151 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
| 152 |
+
return dict(kid50k=kid)
|
| 153 |
+
|
| 154 |
+
#----------------------------------------------------------------------------
|
src/metrics/metric_utils.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import hashlib
|
| 12 |
+
import pickle
|
| 13 |
+
import copy
|
| 14 |
+
import uuid
|
| 15 |
+
from urllib.parse import urlparse
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from src import dnnlib
|
| 19 |
+
from src.training.dataset import video_to_image_dataset_kwargs
|
| 20 |
+
|
| 21 |
+
#----------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class MetricOptions:
|
| 24 |
+
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None,
|
| 25 |
+
progress=None, cache=True, gen_dataset_kwargs={}, generator_as_dataset=False):
|
| 26 |
+
assert 0 <= rank < num_gpus
|
| 27 |
+
self.G = G
|
| 28 |
+
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
| 29 |
+
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
| 30 |
+
self.num_gpus = num_gpus
|
| 31 |
+
self.rank = rank
|
| 32 |
+
self.device = device if device is not None else torch.device('cuda', rank)
|
| 33 |
+
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
| 34 |
+
self.cache = cache
|
| 35 |
+
self.gen_dataset_kwargs = gen_dataset_kwargs
|
| 36 |
+
self.generator_as_dataset = generator_as_dataset
|
| 37 |
+
|
| 38 |
+
#----------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
_feature_detector_cache = dict()
|
| 41 |
+
|
| 42 |
+
def get_feature_detector_name(url):
|
| 43 |
+
return os.path.splitext(url.split('/')[-1])[0]
|
| 44 |
+
|
| 45 |
+
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
| 46 |
+
assert 0 <= rank < num_gpus
|
| 47 |
+
key = (url, device)
|
| 48 |
+
if key not in _feature_detector_cache:
|
| 49 |
+
is_leader = (rank == 0)
|
| 50 |
+
if not is_leader and num_gpus > 1:
|
| 51 |
+
torch.distributed.barrier() # leader goes first
|
| 52 |
+
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
| 53 |
+
if urlparse(url).path.endswith('.pkl'):
|
| 54 |
+
_feature_detector_cache[key] = pickle.load(f).to(device)
|
| 55 |
+
else:
|
| 56 |
+
_feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
|
| 57 |
+
if is_leader and num_gpus > 1:
|
| 58 |
+
torch.distributed.barrier() # others follow
|
| 59 |
+
return _feature_detector_cache[key]
|
| 60 |
+
|
| 61 |
+
#----------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
class FeatureStats:
|
| 64 |
+
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
| 65 |
+
self.capture_all = capture_all
|
| 66 |
+
self.capture_mean_cov = capture_mean_cov
|
| 67 |
+
self.max_items = max_items
|
| 68 |
+
self.num_items = 0
|
| 69 |
+
self.num_features = None
|
| 70 |
+
self.all_features = None
|
| 71 |
+
self.raw_mean = None
|
| 72 |
+
self.raw_cov = None
|
| 73 |
+
|
| 74 |
+
def set_num_features(self, num_features):
|
| 75 |
+
if self.num_features is not None:
|
| 76 |
+
assert num_features == self.num_features
|
| 77 |
+
else:
|
| 78 |
+
self.num_features = num_features
|
| 79 |
+
self.all_features = []
|
| 80 |
+
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
| 81 |
+
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
| 82 |
+
|
| 83 |
+
def is_full(self):
|
| 84 |
+
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
| 85 |
+
|
| 86 |
+
def append(self, x):
|
| 87 |
+
x = np.asarray(x, dtype=np.float32)
|
| 88 |
+
assert x.ndim == 2
|
| 89 |
+
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
| 90 |
+
if self.num_items >= self.max_items:
|
| 91 |
+
return
|
| 92 |
+
x = x[:self.max_items - self.num_items]
|
| 93 |
+
|
| 94 |
+
self.set_num_features(x.shape[1])
|
| 95 |
+
self.num_items += x.shape[0]
|
| 96 |
+
if self.capture_all:
|
| 97 |
+
self.all_features.append(x)
|
| 98 |
+
if self.capture_mean_cov:
|
| 99 |
+
x64 = x.astype(np.float64)
|
| 100 |
+
self.raw_mean += x64.sum(axis=0)
|
| 101 |
+
self.raw_cov += x64.T @ x64
|
| 102 |
+
|
| 103 |
+
def append_torch(self, x, num_gpus=1, rank=0):
|
| 104 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
| 105 |
+
assert 0 <= rank < num_gpus
|
| 106 |
+
if num_gpus > 1:
|
| 107 |
+
ys = []
|
| 108 |
+
for src in range(num_gpus):
|
| 109 |
+
y = x.clone()
|
| 110 |
+
torch.distributed.broadcast(y, src=src)
|
| 111 |
+
ys.append(y)
|
| 112 |
+
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
| 113 |
+
self.append(x.cpu().numpy())
|
| 114 |
+
|
| 115 |
+
def get_all(self):
|
| 116 |
+
assert self.capture_all
|
| 117 |
+
return np.concatenate(self.all_features, axis=0)
|
| 118 |
+
|
| 119 |
+
def get_all_torch(self):
|
| 120 |
+
return torch.from_numpy(self.get_all())
|
| 121 |
+
|
| 122 |
+
def get_mean_cov(self):
|
| 123 |
+
assert self.capture_mean_cov
|
| 124 |
+
mean = self.raw_mean / self.num_items
|
| 125 |
+
cov = self.raw_cov / self.num_items
|
| 126 |
+
cov = cov - np.outer(mean, mean)
|
| 127 |
+
return mean, cov
|
| 128 |
+
|
| 129 |
+
def save(self, pkl_file):
|
| 130 |
+
with open(pkl_file, 'wb') as f:
|
| 131 |
+
pickle.dump(self.__dict__, f)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def load(pkl_file):
|
| 135 |
+
with open(pkl_file, 'rb') as f:
|
| 136 |
+
s = dnnlib.EasyDict(pickle.load(f))
|
| 137 |
+
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
| 138 |
+
obj.__dict__.update(s)
|
| 139 |
+
return obj
|
| 140 |
+
|
| 141 |
+
#----------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
class ProgressMonitor:
|
| 144 |
+
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
| 145 |
+
self.tag = tag
|
| 146 |
+
self.num_items = num_items
|
| 147 |
+
self.verbose = verbose
|
| 148 |
+
self.flush_interval = flush_interval
|
| 149 |
+
self.progress_fn = progress_fn
|
| 150 |
+
self.pfn_lo = pfn_lo
|
| 151 |
+
self.pfn_hi = pfn_hi
|
| 152 |
+
self.pfn_total = pfn_total
|
| 153 |
+
self.start_time = time.time()
|
| 154 |
+
self.batch_time = self.start_time
|
| 155 |
+
self.batch_items = 0
|
| 156 |
+
if self.progress_fn is not None:
|
| 157 |
+
self.progress_fn(self.pfn_lo, self.pfn_total)
|
| 158 |
+
|
| 159 |
+
def update(self, cur_items: int):
|
| 160 |
+
assert (self.num_items is None) or (cur_items <= self.num_items), f"Wrong `items` values: cur_items={cur_items}, self.num_items={self.num_items}"
|
| 161 |
+
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
| 162 |
+
return
|
| 163 |
+
cur_time = time.time()
|
| 164 |
+
total_time = cur_time - self.start_time
|
| 165 |
+
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
| 166 |
+
if (self.verbose) and (self.tag is not None):
|
| 167 |
+
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
| 168 |
+
self.batch_time = cur_time
|
| 169 |
+
self.batch_items = cur_items
|
| 170 |
+
|
| 171 |
+
if (self.progress_fn is not None) and (self.num_items is not None):
|
| 172 |
+
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
| 173 |
+
|
| 174 |
+
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
| 175 |
+
return ProgressMonitor(
|
| 176 |
+
tag = tag,
|
| 177 |
+
num_items = num_items,
|
| 178 |
+
flush_interval = flush_interval,
|
| 179 |
+
verbose = self.verbose,
|
| 180 |
+
progress_fn = self.progress_fn,
|
| 181 |
+
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
| 182 |
+
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
| 183 |
+
pfn_total = self.pfn_total,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
#----------------------------------------------------------------------------
|
| 187 |
+
|
| 188 |
+
@torch.no_grad()
|
| 189 |
+
def compute_feature_stats_for_dataset(
|
| 190 |
+
opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64,
|
| 191 |
+
data_loader_kwargs=None, max_items=None, temporal_detector=False, use_image_dataset=False,
|
| 192 |
+
feature_stats_cls=FeatureStats, **stats_kwargs):
|
| 193 |
+
|
| 194 |
+
dataset_kwargs = video_to_image_dataset_kwargs(opts.dataset_kwargs) if use_image_dataset else opts.dataset_kwargs
|
| 195 |
+
dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
|
| 196 |
+
|
| 197 |
+
if data_loader_kwargs is None:
|
| 198 |
+
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
| 199 |
+
|
| 200 |
+
# Try to lookup from cache.
|
| 201 |
+
cache_file = None
|
| 202 |
+
if opts.cache:
|
| 203 |
+
# Choose cache file name.
|
| 204 |
+
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 205 |
+
stats_kwargs=stats_kwargs, feature_stats_cls=feature_stats_cls.__name__)
|
| 206 |
+
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
| 207 |
+
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
|
| 208 |
+
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
| 209 |
+
|
| 210 |
+
# Check if the file exists (all processes must agree).
|
| 211 |
+
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
| 212 |
+
if opts.num_gpus > 1:
|
| 213 |
+
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
| 214 |
+
torch.distributed.broadcast(tensor=flag, src=0)
|
| 215 |
+
flag = (float(flag.cpu()) != 0)
|
| 216 |
+
|
| 217 |
+
# Load.
|
| 218 |
+
if flag:
|
| 219 |
+
return feature_stats_cls.load(cache_file)
|
| 220 |
+
|
| 221 |
+
# Initialize.
|
| 222 |
+
num_items = len(dataset)
|
| 223 |
+
if max_items is not None:
|
| 224 |
+
num_items = min(num_items, max_items)
|
| 225 |
+
stats = feature_stats_cls(max_items=num_items, **stats_kwargs)
|
| 226 |
+
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
| 227 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
| 228 |
+
|
| 229 |
+
# Main loop.
|
| 230 |
+
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
| 231 |
+
for batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
| 232 |
+
images = batch['image']
|
| 233 |
+
if temporal_detector:
|
| 234 |
+
images = images.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
|
| 235 |
+
|
| 236 |
+
# images = images.float() / 255
|
| 237 |
+
# images = torch.nn.functional.interpolate(images, size=(images.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
|
| 238 |
+
# images = torch.nn.functional.interpolate(images, size=(images.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
|
| 239 |
+
# images = (images * 255).to(torch.uint8)
|
| 240 |
+
else:
|
| 241 |
+
images = images.view(-1, *images.shape[-3:]) # [-1, c, h, w]
|
| 242 |
+
|
| 243 |
+
if images.shape[1] == 1:
|
| 244 |
+
images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
|
| 245 |
+
features = detector(images.to(opts.device), **detector_kwargs)
|
| 246 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 247 |
+
progress.update(stats.num_items)
|
| 248 |
+
|
| 249 |
+
# Save to cache.
|
| 250 |
+
if cache_file is not None and opts.rank == 0:
|
| 251 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
| 252 |
+
temp_file = cache_file + '.' + uuid.uuid4().hex
|
| 253 |
+
stats.save(temp_file)
|
| 254 |
+
os.replace(temp_file, cache_file) # atomic
|
| 255 |
+
return stats
|
| 256 |
+
|
| 257 |
+
#----------------------------------------------------------------------------
|
| 258 |
+
|
| 259 |
+
@torch.no_grad()
|
| 260 |
+
def compute_feature_stats_for_generator(
|
| 261 |
+
opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16,
|
| 262 |
+
batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16,
|
| 263 |
+
feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs):
|
| 264 |
+
|
| 265 |
+
if batch_gen is None:
|
| 266 |
+
batch_gen = min(batch_size, 4)
|
| 267 |
+
assert batch_size % batch_gen == 0
|
| 268 |
+
|
| 269 |
+
# Setup generator and load labels.
|
| 270 |
+
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
| 271 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
| 272 |
+
|
| 273 |
+
# Image generation func.
|
| 274 |
+
def run_generator(z, c, t):
|
| 275 |
+
img = G(z=z, c=c, t=t, **opts.G_kwargs)
|
| 276 |
+
bt, c, h, w = img.shape
|
| 277 |
+
|
| 278 |
+
if temporal_detector:
|
| 279 |
+
img = img.view(bt // num_video_frames, num_video_frames, c, h, w) # [batch_size, t, c, h, w]
|
| 280 |
+
img = img.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
|
| 281 |
+
|
| 282 |
+
# img = torch.nn.functional.interpolate(img, size=(img.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
|
| 283 |
+
# img = torch.nn.functional.interpolate(img, size=(img.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
|
| 284 |
+
|
| 285 |
+
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
| 286 |
+
return img
|
| 287 |
+
|
| 288 |
+
# JIT.
|
| 289 |
+
if jit:
|
| 290 |
+
z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
|
| 291 |
+
c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
|
| 292 |
+
t = torch.zeros([batch_gen, G.cfg.sampling.num_frames_per_video], device=opts.device)
|
| 293 |
+
run_generator = torch.jit.trace(run_generator, [z, c, t], check_trace=False)
|
| 294 |
+
|
| 295 |
+
# Initialize.
|
| 296 |
+
stats = feature_stats_cls(**stats_kwargs)
|
| 297 |
+
assert stats.max_items is not None
|
| 298 |
+
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
| 299 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
| 300 |
+
|
| 301 |
+
# Main loop.
|
| 302 |
+
while not stats.is_full():
|
| 303 |
+
images = []
|
| 304 |
+
for _i in range(batch_size // batch_gen):
|
| 305 |
+
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
| 306 |
+
cond_sample_idx = [np.random.randint(len(dataset)) for _ in range(batch_gen)]
|
| 307 |
+
c = [dataset.get_label(i) for i in cond_sample_idx]
|
| 308 |
+
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
| 309 |
+
t = [list(range(0, num_video_frames * subsample_factor, subsample_factor)) for _i in range(batch_gen)]
|
| 310 |
+
t = torch.from_numpy(np.stack(t)).pin_memory().to(opts.device)
|
| 311 |
+
images.append(run_generator(z, c, t))
|
| 312 |
+
images = torch.cat(images)
|
| 313 |
+
if images.shape[1] == 1:
|
| 314 |
+
images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
|
| 315 |
+
features = detector(images, **detector_kwargs)
|
| 316 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 317 |
+
progress.update(stats.num_items)
|
| 318 |
+
return stats
|
| 319 |
+
|
| 320 |
+
#----------------------------------------------------------------------------
|
| 321 |
+
|
| 322 |
+
def rewrite_opts_for_gen_dataset(opts):
|
| 323 |
+
"""
|
| 324 |
+
Updates dataset arguments in the opts to enable the second dataset stats computation
|
| 325 |
+
"""
|
| 326 |
+
new_opts = copy.deepcopy(opts)
|
| 327 |
+
new_opts.dataset_kwargs = new_opts.gen_dataset_kwargs
|
| 328 |
+
new_opts.cache = False
|
| 329 |
+
|
| 330 |
+
return new_opts
|
| 331 |
+
|
| 332 |
+
#----------------------------------------------------------------------------
|
src/metrics/video_inception_score.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inception Score (IS) from the paper "Improved techniques for training
|
| 2 |
+
GANs". Matches the original implementation by Salimans et al. at
|
| 3 |
+
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from . import metric_utils
|
| 7 |
+
|
| 8 |
+
#----------------------------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
NUM_FRAMES_IN_BATCH = {128: 128, 256: 128, 512: 64, 1024: 32}
|
| 11 |
+
|
| 12 |
+
#----------------------------------------------------------------------------
|
| 13 |
+
|
| 14 |
+
def compute_isv(opts, num_gen: int, num_splits: int, backbone: str):
|
| 15 |
+
if backbone == 'c3d_ucf101':
|
| 16 |
+
# Perfectly reproduced torchscript version of the original chainer checkpoint:
|
| 17 |
+
# https://github.com/pfnet-research/tgan2/blob/f892bc432da315d4f6b6ae9448f69d046ef6fe01/tgan2/models/c3d/c3d_ucf101.py
|
| 18 |
+
# It is a UCF-101-finetuned C3D model.
|
| 19 |
+
detector_url = 'https://www.dropbox.com/s/jxpu7avzdc9n97q/c3d_ucf101.pt?dl=1'
|
| 20 |
+
else:
|
| 21 |
+
raise NotImplementedError(f'Backbone {backbone} is not supported.')
|
| 22 |
+
|
| 23 |
+
num_frames = 16
|
| 24 |
+
batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution] // num_frames
|
| 25 |
+
|
| 26 |
+
if opts.generator_as_dataset:
|
| 27 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
|
| 28 |
+
gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
|
| 29 |
+
gen_opts.dataset_kwargs.load_n_consecutive = num_frames
|
| 30 |
+
gen_opts.dataset_kwargs.load_n_consecutive_random_offset = False
|
| 31 |
+
gen_opts.dataset_kwargs.subsample_factor = 1
|
| 32 |
+
gen_kwargs = dict()
|
| 33 |
+
else:
|
| 34 |
+
compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
|
| 35 |
+
gen_opts = opts
|
| 36 |
+
gen_kwargs = dict(num_video_frames=num_frames, subsample_factor=1)
|
| 37 |
+
|
| 38 |
+
gen_probs = compute_gen_stats_fn(
|
| 39 |
+
opts=gen_opts, detector_url=detector_url, detector_kwargs={},
|
| 40 |
+
capture_all=True, max_items=num_gen, temporal_detector=True, **gen_kwargs).get_all() # [num_gen, num_classes]
|
| 41 |
+
|
| 42 |
+
if opts.rank != 0:
|
| 43 |
+
return float('nan'), float('nan')
|
| 44 |
+
|
| 45 |
+
scores = []
|
| 46 |
+
np.random.RandomState(42).shuffle(gen_probs)
|
| 47 |
+
for i in range(num_splits):
|
| 48 |
+
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
| 49 |
+
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
| 50 |
+
kl = np.mean(np.sum(kl, axis=1))
|
| 51 |
+
scores.append(np.exp(kl))
|
| 52 |
+
return float(np.mean(scores)), float(np.std(scores))
|
| 53 |
+
|
| 54 |
+
#----------------------------------------------------------------------------
|
src/scripts/__init__.py
ADDED
|
File without changes
|
src/scripts/calc_metrics.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Calculate quality metrics for previous training run or pretrained network pickle."""
|
| 10 |
+
|
| 11 |
+
import sys; sys.path.extend(['.', 'src'])
|
| 12 |
+
import os
|
| 13 |
+
import re
|
| 14 |
+
import click
|
| 15 |
+
import json
|
| 16 |
+
import tempfile
|
| 17 |
+
import copy
|
| 18 |
+
import torch
|
| 19 |
+
from src import dnnlib
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
|
| 22 |
+
import legacy
|
| 23 |
+
from metrics import metric_main
|
| 24 |
+
from metrics import metric_utils
|
| 25 |
+
from src.torch_utils import training_stats
|
| 26 |
+
from src.torch_utils import custom_ops
|
| 27 |
+
from src.torch_utils import misc
|
| 28 |
+
|
| 29 |
+
#----------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def subprocess_fn(rank, args, temp_dir):
|
| 32 |
+
dnnlib.util.Logger(should_flush=True)
|
| 33 |
+
|
| 34 |
+
# Init torch.distributed.
|
| 35 |
+
if args.num_gpus > 1:
|
| 36 |
+
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
| 37 |
+
if os.name == 'nt':
|
| 38 |
+
init_method = 'file:///' + init_file.replace('\\', '/')
|
| 39 |
+
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
| 40 |
+
else:
|
| 41 |
+
init_method = f'file://{init_file}'
|
| 42 |
+
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
| 43 |
+
|
| 44 |
+
# Init torch_utils.
|
| 45 |
+
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
|
| 46 |
+
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
| 47 |
+
if rank != 0 or not args.verbose:
|
| 48 |
+
custom_ops.verbosity = 'none'
|
| 49 |
+
|
| 50 |
+
# Print network summary.
|
| 51 |
+
device = torch.device('cuda', rank)
|
| 52 |
+
torch.backends.cudnn.benchmark = True
|
| 53 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 54 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 55 |
+
G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
|
| 56 |
+
if rank == 0 and args.verbose:
|
| 57 |
+
z = torch.empty([8, G.z_dim], device=device)
|
| 58 |
+
c = torch.empty([8, G.c_dim], device=device)
|
| 59 |
+
t = torch.zeros([8, G.cfg.sampling.num_frames_per_video], device=device).long()
|
| 60 |
+
misc.print_module_summary(G, [z, c, t])
|
| 61 |
+
|
| 62 |
+
# Calculate each metric.
|
| 63 |
+
for metric in args.metrics:
|
| 64 |
+
if rank == 0 and args.verbose:
|
| 65 |
+
print(f'Calculating {metric}...')
|
| 66 |
+
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
|
| 67 |
+
result_dict = metric_main.calc_metric(
|
| 68 |
+
metric=metric,
|
| 69 |
+
G=G,
|
| 70 |
+
dataset_kwargs=args.dataset_kwargs,
|
| 71 |
+
num_gpus=args.num_gpus,
|
| 72 |
+
rank=rank,
|
| 73 |
+
device=device,
|
| 74 |
+
progress=progress,
|
| 75 |
+
cache=args.use_cache,
|
| 76 |
+
num_runs=(1 if metric == 'fid50k_full' else args.num_runs),
|
| 77 |
+
)
|
| 78 |
+
if rank == 0:
|
| 79 |
+
metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
|
| 80 |
+
if rank == 0 and args.verbose:
|
| 81 |
+
print()
|
| 82 |
+
|
| 83 |
+
# Done.
|
| 84 |
+
if rank == 0 and args.verbose:
|
| 85 |
+
print('Exiting...')
|
| 86 |
+
|
| 87 |
+
#----------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
class CommaSeparatedList(click.ParamType):
|
| 90 |
+
name = 'list'
|
| 91 |
+
|
| 92 |
+
def convert(self, value, param, ctx):
|
| 93 |
+
_ = param, ctx
|
| 94 |
+
if value is None or value.lower() == 'none' or value == '':
|
| 95 |
+
return []
|
| 96 |
+
return value.split(',')
|
| 97 |
+
|
| 98 |
+
#----------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
@click.command()
|
| 101 |
+
@click.pass_context
|
| 102 |
+
@click.option('--network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH')
|
| 103 |
+
@click.option('--networks_dir', '--networks_dir', help='Path to the experiment directory if the latest checkpoint is requested.', metavar='PATH')
|
| 104 |
+
@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
|
| 105 |
+
@click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
|
| 106 |
+
@click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
|
| 107 |
+
@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
|
| 108 |
+
@click.option('--cfg_path', help='Path to the experiments config', type=str, default="auto", metavar='PATH')
|
| 109 |
+
@click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
|
| 110 |
+
@click.option('--use_cache', help='Should we use the cache file?', type=bool, default=True, metavar='BOOL', show_default=True)
|
| 111 |
+
@click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
|
| 112 |
+
|
| 113 |
+
def calc_metrics(ctx, network_pkl, networks_dir, metrics, data, mirror, gpus, cfg_path, verbose, use_cache: bool, num_runs: int):
|
| 114 |
+
"""Calculate quality metrics for previous training run or pretrained network pickle.
|
| 115 |
+
|
| 116 |
+
Examples:
|
| 117 |
+
|
| 118 |
+
\b
|
| 119 |
+
# Previous training run: look up options automatically, save result to JSONL file.
|
| 120 |
+
python calc_metrics.py --metrics=pr50k3_full \\
|
| 121 |
+
--network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
|
| 122 |
+
|
| 123 |
+
\b
|
| 124 |
+
# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
|
| 125 |
+
python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
|
| 126 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
|
| 127 |
+
|
| 128 |
+
Available metrics:
|
| 129 |
+
|
| 130 |
+
\b
|
| 131 |
+
ADA paper:
|
| 132 |
+
fid50k_full Frechet inception distance against the full dataset.
|
| 133 |
+
kid50k_full Kernel inception distance against the full dataset.
|
| 134 |
+
pr50k3_full Precision and recall againt the full dataset.
|
| 135 |
+
is50k Inception score for CIFAR-10.
|
| 136 |
+
|
| 137 |
+
\b
|
| 138 |
+
StyleGAN and StyleGAN2 papers:
|
| 139 |
+
fid50k Frechet inception distance against 50k real images.
|
| 140 |
+
kid50k Kernel inception distance against 50k real images.
|
| 141 |
+
pr50k3 Precision and recall against 50k real images.
|
| 142 |
+
ppl2_wend Perceptual path length in W at path endpoints against full image.
|
| 143 |
+
ppl_zfull Perceptual path length in Z for full paths against cropped image.
|
| 144 |
+
ppl_wfull Perceptual path length in W for full paths against cropped image.
|
| 145 |
+
ppl_zend Perceptual path length in Z at path endpoints against cropped image.
|
| 146 |
+
ppl_wend Perceptual path length in W at path endpoints against cropped image.
|
| 147 |
+
"""
|
| 148 |
+
dnnlib.util.Logger(should_flush=True)
|
| 149 |
+
|
| 150 |
+
if network_pkl is None:
|
| 151 |
+
output_regex = "^network-snapshot-\d{6}.pkl$"
|
| 152 |
+
ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
|
| 153 |
+
# ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
|
| 154 |
+
# network_pkl = os.path.join(networks_dir, ckpts[-1])
|
| 155 |
+
metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
|
| 156 |
+
with open(metrics_file, 'r') as f:
|
| 157 |
+
snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
|
| 158 |
+
best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
|
| 159 |
+
network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
|
| 160 |
+
print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
|
| 161 |
+
# Selecting a checkpoint with the best score
|
| 162 |
+
|
| 163 |
+
# Validate arguments.
|
| 164 |
+
args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
|
| 165 |
+
if cfg_path == "auto":
|
| 166 |
+
# Assuming that `network_pkl` has the structure /path/to/experiment/output/network-X.pkl
|
| 167 |
+
output_path = os.path.dirname(network_pkl)
|
| 168 |
+
assert os.path.basename(output_path) == "output", f"Unknown path structure: {output_path}"
|
| 169 |
+
experiment_path = os.path.dirname(output_path)
|
| 170 |
+
cfg_path = os.path.join(experiment_path, 'experiment_config.yaml')
|
| 171 |
+
|
| 172 |
+
cfg = OmegaConf.load(cfg_path)
|
| 173 |
+
if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
|
| 174 |
+
ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
|
| 175 |
+
if not args.num_gpus >= 1:
|
| 176 |
+
ctx.fail('--gpus must be at least 1')
|
| 177 |
+
|
| 178 |
+
# Load network.
|
| 179 |
+
if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
|
| 180 |
+
ctx.fail('--network must point to a file or URL')
|
| 181 |
+
if args.verbose:
|
| 182 |
+
print(f'Loading network from "{network_pkl}"...')
|
| 183 |
+
with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
|
| 184 |
+
network_dict = legacy.load_network_pkl(f)
|
| 185 |
+
args.G = network_dict['G_ema'] # subclass of torch.nn.Module
|
| 186 |
+
|
| 187 |
+
from src.training.networks import Generator
|
| 188 |
+
G = args.G
|
| 189 |
+
G.cfg.z_dim = G.z_dim
|
| 190 |
+
G_new = Generator(
|
| 191 |
+
w_dim=G.cfg.w_dim,
|
| 192 |
+
mapping_kwargs=dnnlib.EasyDict(num_layers=G.cfg.get('mapping_net_n_layers', 2), cfg=G.cfg),
|
| 193 |
+
synthesis_kwargs=dnnlib.EasyDict(
|
| 194 |
+
channel_base=int(G.cfg.get('fmaps', 0.5) * 32768),
|
| 195 |
+
channel_max=G.cfg.get('channel_max', 512),
|
| 196 |
+
num_fp16_res=4,
|
| 197 |
+
conv_clamp=256,
|
| 198 |
+
),
|
| 199 |
+
cfg=G.cfg,
|
| 200 |
+
img_resolution=256,
|
| 201 |
+
img_channels=3,
|
| 202 |
+
c_dim=G.cfg.c_dim,
|
| 203 |
+
).eval()
|
| 204 |
+
G_new.load_state_dict(G.state_dict())
|
| 205 |
+
args.G = G_new
|
| 206 |
+
|
| 207 |
+
# Initialize dataset options.
|
| 208 |
+
if data is not None:
|
| 209 |
+
args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.VideoFramesFolderDataset', cfg=cfg.dataset, path=data)
|
| 210 |
+
elif network_dict['training_set_kwargs'] is not None:
|
| 211 |
+
args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
|
| 212 |
+
else:
|
| 213 |
+
ctx.fail('Could not look up dataset options; please specify --data')
|
| 214 |
+
|
| 215 |
+
# Finalize dataset options.
|
| 216 |
+
args.dataset_kwargs.resolution = args.G.img_resolution
|
| 217 |
+
args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
|
| 218 |
+
if mirror is not None:
|
| 219 |
+
args.dataset_kwargs.xflip = mirror
|
| 220 |
+
args.use_cache = use_cache
|
| 221 |
+
args.num_runs = num_runs
|
| 222 |
+
|
| 223 |
+
# Print dataset options.
|
| 224 |
+
if args.verbose:
|
| 225 |
+
print('Dataset options:')
|
| 226 |
+
print(cfg.dataset)
|
| 227 |
+
|
| 228 |
+
# Locate run dir.
|
| 229 |
+
args.run_dir = None
|
| 230 |
+
if os.path.isfile(network_pkl):
|
| 231 |
+
pkl_dir = os.path.dirname(network_pkl)
|
| 232 |
+
if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
|
| 233 |
+
args.run_dir = pkl_dir
|
| 234 |
+
|
| 235 |
+
# Launch processes.
|
| 236 |
+
if args.verbose:
|
| 237 |
+
print('Launching processes...')
|
| 238 |
+
torch.multiprocessing.set_start_method('spawn')
|
| 239 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 240 |
+
if args.num_gpus == 1:
|
| 241 |
+
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
|
| 242 |
+
else:
|
| 243 |
+
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
|
| 244 |
+
|
| 245 |
+
#----------------------------------------------------------------------------
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
calc_metrics() # pylint: disable=no-value-for-parameter
|
| 249 |
+
|
| 250 |
+
#----------------------------------------------------------------------------
|
src/scripts/calc_metrics_for_dataset.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
"""Calculate quality metrics for previous training run or pretrained network pickle."""
|
| 10 |
+
|
| 11 |
+
import sys; sys.path.extend(['.', 'src'])
|
| 12 |
+
import os
|
| 13 |
+
import click
|
| 14 |
+
import tempfile
|
| 15 |
+
import torch
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from src import dnnlib
|
| 18 |
+
|
| 19 |
+
from metrics import metric_main
|
| 20 |
+
from metrics import metric_utils
|
| 21 |
+
from src.torch_utils import training_stats
|
| 22 |
+
from src.torch_utils import custom_ops
|
| 23 |
+
|
| 24 |
+
#----------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
def subprocess_fn(rank, args, temp_dir):
|
| 27 |
+
dnnlib.util.Logger(should_flush=True)
|
| 28 |
+
|
| 29 |
+
# Init torch.distributed.
|
| 30 |
+
if args.num_gpus > 1:
|
| 31 |
+
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
| 32 |
+
if os.name == 'nt':
|
| 33 |
+
init_method = 'file:///' + init_file.replace('\\', '/')
|
| 34 |
+
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
| 35 |
+
else:
|
| 36 |
+
init_method = f'file://{init_file}'
|
| 37 |
+
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
| 38 |
+
|
| 39 |
+
# Init torch_utils.
|
| 40 |
+
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
|
| 41 |
+
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
| 42 |
+
if rank != 0 or not args.verbose:
|
| 43 |
+
custom_ops.verbosity = 'none'
|
| 44 |
+
|
| 45 |
+
# Print network summary.
|
| 46 |
+
device = torch.device('cuda', rank)
|
| 47 |
+
torch.backends.cudnn.benchmark = True
|
| 48 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 49 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 50 |
+
|
| 51 |
+
# Calculate each metric.
|
| 52 |
+
for metric in args.metrics:
|
| 53 |
+
if rank == 0 and args.verbose:
|
| 54 |
+
print(f'Calculating {metric}...')
|
| 55 |
+
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
|
| 56 |
+
result_dict = metric_main.calc_metric(
|
| 57 |
+
metric=metric,
|
| 58 |
+
dataset_kwargs=args.dataset_kwargs,
|
| 59 |
+
gen_dataset_kwargs=args.gen_dataset_kwargs,
|
| 60 |
+
generator_as_dataset=args.generator_as_dataset,
|
| 61 |
+
num_gpus=args.num_gpus,
|
| 62 |
+
rank=rank,
|
| 63 |
+
device=device,
|
| 64 |
+
progress=progress,
|
| 65 |
+
cache=args.use_cache,
|
| 66 |
+
num_runs=args.num_runs,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if rank == 0:
|
| 70 |
+
metric_main.report_metric(result_dict, run_dir=args.run_dir)
|
| 71 |
+
|
| 72 |
+
if rank == 0 and args.verbose:
|
| 73 |
+
print()
|
| 74 |
+
|
| 75 |
+
# Done.
|
| 76 |
+
if rank == 0 and args.verbose:
|
| 77 |
+
print('Exiting...')
|
| 78 |
+
|
| 79 |
+
#----------------------------------------------------------------------------
|
| 80 |
+
|
| 81 |
+
class CommaSeparatedList(click.ParamType):
|
| 82 |
+
name = 'list'
|
| 83 |
+
|
| 84 |
+
def convert(self, value, param, ctx):
|
| 85 |
+
_ = param, ctx
|
| 86 |
+
if value is None or value.lower() == 'none' or value == '':
|
| 87 |
+
return []
|
| 88 |
+
return value.split(',')
|
| 89 |
+
|
| 90 |
+
#----------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
def calc_metrics_for_dataset(ctx, metrics, real_data_path, fake_data_path, mirror, resolution, gpus, verbose, use_cache: bool, num_runs: int):
|
| 93 |
+
dnnlib.util.Logger(should_flush=True)
|
| 94 |
+
|
| 95 |
+
# Validate arguments.
|
| 96 |
+
args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, verbose=verbose)
|
| 97 |
+
if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
|
| 98 |
+
ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
|
| 99 |
+
if not args.num_gpus >= 1:
|
| 100 |
+
ctx.fail('--gpus must be at least 1')
|
| 101 |
+
|
| 102 |
+
dummy_dataset_cfg = OmegaConf.create({'max_num_frames': 10000})
|
| 103 |
+
|
| 104 |
+
# Initialize dataset options for real data.
|
| 105 |
+
args.dataset_kwargs = dnnlib.EasyDict(
|
| 106 |
+
class_name='training.dataset.VideoFramesFolderDataset',
|
| 107 |
+
path=real_data_path,
|
| 108 |
+
cfg=dummy_dataset_cfg,
|
| 109 |
+
xflip=mirror,
|
| 110 |
+
resolution=resolution,
|
| 111 |
+
use_labels=False,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Initialize dataset options for fake data.
|
| 115 |
+
args.gen_dataset_kwargs = dnnlib.EasyDict(
|
| 116 |
+
class_name='training.dataset.VideoFramesFolderDataset',
|
| 117 |
+
path=fake_data_path,
|
| 118 |
+
cfg=dummy_dataset_cfg,
|
| 119 |
+
xflip=False,
|
| 120 |
+
resolution=resolution,
|
| 121 |
+
use_labels=False,
|
| 122 |
+
)
|
| 123 |
+
args.generator_as_dataset = True
|
| 124 |
+
|
| 125 |
+
# Print dataset options.
|
| 126 |
+
if args.verbose:
|
| 127 |
+
print('Real data options:')
|
| 128 |
+
print(args.dataset_kwargs)
|
| 129 |
+
|
| 130 |
+
print('Fake data options:')
|
| 131 |
+
print(args.gen_dataset_kwargs)
|
| 132 |
+
|
| 133 |
+
# Locate run dir.
|
| 134 |
+
args.run_dir = None
|
| 135 |
+
args.use_cache = use_cache
|
| 136 |
+
args.num_runs = num_runs
|
| 137 |
+
|
| 138 |
+
# Launch processes.
|
| 139 |
+
if args.verbose:
|
| 140 |
+
print('Launching processes...')
|
| 141 |
+
torch.multiprocessing.set_start_method('spawn')
|
| 142 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 143 |
+
if args.num_gpus == 1:
|
| 144 |
+
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
|
| 145 |
+
else:
|
| 146 |
+
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
|
| 147 |
+
|
| 148 |
+
#----------------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
@click.command()
|
| 151 |
+
@click.pass_context
|
| 152 |
+
@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fvd2048_16f,fid50k_full', show_default=True)
|
| 153 |
+
@click.option('--real_data_path', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
|
| 154 |
+
@click.option('--fake_data_path', help='Generated images (directory or zip)', metavar='PATH')
|
| 155 |
+
@click.option('--mirror', help='Should we mirror the real data?', type=bool, metavar='BOOL')
|
| 156 |
+
@click.option('--resolution', help='Resolution for the source dataset', type=int, metavar='INT')
|
| 157 |
+
@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
|
| 158 |
+
@click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
|
| 159 |
+
@click.option('--use_cache', help='Use stats cache', type=bool, default=True, metavar='BOOL', show_default=True)
|
| 160 |
+
@click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
|
| 161 |
+
def calc_metrics_cli_wrapper(ctx, *args, **kwargs):
|
| 162 |
+
calc_metrics_for_dataset(ctx, *args, **kwargs)
|
| 163 |
+
|
| 164 |
+
#----------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
calc_metrics_cli_wrapper() # pylint: disable=no-value-for-parameter
|
| 168 |
+
|
| 169 |
+
#----------------------------------------------------------------------------
|
src/scripts/clip_edit.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import sys; sys.path.extend(['.', 'src', '/home/skoroki/StyleCLIP'])
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
from typing import List
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
import random
|
| 9 |
+
import yaml
|
| 10 |
+
import itertools
|
| 11 |
+
|
| 12 |
+
import torchvision
|
| 13 |
+
from torch import optim
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import click
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from omegaconf import OmegaConf
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from torchvision import utils
|
| 23 |
+
from torch import Tensor
|
| 24 |
+
import torchvision.transforms.functional as TVF
|
| 25 |
+
from torchvision.utils import save_image
|
| 26 |
+
from torch import Tensor
|
| 27 |
+
|
| 28 |
+
from src.deps.facial_recognition.model_irse import Backbone
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
import clip
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
"To edit videos with CLIP, you need to install the `clip` library. " \
|
| 35 |
+
"Please follow the instructions in https://github.com/openai/CLIP")
|
| 36 |
+
|
| 37 |
+
from src import dnnlib
|
| 38 |
+
import legacy
|
| 39 |
+
from src.scripts.project import save_edited_w
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
#----------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
|
| 45 |
+
lr_ramp = min(1, (1 - t) / rampdown)
|
| 46 |
+
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
|
| 47 |
+
lr_ramp = lr_ramp * min(1, t / rampup)
|
| 48 |
+
|
| 49 |
+
return initial_lr * lr_ramp
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class CLIPLoss(torch.nn.Module):
|
| 54 |
+
"""
|
| 55 |
+
Copy-pasted and adapted from StyleCLIP
|
| 56 |
+
"""
|
| 57 |
+
def __init__(self):
|
| 58 |
+
super(CLIPLoss, self).__init__()
|
| 59 |
+
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
| 60 |
+
#self.upsample = torch.nn.Upsample(scale_factor=7)
|
| 61 |
+
#self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
|
| 62 |
+
|
| 63 |
+
def forward(self, image, text):
|
| 64 |
+
#image = self.avg_pool(self.upsample(image))
|
| 65 |
+
#print('shape', image.shape, text.shape)
|
| 66 |
+
image = F.interpolate(image, size=(224, 224), mode='area')
|
| 67 |
+
similarity = 1 - self.model(image, text)[0] / 100
|
| 68 |
+
similarity = similarity.diag()
|
| 69 |
+
|
| 70 |
+
return similarity
|
| 71 |
+
|
| 72 |
+
#----------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
class IDLoss(nn.Module):
|
| 75 |
+
"""
|
| 76 |
+
Copy-pasted from StyleCLIP
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self):
|
| 79 |
+
super(IDLoss, self).__init__()
|
| 80 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
| 81 |
+
with dnnlib.util.open_url(Backbone.WEIGHTS_URL, verbose=True) as f:
|
| 82 |
+
ir_se50_weights = torch.load(f)
|
| 83 |
+
self.facenet.load_state_dict(ir_se50_weights)
|
| 84 |
+
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
| 85 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
| 86 |
+
self.facenet.eval()
|
| 87 |
+
self.facenet.cuda()
|
| 88 |
+
|
| 89 |
+
def extract_feats(self, x):
|
| 90 |
+
if x.shape[2] != 256:
|
| 91 |
+
x = self.pool(x)
|
| 92 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
| 93 |
+
x = self.face_pool(x)
|
| 94 |
+
x_feats = self.facenet(x)
|
| 95 |
+
return x_feats
|
| 96 |
+
|
| 97 |
+
def forward(self, y_hat, y):
|
| 98 |
+
n_samples = y.shape[0]
|
| 99 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
| 100 |
+
y_hat_feats = self.extract_feats(y_hat)
|
| 101 |
+
y_feats = y_feats.detach()
|
| 102 |
+
loss = 0
|
| 103 |
+
|
| 104 |
+
for i in range(n_samples):
|
| 105 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
| 106 |
+
loss += 1 - diff_target
|
| 107 |
+
|
| 108 |
+
return loss / n_samples
|
| 109 |
+
|
| 110 |
+
#----------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
def run_edit_optimization(
|
| 113 |
+
_sentinel=None,
|
| 114 |
+
G: nn.Module=None,
|
| 115 |
+
w_orig: Tensor=None,
|
| 116 |
+
descriptions: List[str]=None,
|
| 117 |
+
# ckpt: float="stylegan2-ffhq-config-f.pt",
|
| 118 |
+
lr: float=0.1,
|
| 119 |
+
num_steps: int=40,
|
| 120 |
+
l2_lambda: float=0.001,
|
| 121 |
+
id_lambda: float=0.005,
|
| 122 |
+
# latent_path: float=latent_path,
|
| 123 |
+
# truncation: float=0.7,
|
| 124 |
+
# save_intermediate_image_every: float=1 if create_video else 20,
|
| 125 |
+
# results_dir: float="results",
|
| 126 |
+
mask: float=None,
|
| 127 |
+
mask_lambda: float=0.0,
|
| 128 |
+
verbose: bool=False,
|
| 129 |
+
) -> Tensor:
|
| 130 |
+
assert _sentinel is None
|
| 131 |
+
# text_inputs = torch.cat([clip.tokenize(d) for d in descriptions]).to(device)
|
| 132 |
+
num_prompts = len(descriptions)
|
| 133 |
+
num_images = len(w_orig)
|
| 134 |
+
device = w_orig.device
|
| 135 |
+
|
| 136 |
+
text_inputs = clip.tokenize(descriptions).to(device) # [num_prompts, 77]
|
| 137 |
+
text_inputs = text_inputs.repeat_interleave(len(w_orig), dim=0) # [num_prompts * num_images, 77]
|
| 138 |
+
|
| 139 |
+
c = torch.zeros(num_prompts * num_images, 0, device=device)
|
| 140 |
+
ts = torch.zeros(num_prompts * num_images, 1, device=device)
|
| 141 |
+
w_orig = w_orig.repeat(num_prompts, 1, 1) # [num_prompts * num_images, num_ws, w_dim]
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
img_orig = G.synthesis(ws=w_orig, c=c, t=ts) # [num_prompts * num_images, 3, c, h, w]
|
| 145 |
+
|
| 146 |
+
w = w_orig.detach().clone() # [num_prompts * num_images, num_ws, w_dim]
|
| 147 |
+
w.requires_grad = True
|
| 148 |
+
|
| 149 |
+
if mask_lambda > 0:
|
| 150 |
+
target_image = img_orig * (1 - mask) # [num_prompts * num_images, 3, c, h, w]
|
| 151 |
+
#target_image = img_orig[:, :, -128:, :128]
|
| 152 |
+
target_image = (target_image * 0.5 + 0.5) * 255.0 # [num_prompts * num_images, 3, c, h, w]
|
| 153 |
+
if target_image.shape[2] > 256:
|
| 154 |
+
target_image = F.interpolate(target_image, size=(256, 256), mode='area')
|
| 155 |
+
target_features = vgg16(target_image, resize_images=False, return_lpips=True)
|
| 156 |
+
#dist = (target_features - synth_features).square().sum()
|
| 157 |
+
else:
|
| 158 |
+
target_features = None
|
| 159 |
+
|
| 160 |
+
clip_loss = CLIPLoss()
|
| 161 |
+
id_loss = IDLoss()
|
| 162 |
+
optimizer = optim.Adam([w], lr=lr)
|
| 163 |
+
|
| 164 |
+
if verbose:
|
| 165 |
+
pbar = tqdm(range(num_steps))
|
| 166 |
+
else:
|
| 167 |
+
pbar = range(num_steps)
|
| 168 |
+
|
| 169 |
+
for curr_iter in pbar:
|
| 170 |
+
curr_lr = get_lr(curr_iter / num_steps, lr)
|
| 171 |
+
# optimizer.param_groups[0]["lr"] = lr
|
| 172 |
+
for param_group in optimizer.param_groups:
|
| 173 |
+
param_group['lr'] = curr_lr
|
| 174 |
+
|
| 175 |
+
#img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=work_in_stylespace)
|
| 176 |
+
img_gen = G.synthesis(ws=w, c=c, t=ts) # [num_prompts * num_images, 3, c, h, w]
|
| 177 |
+
|
| 178 |
+
if mask_lambda > 0:
|
| 179 |
+
raise NotImplementedError
|
| 180 |
+
synth_image = img_gen * (1 - mask)
|
| 181 |
+
#synth_image = img_gen[:, :, -128:, :128]
|
| 182 |
+
synth_image = (synth_image * 0.5 + 0.5) * 255.0
|
| 183 |
+
if synth_image.shape[2] > 256:
|
| 184 |
+
synth_image = F.interpolate(synth_image, size=(256, 256), mode='area')
|
| 185 |
+
synth_features = vgg16(synth_image, resize_images=False, return_lpips=True)
|
| 186 |
+
mask_loss = (target_features - synth_features).square().sum()
|
| 187 |
+
else:
|
| 188 |
+
mask_loss = 0
|
| 189 |
+
|
| 190 |
+
if not mask is None:
|
| 191 |
+
img_gen = img_gen * mask.unsqueeze(0) # [num_prompts * num_images, 3, c, h, w]
|
| 192 |
+
|
| 193 |
+
c_loss = clip_loss(img_gen, text_inputs) # [num_prompts * num_images]
|
| 194 |
+
|
| 195 |
+
if id_lambda > 0:
|
| 196 |
+
i_loss = id_loss(img_gen, img_orig)
|
| 197 |
+
else:
|
| 198 |
+
i_loss = 0
|
| 199 |
+
|
| 200 |
+
l2_loss = ((w_orig - w) ** 2) # [1]
|
| 201 |
+
loss = c_loss.sum() + l2_lambda * l2_loss.sum() + id_lambda * i_loss + mask_lambda * mask_loss
|
| 202 |
+
|
| 203 |
+
optimizer.zero_grad()
|
| 204 |
+
loss.backward()
|
| 205 |
+
optimizer.step()
|
| 206 |
+
|
| 207 |
+
if verbose:
|
| 208 |
+
pbar.set_description((f"loss: {loss.item():.4f};"))
|
| 209 |
+
|
| 210 |
+
final_result = torch.stack([img_orig, img_gen]) # [2, num_prompts * num_images, c, h, w]
|
| 211 |
+
|
| 212 |
+
return final_result, w
|
| 213 |
+
|
| 214 |
+
# x, new_w = main(args)
|
| 215 |
+
|
| 216 |
+
# pair = torch.cat([img for img in x], dim=2)
|
| 217 |
+
# TVF.to_pil_image((pair.cpu().detach() * 0.5 + 0.5).clamp(0, 1))
|
| 218 |
+
|
| 219 |
+
#----------------------------------------------------------------------------
|
| 220 |
+
|
| 221 |
+
@click.command()
|
| 222 |
+
@click.pass_context
|
| 223 |
+
@click.option('--network_pkl', help='Network pickle filename', metavar='PATH')
|
| 224 |
+
@click.option('--networks_dir', help='Network pickles directory', metavar='PATH')
|
| 225 |
+
# @click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True)
|
| 226 |
+
# @click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
| 227 |
+
# @click.option('--same_motion_codes', type=bool, help='Should we use the same motion codes for all videos?', default=False, show_default=True)
|
| 228 |
+
@click.option('--w_dir', help='A directory leading to latent codes.', type=str, required=False, metavar='DIR')
|
| 229 |
+
@click.option('--results_dir', help='A directory to save the results in.', type=str, required=False, metavar='DIR')
|
| 230 |
+
@click.option('--truncation_psi', help='If we use new w, what truncation to use.', type=float, required=False, metavar='FLOAT', default=1.0)
|
| 231 |
+
@click.option('--num_w', help='If we use new w, how many to sample?', type=int, required=False, metavar='FLOAT', default=16)
|
| 232 |
+
@click.option('--prompts', help='A path to prompts or a string of prompts.', type=str, required=True, metavar='DIR')
|
| 233 |
+
@click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
|
| 234 |
+
@click.option('--zero_periods', help='Zero-out periods predictor?', default=False, type=bool, metavar='BOOL')
|
| 235 |
+
@click.option('--num_weights_to_slice', help='Number of high-frequency coords to remove.', default=0, type=int, metavar='INT')
|
| 236 |
+
@click.option('--num_steps', help='Number of the optimization steps to perform.', default=40, type=int, metavar='INT')
|
| 237 |
+
@click.option('--stack_samples', help='When saving, should we stack samples together?', default=False, type=bool, metavar='BOOL')
|
| 238 |
+
# l2_lambda=0.001,
|
| 239 |
+
# id_lambda=0.005,
|
| 240 |
+
# l2_lambda=0.0005,
|
| 241 |
+
# id_lambda=0.0,
|
| 242 |
+
@click.option('--l2_lambda', help='L2 loss coef', default=0.001, type=float, metavar='FLOAT')
|
| 243 |
+
@click.option('--id_lambda', help='ID loss coef', default=0.005, type=float, metavar='FLOAT')
|
| 244 |
+
@click.option('--lr', help='Learning rate', default=0.1, type=float, metavar='FLOAT')
|
| 245 |
+
@click.option('--mask_lambda', help='If we use a mask, specify the loss coef', default=0.0, type=float, metavar='FLOAT')
|
| 246 |
+
@click.option('--use_id_lambda', help='Should we use id lambda in HPO?', default=False, type=bool, metavar='BOOL')
|
| 247 |
+
|
| 248 |
+
def main(
|
| 249 |
+
ctx: click.Context,
|
| 250 |
+
network_pkl: str,
|
| 251 |
+
networks_dir: str,
|
| 252 |
+
seed: int,
|
| 253 |
+
w_dir: str,
|
| 254 |
+
results_dir: str,
|
| 255 |
+
truncation_psi: float,
|
| 256 |
+
num_w: int,
|
| 257 |
+
# save_as_mp4: bool,
|
| 258 |
+
# video_len: int,
|
| 259 |
+
# fps: int,
|
| 260 |
+
# as_grids: bool,
|
| 261 |
+
zero_periods: bool,
|
| 262 |
+
num_weights_to_slice: int,
|
| 263 |
+
num_steps: int,
|
| 264 |
+
stack_samples: bool,
|
| 265 |
+
l2_lambda: float,
|
| 266 |
+
id_lambda: float,
|
| 267 |
+
lr: float,
|
| 268 |
+
prompts: str,
|
| 269 |
+
mask_lambda: float,
|
| 270 |
+
use_id_lambda: bool,
|
| 271 |
+
):
|
| 272 |
+
if network_pkl is None:
|
| 273 |
+
output_regex = "^network-snapshot-\d{6}.pkl$"
|
| 274 |
+
ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
|
| 275 |
+
# ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
|
| 276 |
+
# network_pkl = os.path.join(networks_dir, ckpts[-1])
|
| 277 |
+
metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
|
| 278 |
+
with open(metrics_file, 'r') as f:
|
| 279 |
+
snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
|
| 280 |
+
best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
|
| 281 |
+
network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
|
| 282 |
+
print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
|
| 283 |
+
# Selecting a checkpoint with the best score
|
| 284 |
+
else:
|
| 285 |
+
assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir"
|
| 286 |
+
|
| 287 |
+
print('Loading networks from "%s"...' % network_pkl, end='')
|
| 288 |
+
device = torch.device('cuda')
|
| 289 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
| 290 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
|
| 291 |
+
print('Loaded!')
|
| 292 |
+
|
| 293 |
+
random.seed(seed)
|
| 294 |
+
np.random.seed(seed)
|
| 295 |
+
torch.manual_seed(seed)
|
| 296 |
+
|
| 297 |
+
if zero_periods:
|
| 298 |
+
G.synthesis.motion_encoder.time_encoder.periods_predictor.weight.data.zero_()
|
| 299 |
+
|
| 300 |
+
if num_weights_to_slice > 0:
|
| 301 |
+
G.synthesis.motion_encoder.time_encoder.weights[:, -num_weights_to_slice:] = 0.0
|
| 302 |
+
|
| 303 |
+
# description = "Bright sunny sky and mountains far away"
|
| 304 |
+
# experiment_type = 'edit' #@param ['edit', 'free_generation']
|
| 305 |
+
# mask = torch.zeros(3, 256, 256, device=device)
|
| 306 |
+
# mask[:, :, 64+32 : 128+32] = 1.0
|
| 307 |
+
# mask[:, :-128, :] = 1.0
|
| 308 |
+
# mask[:, :, 128:] = 1.0
|
| 309 |
+
|
| 310 |
+
if w_dir is None:
|
| 311 |
+
print('Sampling new w')
|
| 312 |
+
z = torch.randn(num_w, G.z_dim, device=device)
|
| 313 |
+
c = torch.zeros(len(z), G.c_dim, device=device)
|
| 314 |
+
w_orig = G.mapping(z=z, c=c, truncation_psi=truncation_psi)
|
| 315 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 316 |
+
torch.save(w_orig.cpu(), f'{results_dir}_w_orig.pt')
|
| 317 |
+
w_save_dir = os.path.join(results_dir, 'w_edit')
|
| 318 |
+
samples_save_dir = os.path.join(results_dir, 'edited_samples')
|
| 319 |
+
else:
|
| 320 |
+
w_paths = sorted([os.path.join(w_dir, f) for f in os.listdir(w_dir) if f.endswith('_w.pt')])
|
| 321 |
+
w_names = [os.path.basename(f) for f in w_paths]
|
| 322 |
+
w_orig = [torch.load(f) for f in w_paths]
|
| 323 |
+
w_orig = torch.stack(w_orig).to(device) # [num_images, num_ws, w_dim]
|
| 324 |
+
w_save_dir = f'{w_dir}_edited_w'
|
| 325 |
+
samples_save_dir = f'{w_dir}_edited_samples'
|
| 326 |
+
|
| 327 |
+
os.makedirs(w_save_dir, exist_ok=True)
|
| 328 |
+
os.makedirs(samples_save_dir, exist_ok=True)
|
| 329 |
+
|
| 330 |
+
print(f'Loading prompts from file: {prompts}')
|
| 331 |
+
with open(prompts, 'r') as f:
|
| 332 |
+
descs_dict = yaml.load(f)
|
| 333 |
+
edit_names, descriptions = list(zip(*descs_dict.items()))
|
| 334 |
+
edit_names = edit_names
|
| 335 |
+
descriptions = descriptions
|
| 336 |
+
|
| 337 |
+
del id_lambda, num_steps, l2_lambda
|
| 338 |
+
l2_lambdas = [1000000.0, 0.0025, 0.001, 0.00025, 0.0005, 0.0001]
|
| 339 |
+
if use_id_lambda:
|
| 340 |
+
id_lambdas = [0.005, 0.0025, 0.001, 0.00025, 0.0005, 0.0001, 0.0]
|
| 341 |
+
else:
|
| 342 |
+
id_lambdas = [0.0]
|
| 343 |
+
all_num_steps = [40]
|
| 344 |
+
|
| 345 |
+
for curr_edit_name, curr_prompt in zip(edit_names, descriptions):
|
| 346 |
+
all_images = []
|
| 347 |
+
all_w_edited = []
|
| 348 |
+
|
| 349 |
+
for l2_lambda, id_lambda, num_steps in tqdm(list(itertools.product(l2_lambdas, id_lambdas, all_num_steps)), desc=f'Performing HPO for {curr_edit_name}'):
|
| 350 |
+
final_image, w_edited = run_edit_optimization(
|
| 351 |
+
G=G,
|
| 352 |
+
w_orig=w_orig,
|
| 353 |
+
descriptions=[curr_prompt],
|
| 354 |
+
# ckpt="stylegan2-ffhq-config-f.pt",
|
| 355 |
+
lr=lr,
|
| 356 |
+
num_steps=num_steps,
|
| 357 |
+
l2_lambda=l2_lambda,
|
| 358 |
+
id_lambda=id_lambda,
|
| 359 |
+
mask_lambda=mask_lambda,
|
| 360 |
+
verbose=False,
|
| 361 |
+
# latent_path=latent_path,
|
| 362 |
+
# truncation=0.7,
|
| 363 |
+
# mask=None,
|
| 364 |
+
# mask_lambda=0.1,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
all_images.extend((final_image[1].cpu() * 0.5 + 0.5).clamp(0, 1))
|
| 368 |
+
all_w_edited.append({
|
| 369 |
+
"w_edit": w_edited.cpu(),
|
| 370 |
+
"l2_lambda": l2_lambda,
|
| 371 |
+
"id_lambda": id_lambda,
|
| 372 |
+
"num_steps": num_steps,
|
| 373 |
+
"prompt": curr_prompt,
|
| 374 |
+
"edit_name": curr_edit_name,
|
| 375 |
+
})
|
| 376 |
+
|
| 377 |
+
# img_names = [f'{w_name}_{edit_name}' for edit_name in edit_names for w_name in w_names]
|
| 378 |
+
|
| 379 |
+
# save_edited_w(
|
| 380 |
+
# G=G,
|
| 381 |
+
# w_outdir = f'{w_dir}_edited',
|
| 382 |
+
# samples_outdir = f'{w_dir}_projected_samples',
|
| 383 |
+
# img_names=img_names,
|
| 384 |
+
# stack_samples=stack_samples,
|
| 385 |
+
# all_w = w_edited,
|
| 386 |
+
# all_motion_z = None,
|
| 387 |
+
# stacked_samples_out_path = f'{w_dir}_edited_samples.png'
|
| 388 |
+
# )
|
| 389 |
+
|
| 390 |
+
torch.save(all_w_edited, f"{w_save_dir}/{curr_edit_name}_w.pt")
|
| 391 |
+
grid = utils.make_grid(torch.stack(all_images), nrow=len(w_orig))
|
| 392 |
+
print('savig intp', f"{samples_save_dir}/{curr_edit_name}.png")
|
| 393 |
+
save_image(grid, f"{samples_save_dir}/{curr_edit_name}.png")
|
| 394 |
+
|
| 395 |
+
print('Done!')
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
#----------------------------------------------------------------------------
|
| 399 |
+
|
| 400 |
+
if __name__ == "__main__":
|
| 401 |
+
main() # pylint: disable=no-value-for-parameter
|
| 402 |
+
|
| 403 |
+
#----------------------------------------------------------------------------
|
src/scripts/construct_static_videos_dataset.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Takes a dataset directory and repeats the frames to include only a random frame from each video
|
| 3 |
+
This is needed to calculate same-frame FVD and DiFID
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import argparse
|
| 8 |
+
from typing import List
|
| 9 |
+
import shutil
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def construct_static_videos_dataset(videos_dir: os.PathLike, max_len: int=None, output_dir: os.PathLike=None, force_len: int=None):
|
| 14 |
+
output_dir = output_dir if not output_dir is None else f'{videos_dir}_freeze'
|
| 15 |
+
clips_paths = [os.path.join(videos_dir, d) for d in os.listdir(videos_dir)]
|
| 16 |
+
|
| 17 |
+
print(f'Saving into {output_dir}')
|
| 18 |
+
|
| 19 |
+
for video_idx, clip_path in enumerate(tqdm(clips_paths)):
|
| 20 |
+
frames_paths = os.listdir(clip_path)
|
| 21 |
+
frame_to_repeat = random.choice(frames_paths)
|
| 22 |
+
curr_output_dir = os.path.join(output_dir, f'{video_idx:05d}')
|
| 23 |
+
os.makedirs(curr_output_dir, exist_ok=True)
|
| 24 |
+
num_frames_to_create = force_len if not force_len is None else min(len(frames_paths), max_len)
|
| 25 |
+
|
| 26 |
+
for i in range(num_frames_to_create):
|
| 27 |
+
ext = os.path.splitext(frame_to_repeat)[1].lower()
|
| 28 |
+
target_file_path = os.path.join(curr_output_dir, f'{i:06d}{ext}')
|
| 29 |
+
shutil.copy(os.path.join(clip_path, frame_to_repeat), target_file_path)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument('-d', '--directory', type=str, help='Directory with video frames')
|
| 35 |
+
parser.add_argument('-o', '--output_dir', type=None, help='Where to save the file?.')
|
| 36 |
+
parser.add_argument('-l', '--max_len', type=int, help='Max video length')
|
| 37 |
+
parser.add_argument('-fl', '--force_len', type=int, help='Force video length')
|
| 38 |
+
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
construct_static_videos_dataset(
|
| 42 |
+
videos_dir=args.directory,
|
| 43 |
+
max_len=args.max_len,
|
| 44 |
+
output_dir=args.output_dir,
|
| 45 |
+
force_len=args.force_len,
|
| 46 |
+
)
|
src/scripts/convert_video_to_dataset.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts a dataset of mp4 videos into a dataset of video frames
|
| 3 |
+
I.e. a directory of mp4 files becomes a directory of directories of frames
|
| 4 |
+
This speeds up loading during training because we do not need
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
from typing import List
|
| 8 |
+
import argparse
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from multiprocessing import Pool
|
| 11 |
+
from collections import Counter
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import torchvision.transforms.functional as TVF
|
| 16 |
+
from moviepy.editor import VideoFileClip
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def convert_videos_into_dataset(video_path: os.PathLike, target_dir: os.PathLike, num_chunks: int, chunk_size: int, start_frame: int, target_size: int, force_fps: int):
|
| 21 |
+
assert (num_chunks is None) or (chunk_size is None), "Cant use both num_chunks and chunk_size"
|
| 22 |
+
|
| 23 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 24 |
+
clip = VideoFileClip(video_path)
|
| 25 |
+
fps = clip.fps if force_fps is None else force_fps
|
| 26 |
+
num_frames_total = int(np.floor(clip.duration * fps)) - start_frame
|
| 27 |
+
|
| 28 |
+
if num_chunks is None:
|
| 29 |
+
num_chunks = num_frames_total // chunk_size
|
| 30 |
+
else:
|
| 31 |
+
chunk_size = num_frames_total // num_chunks
|
| 32 |
+
|
| 33 |
+
num_frames_to_save = chunk_size * num_chunks
|
| 34 |
+
|
| 35 |
+
print(f'Processing the video at {fps} fps. {num_frames_total} frames in total. We have {num_chunks} videos of {chunk_size} frames each.')
|
| 36 |
+
|
| 37 |
+
current_chunk_idx = 0
|
| 38 |
+
frame_idx = -start_frame
|
| 39 |
+
curr_chunk_dir = os.path.join(target_dir, f'{current_chunk_idx:06d}')
|
| 40 |
+
|
| 41 |
+
for frame in tqdm(clip.iter_frames(fps=fps), total=num_frames_total + start_frame):
|
| 42 |
+
if frame_idx >= 0:
|
| 43 |
+
os.makedirs(curr_chunk_dir, exist_ok=True)
|
| 44 |
+
frame = Image.fromarray(frame)
|
| 45 |
+
frame = TVF.center_crop(frame, output_size=min(frame.size))
|
| 46 |
+
frame = TVF.resize(frame, size=target_size, interpolation=Image.LANCZOS)
|
| 47 |
+
frame.save(os.path.join(curr_chunk_dir, f'{frame_idx % chunk_size:06d}.jpg'), q=95)
|
| 48 |
+
|
| 49 |
+
frame_idx += 1
|
| 50 |
+
if frame_idx % chunk_size == 0 and frame_idx > 0:
|
| 51 |
+
current_chunk_idx += 1
|
| 52 |
+
curr_chunk_dir = os.path.join(target_dir, f'{current_chunk_idx:06d}')
|
| 53 |
+
|
| 54 |
+
if frame_idx == num_frames_to_save:
|
| 55 |
+
# Stop here so not to have a partially-filled chunk
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
chunk_sizes = [len(os.listdir(d)) for d in listdir_full_paths(target_dir)]
|
| 59 |
+
assert len(set(chunk_sizes)) == 1, f"Bad chunk sizes: {set(chunk_sizes)}"
|
| 60 |
+
|
| 61 |
+
print('Finished successfully!')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def listdir_full_paths(d) -> List[os.PathLike]:
|
| 65 |
+
return sorted([os.path.join(d, x) for x in os.listdir(d)])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
parser = argparse.ArgumentParser(description='Convert a long video into a dataset of frame dirs')
|
| 70 |
+
parser.add_argument('-s', '--source_video_path', type=str, help='Path to the source video')
|
| 71 |
+
parser.add_argument('-t', '--target_dir', type=str, help='Where to save the new dataset')
|
| 72 |
+
parser.add_argument('-n', '--num_chunks', type=int, help='How many samples should there be in the dataset?')
|
| 73 |
+
parser.add_argument('-cs', '--chunk_size', type=int, help='Each video length. Should be used separately from num_chunks')
|
| 74 |
+
parser.add_argument('-sf', '--start_frame', type=int, default=0, help='Start frame idx. Should we skip several frames?')
|
| 75 |
+
parser.add_argument('--target_size', type=int, default=128, help='What size should we resize to?')
|
| 76 |
+
parser.add_argument('--force_fps', type=int, help='What fps should we run videos with?')
|
| 77 |
+
args = parser.parse_args()
|
| 78 |
+
|
| 79 |
+
convert_videos_into_dataset(
|
| 80 |
+
video_path=args.source_video_path,
|
| 81 |
+
target_dir=args.target_dir,
|
| 82 |
+
num_chunks=args.num_chunks,
|
| 83 |
+
chunk_size=args.chunk_size,
|
| 84 |
+
start_frame=args.start_frame,
|
| 85 |
+
target_size=args.target_size,
|
| 86 |
+
force_fps=args.force_fps,
|
| 87 |
+
)
|
src/scripts/convert_videos_to_frames.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts a dataset of mp4 videos into a dataset of video frames
|
| 3 |
+
I.e. a directory of mp4 files becomes a directory of directories of frames
|
| 4 |
+
This speeds up loading during training because we do not need
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
from typing import List
|
| 8 |
+
import argparse
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from multiprocessing import Pool
|
| 11 |
+
from collections import Counter
|
| 12 |
+
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torchvision.transforms.functional as TVF
|
| 15 |
+
from moviepy.editor import VideoFileClip
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def convert_videos_to_frames(source_dir: os.PathLike, target_dir: os.PathLike, num_workers: int, video_ext: str, **process_video_kwargs):
|
| 20 |
+
broken_clips_dir = f'{target_dir}_broken_clips'
|
| 21 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 22 |
+
os.makedirs(broken_clips_dir, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
clips_paths = [cp for cp in listdir_full_paths(source_dir) if cp.endswith(video_ext)]
|
| 25 |
+
clips_fps = []
|
| 26 |
+
tasks_kwargs = [dict(
|
| 27 |
+
clip_path=cp,
|
| 28 |
+
target_dir=target_dir,
|
| 29 |
+
broken_clips_dir=broken_clips_dir,
|
| 30 |
+
**process_video_kwargs,
|
| 31 |
+
) for cp in clips_paths]
|
| 32 |
+
pool = Pool(processes=num_workers)
|
| 33 |
+
|
| 34 |
+
for fps in tqdm(pool.imap_unordered(task_proxy, tasks_kwargs), total=len(clips_paths)):
|
| 35 |
+
clips_fps.append(fps)
|
| 36 |
+
|
| 37 |
+
print(f'All possible fps: {Counter(clips_fps).most_common()}')
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def task_proxy(kwargs):
|
| 41 |
+
"""I do not know, how to pass several arguments to a pool job..."""
|
| 42 |
+
return process_video(**kwargs)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def process_video(
|
| 46 |
+
clip_path: os.PathLike, target_dir: os.PathLike, force_fps: int=None, target_size: int=None,
|
| 47 |
+
broken_clips_dir: os.PathLike=None, compute_fps_only: bool=False) -> int:
|
| 48 |
+
|
| 49 |
+
clip_name = os.path.basename(clip_path)
|
| 50 |
+
clip_name = clip_name[:clip_name.rfind('.')]
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
clip = VideoFileClip(clip_path)
|
| 54 |
+
except KeyboardInterrupt:
|
| 55 |
+
raise
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f'Coudnt process clip: {clip_path}')
|
| 58 |
+
if not broken_clips_dir is None:
|
| 59 |
+
Path(os.path.join(broken_clips_dir, clip_name)).touch()
|
| 60 |
+
return 0
|
| 61 |
+
|
| 62 |
+
if compute_fps_only:
|
| 63 |
+
return clip.fps
|
| 64 |
+
|
| 65 |
+
fps = clip.fps if force_fps is None else force_fps
|
| 66 |
+
clip_target_dir = os.path.join(target_dir, clip_name)
|
| 67 |
+
clip_target_dir = clip_target_dir.replace('#', '_')
|
| 68 |
+
os.makedirs(clip_target_dir, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
frame_idx = 0
|
| 71 |
+
for frame in clip.iter_frames(fps=fps):
|
| 72 |
+
frame = Image.fromarray(frame)
|
| 73 |
+
if not target_size is None:
|
| 74 |
+
frame = TVF.resize(frame, size=target_size, interpolation=Image.LANCZOS)
|
| 75 |
+
frame = TVF.center_crop(frame, output_size=(target_size, target_size))
|
| 76 |
+
frame.save(os.path.join(clip_target_dir, f'{frame_idx:06d}.jpg'), q=95)
|
| 77 |
+
frame_idx += 1
|
| 78 |
+
|
| 79 |
+
return clip.fps
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def listdir_full_paths(d) -> List[os.PathLike]:
|
| 83 |
+
return sorted([os.path.join(d, x) for x in os.listdir(d)])
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
parser = argparse.ArgumentParser(description='Convert a dataset of mp4 files into a dataset of individual frames')
|
| 88 |
+
parser.add_argument('-s', '--source_dir', type=str, help='Path to the source dataset')
|
| 89 |
+
parser.add_argument('-t', '--target_dir', type=str, help='Where to save the new dataset')
|
| 90 |
+
parser.add_argument('--video_ext', type=str, default='mp4', help='Video extension')
|
| 91 |
+
parser.add_argument('--target_size', type=int, default=128, help='What size should we resize to?')
|
| 92 |
+
parser.add_argument('--force_fps', type=int, help='What fps should we run videos with?')
|
| 93 |
+
parser.add_argument('--num_workers', type=int, default=8, help='Number of processes to launch')
|
| 94 |
+
parser.add_argument('--compute_fps_only', action='store_true', help='Should we just compute fps?')
|
| 95 |
+
args = parser.parse_args()
|
| 96 |
+
|
| 97 |
+
convert_videos_to_frames(
|
| 98 |
+
source_dir=args.source_dir,
|
| 99 |
+
target_dir=args.target_dir,
|
| 100 |
+
target_size=args.target_size,
|
| 101 |
+
force_fps=args.force_fps,
|
| 102 |
+
num_workers=args.num_workers,
|
| 103 |
+
video_ext=args.video_ext,
|
| 104 |
+
compute_fps_only=args.compute_fps_only,
|
| 105 |
+
)
|
src/scripts/crop_video_dataset.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import argparse
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def crop_video_dataset(source_dir: str, max_num_frames: int=None, slice_n_left_frames: int=0, resize: int=None, target_dir: str=None):
|
| 12 |
+
dataset_name = os.path.basename(source_dir)
|
| 13 |
+
if target_dir is None:
|
| 14 |
+
max_num_frames_prefix = "" if max_num_frames is None else f"_cut{max_num_frames}"
|
| 15 |
+
slice_prefix = "" if slice_n_left_frames == 0 else f"_slice{slice_n_left_frames}"
|
| 16 |
+
new_dataset_name = f"{dataset_name}{max_num_frames_prefix}{slice_prefix}"
|
| 17 |
+
target_dir = os.path.join(os.path.dirname(source_dir), new_dataset_name)
|
| 18 |
+
all_clips_paths = listdir_full_paths(source_dir)
|
| 19 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 20 |
+
slice_proportions = []
|
| 21 |
+
|
| 22 |
+
total_num_frames = 0
|
| 23 |
+
|
| 24 |
+
for source_clip_dir in tqdm(all_clips_paths, desc=f'Cropping the dataset into {target_dir}'):
|
| 25 |
+
all_frames = listdir_full_paths(source_clip_dir)
|
| 26 |
+
if len(all_frames) == 0:
|
| 27 |
+
print(f'{source_clip_dir} is empty. Skipping it.')
|
| 28 |
+
continue
|
| 29 |
+
target_clip_dir = os.path.join(target_dir, os.path.basename(source_clip_dir))
|
| 30 |
+
os.makedirs(target_clip_dir, exist_ok=True)
|
| 31 |
+
total_num_frames += len(all_frames)
|
| 32 |
+
slice_proportions.append(slice_n_left_frames / len(all_frames))
|
| 33 |
+
all_frames = all_frames[slice_n_left_frames:]
|
| 34 |
+
|
| 35 |
+
if not max_num_frames is None:
|
| 36 |
+
all_frames = all_frames[:max_num_frames]
|
| 37 |
+
|
| 38 |
+
for source_frame_path in all_frames:
|
| 39 |
+
target_frame_path = os.path.join(target_clip_dir, os.path.basename(source_frame_path))
|
| 40 |
+
|
| 41 |
+
if resize is None:
|
| 42 |
+
shutil.copy(source_frame_path, target_frame_path)
|
| 43 |
+
else:
|
| 44 |
+
assert target_frame_path.endswith('.jpg')
|
| 45 |
+
Image.open(source_frame_path).resize((resize, resize), resample=Image.LANCZOS).save(target_frame_path, q=95)
|
| 46 |
+
|
| 47 |
+
print(f'Done! Sliced {np.mean(slice_proportions) * 100.0 : .02f}% on average. {len(all_clips_paths) * slice_n_left_frames / total_num_frames * 100.0 : .02f}% of total num frames.')
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def listdir_full_paths(d) -> List[os.PathLike]:
|
| 51 |
+
return sorted([os.path.join(d, x) for x in os.listdir(d)])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
parser = argparse.ArgumentParser(description='Crops a video dataset temporally into several frames')
|
| 56 |
+
parser.add_argument('source_dir', type=str, help='Path to the dataset')
|
| 57 |
+
parser.add_argument('-n', '--max_num_frames', type=int, default=None, help='Number of frames to preserve')
|
| 58 |
+
parser.add_argument('--slice_n_left_frames', type=int, default=0, help='Number of frames to slice from the left')
|
| 59 |
+
parser.add_argument('--resize', type=int, default=None, help='Should we resize the dataset')
|
| 60 |
+
parser.add_argument('--target_dir', type=str, default=None, help='Should we resize the dataset')
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
crop_video_dataset(
|
| 64 |
+
source_dir=args.source_dir,
|
| 65 |
+
max_num_frames=args.max_num_frames,
|
| 66 |
+
slice_n_left_frames=args.slice_n_left_frames,
|
| 67 |
+
resize=args.resize,
|
| 68 |
+
target_dir=args.target_dir,
|
| 69 |
+
)
|
src/scripts/frames_to_video_grid.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Converts a directory of video frames into an mp4-grid
|
| 3 |
+
"""
|
| 4 |
+
import sys; sys.path.extend(['.'])
|
| 5 |
+
import os
|
| 6 |
+
import argparse
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
import torchvision.transforms.functional as TVF
|
| 13 |
+
from torchvision import utils
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import torchvision
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def frames_to_video_grid(videos_dir: os.PathLike, num_videos: int, length: int, fps: int, output_path: os.PathLike, select_random: bool=False, random_seed: int=None):
|
| 20 |
+
clips_paths = [os.path.join(videos_dir, d) for d in os.listdir(videos_dir)]
|
| 21 |
+
|
| 22 |
+
# bad_idx = [0, 9, 11, 16]
|
| 23 |
+
# clips_paths = [c for i, c in enumerate(clips_paths) if not i in bad_idx]
|
| 24 |
+
|
| 25 |
+
if select_random:
|
| 26 |
+
random.seed(random_seed)
|
| 27 |
+
clips_paths = random.sample(clips_paths, k=num_videos)
|
| 28 |
+
else:
|
| 29 |
+
clips_paths = clips_paths[:num_videos]
|
| 30 |
+
videos = [read_first_n_frames(d, length) for d in tqdm(clips_paths, desc='Reading data...')] # [num_videos, length, c, h, w]
|
| 31 |
+
videos = [fill_with_black_squares(v, length) for v in tqdm(videos, desc='Adding empty frames')] # [num_videos, length, c, h, w]
|
| 32 |
+
frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [video_len, num_videos, c, h, w]
|
| 33 |
+
frame_grids = [utils.make_grid(fs, nrow=int(np.ceil(np.sqrt(num_videos)))) for fs in tqdm(frame_grids, desc='Making grids')]
|
| 34 |
+
|
| 35 |
+
if os.path.dirname(output_path) != "":
|
| 36 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 37 |
+
frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C]
|
| 38 |
+
torchvision.io.write_video(output_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'})
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def read_first_n_frames(d: os.PathLike, num_frames: int) -> Tensor:
|
| 42 |
+
images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))[:num_frames]]
|
| 43 |
+
images = [TVF.to_tensor(x) for x in images]
|
| 44 |
+
|
| 45 |
+
return torch.stack(images)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def fill_with_black_squares(video, desired_len: int) -> Tensor:
|
| 49 |
+
if len(video) >= desired_len:
|
| 50 |
+
return video
|
| 51 |
+
|
| 52 |
+
return torch.cat([
|
| 53 |
+
video,
|
| 54 |
+
torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1),
|
| 55 |
+
], dim=0)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
parser.add_argument('-d', '--directory', type=str, help='Directory with video frames')
|
| 61 |
+
parser.add_argument('-n', '--num_videos', type=int, help='Number of videos to consider')
|
| 62 |
+
parser.add_argument('-l', '--length', type=int, help='Video length (in frames)')
|
| 63 |
+
parser.add_argument('--fps', type=int, default=25, help='FPS to save with.')
|
| 64 |
+
parser.add_argument('-o', '--output_path', type=str, help='Where to save the file?.')
|
| 65 |
+
parser.add_argument('--select_random', action='store_true', help='Select videos at random?')
|
| 66 |
+
parser.add_argument('--random_seed', type=int, default=None, help='Random seed when selecting videos at random')
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
frames_to_video_grid(
|
| 71 |
+
videos_dir=args.directory,
|
| 72 |
+
num_videos=args.num_videos,
|
| 73 |
+
length=args.length,
|
| 74 |
+
fps=args.fps,
|
| 75 |
+
output_path=args.output_path,
|
| 76 |
+
select_random=args.select_random,
|
| 77 |
+
random_seed=args.random_seed,
|
| 78 |
+
)
|
src/scripts/generate.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generates a dataset of images using pretrained network pickle."""
|
| 2 |
+
|
| 3 |
+
import sys; sys.path.extend(['.', 'src'])
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import click
|
| 10 |
+
from src import dnnlib
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
|
| 16 |
+
import src.legacy as legacy
|
| 17 |
+
from src.training.logging import generate_videos, save_video_frames_as_mp4, save_video_frames_as_frames_parallel
|
| 18 |
+
|
| 19 |
+
torch.set_grad_enabled(False)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
#----------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
@click.command()
|
| 25 |
+
@click.pass_context
|
| 26 |
+
@click.option('--network_pkl', help='Network pickle filename', metavar='PATH')
|
| 27 |
+
@click.option('--networks_dir', help='Network pickles directory. Selects a checkpoint from it automatically based on the fvd2048_16f metric.', metavar='PATH')
|
| 28 |
+
@click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True)
|
| 29 |
+
@click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
| 30 |
+
@click.option('--num_videos', type=int, help='Number of images to generate', default=50000, show_default=True)
|
| 31 |
+
@click.option('--batch_size', type=int, help='Batch size to use for generation', default=32, show_default=True)
|
| 32 |
+
@click.option('--moco_decomposition', type=bool, help='Should we do content/motion decomposition (available only for `--as_grids 1` generation)?', default=False, show_default=True)
|
| 33 |
+
@click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
|
| 34 |
+
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
|
| 35 |
+
@click.option('--save_as_mp4', help='Should we save as independent frames or mp4?', type=bool, default=False, metavar='BOOL')
|
| 36 |
+
@click.option('--video_len', help='Number of frames to generate', type=int, default=16, metavar='INT')
|
| 37 |
+
@click.option('--fps', help='FPS for mp4 saving', type=int, default=25, metavar='INT')
|
| 38 |
+
@click.option('--as_grids', help='Save videos as grids', type=bool, default=False, metavar='BOOl')
|
| 39 |
+
@click.option('--time_offset', help='Additional time offset', default=0, type=int, metavar='INT')
|
| 40 |
+
@click.option('--dataset_path', help='Dataset path. In case we want to use the conditioning signal.', default="", type=str, metavar='PATH')
|
| 41 |
+
@click.option('--hydra_cfg_path', help='Config path', default="", type=str, metavar='PATH')
|
| 42 |
+
@click.option('--slowmo_coef', help='Increase this value if you want to produce slow-motion videos.', default=1, type=int, metavar='INT')
|
| 43 |
+
def generate(
|
| 44 |
+
ctx: click.Context,
|
| 45 |
+
network_pkl: str,
|
| 46 |
+
networks_dir: str,
|
| 47 |
+
truncation_psi: float,
|
| 48 |
+
noise_mode: str,
|
| 49 |
+
num_videos: int,
|
| 50 |
+
batch_size: int,
|
| 51 |
+
moco_decomposition: bool,
|
| 52 |
+
seed: int,
|
| 53 |
+
outdir: str,
|
| 54 |
+
save_as_mp4: bool,
|
| 55 |
+
video_len: int,
|
| 56 |
+
fps: int,
|
| 57 |
+
as_grids: bool,
|
| 58 |
+
time_offset: int,
|
| 59 |
+
dataset_path: os.PathLike,
|
| 60 |
+
hydra_cfg_path: os.PathLike,
|
| 61 |
+
slowmo_coef: int,
|
| 62 |
+
):
|
| 63 |
+
if network_pkl is None:
|
| 64 |
+
# output_regex = "^network-snapshot-\d{6}.pkl$"
|
| 65 |
+
# ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
|
| 66 |
+
# ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
|
| 67 |
+
# network_pkl = os.path.join(networks_dir, ckpts[-1])
|
| 68 |
+
ckpt_select_metric = 'fvd2048_16f'
|
| 69 |
+
metrics_file = os.path.join(networks_dir, f'metric-{ckpt_select_metric}.jsonl')
|
| 70 |
+
with open(metrics_file, 'r') as f:
|
| 71 |
+
snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
|
| 72 |
+
best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results'][ckpt_select_metric])[0]
|
| 73 |
+
network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
|
| 74 |
+
print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results'][ckpt_select_metric])
|
| 75 |
+
# Selecting a checkpoint with the best score
|
| 76 |
+
else:
|
| 77 |
+
assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir"
|
| 78 |
+
|
| 79 |
+
if moco_decomposition:
|
| 80 |
+
assert as_grids, f"Content/motion decomposition is available only when we generate as grids."
|
| 81 |
+
assert batch_size == num_videos, "Same motion is supported only for batch_size == num_videos"
|
| 82 |
+
|
| 83 |
+
print('Loading networks from "%s"...' % network_pkl)
|
| 84 |
+
device = torch.device('cuda')
|
| 85 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
| 86 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
|
| 87 |
+
|
| 88 |
+
os.makedirs(outdir, exist_ok=True)
|
| 89 |
+
|
| 90 |
+
random.seed(seed)
|
| 91 |
+
np.random.seed(seed)
|
| 92 |
+
torch.manual_seed(seed)
|
| 93 |
+
|
| 94 |
+
all_z = torch.randn(num_videos, G.z_dim, device=device) # [curr_batch_size, z_dim]
|
| 95 |
+
if dataset_path and G.c_dim > 0:
|
| 96 |
+
hydra_cfg_path = hydra_cfg_path or os.path.join(networks_dir, '..', "experiment_config.yaml")
|
| 97 |
+
hydra_cfg = OmegaConf.load(hydra_cfg_path)
|
| 98 |
+
training_set_kwargs = dnnlib.EasyDict(
|
| 99 |
+
class_name='training.dataset.VideoFramesFolderDataset',
|
| 100 |
+
path=dataset_path, cfg=hydra_cfg.dataset, use_labels=True, max_size=None, xflip=False)
|
| 101 |
+
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs)
|
| 102 |
+
all_c = [training_set.get_label(random.choice(range(len(training_set)))) for _ in range(num_videos)] # [num_videos, c_dim]
|
| 103 |
+
all_c = torch.from_numpy(np.array(all_c)).to(device) # [num_videos, c_dim]
|
| 104 |
+
elif G.c_dim > 0:
|
| 105 |
+
warnings.warn('Assuming that the conditioning is one-hot!')
|
| 106 |
+
c_idx = torch.randint(low=0, high=G.c_dim, size=(num_videos, 1), device=device)
|
| 107 |
+
all_c = torch.zeros(num_videos, G.c_dim, device=device) # [num_videos, c_dim]
|
| 108 |
+
all_c.scatter_(1, c_idx, 1)
|
| 109 |
+
else:
|
| 110 |
+
all_c = torch.zeros(num_videos, G.c_dim, device=device) # [num_videos, c_dim]
|
| 111 |
+
ts = time_offset + torch.arange(video_len, device=device).float().unsqueeze(0).repeat(batch_size, 1) / slowmo_coef # [batch_size, video_len]
|
| 112 |
+
if moco_decomposition:
|
| 113 |
+
num_rows = num_cols = int(np.sqrt(num_videos))
|
| 114 |
+
motion_z = G.synthesis.motion_encoder(c=all_c[:num_rows], t=ts[:num_rows])['motion_z'] # [1, *motion_dims]
|
| 115 |
+
motion_z = motion_z.repeat_interleave(num_cols, dim=0) # [batch_size, *motion_dims]
|
| 116 |
+
|
| 117 |
+
all_z = all_z[:num_cols].repeat(num_rows, 1) # [num_videos, z_dim]
|
| 118 |
+
all_c = all_c[:num_cols].repeat(num_rows, 1) # [num_videos, z_dim]
|
| 119 |
+
else:
|
| 120 |
+
motion_z = None
|
| 121 |
+
|
| 122 |
+
# Generate images.
|
| 123 |
+
for batch_idx in tqdm(range((num_videos + batch_size - 1) // batch_size), desc='Generating videos'):
|
| 124 |
+
curr_batch_size = batch_size if batch_size * (batch_idx + 1) <= num_videos else num_videos % batch_size
|
| 125 |
+
z = all_z[batch_idx * batch_size:batch_idx * batch_size + curr_batch_size] # [curr_batch_size, z_dim]
|
| 126 |
+
c = all_c[batch_idx * batch_size:batch_idx * batch_size + curr_batch_size] # [curr_batch_size, c_dim]
|
| 127 |
+
videos = generate_videos(
|
| 128 |
+
G, z, c, ts, motion_z=motion_z, noise_mode=noise_mode,
|
| 129 |
+
truncation_psi=truncation_psi, as_grids=as_grids, batch_size_num_frames=128)
|
| 130 |
+
|
| 131 |
+
if as_grids:
|
| 132 |
+
videos = [videos]
|
| 133 |
+
|
| 134 |
+
for video_idx, video in enumerate(videos):
|
| 135 |
+
if save_as_mp4:
|
| 136 |
+
save_path = os.path.join(outdir, f'{batch_idx * batch_size + video_idx:06d}.mp4')
|
| 137 |
+
save_video_frames_as_mp4(video, fps, save_path)
|
| 138 |
+
else:
|
| 139 |
+
save_dir = os.path.join(outdir, f'{batch_idx * batch_size + video_idx:06d}')
|
| 140 |
+
video = (video * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy() # [video_len, h, w, c]
|
| 141 |
+
save_video_frames_as_frames_parallel(video, save_dir, time_offset=time_offset, num_processes=8)
|
| 142 |
+
|
| 143 |
+
#----------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
generate() # pylint: disable=no-value-for-parameter
|
| 147 |
+
|
| 148 |
+
#----------------------------------------------------------------------------
|
src/scripts/preprocess_ffs.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file preprocesses FaceForensics dataset by cropping it
|
| 3 |
+
Copied from https://github.com/pfnet-research/tgan2/blob/master/scripts/make_face_forensics.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
from typing import List
|
| 9 |
+
from multiprocessing import Pool
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
import cv2
|
| 13 |
+
# import h5py
|
| 14 |
+
import imageio
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def parse_videos(source_dir, splits: List[str], categories: List[dir]):
|
| 21 |
+
results = []
|
| 22 |
+
for split in splits:
|
| 23 |
+
for category in categories:
|
| 24 |
+
target_dir = os.path.join(source_dir, split, category)
|
| 25 |
+
filenames = sorted(os.listdir(target_dir))
|
| 26 |
+
for filename in filenames:
|
| 27 |
+
results.append({
|
| 28 |
+
'split': split,
|
| 29 |
+
'category': category,
|
| 30 |
+
'filename': filename,
|
| 31 |
+
'filepath': os.path.join(split, category, filename),
|
| 32 |
+
})
|
| 33 |
+
return pandas.DataFrame(results)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def crop(img, left, right, top, bottom, margin):
|
| 37 |
+
cols = right - left
|
| 38 |
+
rows = bottom - top
|
| 39 |
+
if cols < rows:
|
| 40 |
+
padding = rows - cols
|
| 41 |
+
left -= padding // 2
|
| 42 |
+
right += (padding // 2) + (padding % 2)
|
| 43 |
+
cols = right - left
|
| 44 |
+
else:
|
| 45 |
+
padding = cols - rows
|
| 46 |
+
top -= padding // 2
|
| 47 |
+
bottom += (padding // 2) + (padding % 2)
|
| 48 |
+
rows = bottom - top
|
| 49 |
+
assert(rows == cols)
|
| 50 |
+
return img[top:bottom, left:right]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def job_proxy(kwargs):
|
| 54 |
+
process_and_save_video(**kwargs)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def process_and_save_video(video_path: os.PathLike, mask_path: os.PathLike, img_size: int, wide_crop: bool, output_dir: os.PathLike):
|
| 58 |
+
try:
|
| 59 |
+
video = process_video(video_path, mask_path, img_size=img_size, wide_crop=wide_crop)
|
| 60 |
+
except KeyboardInterrupt:
|
| 61 |
+
raise
|
| 62 |
+
except:
|
| 63 |
+
print(f'Couldnt process {video_path}')
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# if os.path.isdir(output_dir) and len(os.listdir(output_dir)) > 0:
|
| 69 |
+
# return
|
| 70 |
+
|
| 71 |
+
for i, frame in enumerate(video):
|
| 72 |
+
frame = frame.transpose(1, 2, 0)
|
| 73 |
+
Image.fromarray(frame).save(os.path.join(output_dir, f'{i:06d}.jpg'), q=95)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def process_video(video_path, mask_path, img_size, threshold=5, margin=0.02, wide_crop: bool=False):
|
| 77 |
+
video_reader = imageio.get_reader(video_path)
|
| 78 |
+
mask_reader = imageio.get_reader(mask_path)
|
| 79 |
+
assert(video_reader.get_length() == mask_reader.get_length())
|
| 80 |
+
|
| 81 |
+
# Searching for the widest crop which would work for the whole video
|
| 82 |
+
if wide_crop:
|
| 83 |
+
left_most = float('inf')
|
| 84 |
+
top_most = float('inf')
|
| 85 |
+
right_most = float('-inf')
|
| 86 |
+
bottom_most = float('-inf')
|
| 87 |
+
|
| 88 |
+
for img, mask in zip(video_reader, mask_reader):
|
| 89 |
+
hist = (255 - mask).astype(np.float64).sum(axis=2)
|
| 90 |
+
horiz_hist = np.where(hist.mean(axis=0) > threshold)[0]
|
| 91 |
+
vert_hist = np.where(hist.mean(axis=1) > threshold)[0]
|
| 92 |
+
left, right = horiz_hist[0], horiz_hist[-1]
|
| 93 |
+
top, bottom = vert_hist[0], vert_hist[-1]
|
| 94 |
+
left_most = min(left_most, left)
|
| 95 |
+
top_most = min(top_most, top)
|
| 96 |
+
right_most = max(right_most, right)
|
| 97 |
+
bottom_most = max(bottom_most, bottom)
|
| 98 |
+
|
| 99 |
+
video = []
|
| 100 |
+
for img, mask in zip(video_reader, mask_reader):
|
| 101 |
+
if wide_crop:
|
| 102 |
+
left, right, top, bottom = left_most, right_most, top_most, bottom_most
|
| 103 |
+
else:
|
| 104 |
+
hist = (255 - mask).astype(np.float64).sum(axis=2)
|
| 105 |
+
horiz_hist = np.where(hist.mean(axis=0) > threshold)[0]
|
| 106 |
+
vert_hist = np.where(hist.mean(axis=1) > threshold)[0]
|
| 107 |
+
left, right = horiz_hist[0], horiz_hist[-1]
|
| 108 |
+
top, bottom = vert_hist[0], vert_hist[-1]
|
| 109 |
+
|
| 110 |
+
dst_img = crop(img, left, right, top, bottom, margin)
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
dst_img = cv2.resize(
|
| 114 |
+
dst_img, (img_size, img_size),
|
| 115 |
+
interpolation=cv2.INTER_LANCZOS4).transpose(2, 0, 1)
|
| 116 |
+
video.append(dst_img)
|
| 117 |
+
except KeyboardInterrupt:
|
| 118 |
+
raise
|
| 119 |
+
except:
|
| 120 |
+
print(img.shape, dst_img.shape, left, right, top, bottom)
|
| 121 |
+
|
| 122 |
+
T = len(video)
|
| 123 |
+
video = np.concatenate(video).reshape(T, 3, img_size, img_size)
|
| 124 |
+
return video
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# def count_frames(path):
|
| 128 |
+
# reader = imageio.get_reader(path)
|
| 129 |
+
# n_frames = 0
|
| 130 |
+
# while True:
|
| 131 |
+
# try:
|
| 132 |
+
# img = reader.get_next_data()
|
| 133 |
+
# except IndexError as e:
|
| 134 |
+
# break
|
| 135 |
+
# else:
|
| 136 |
+
# n_frames += 1
|
| 137 |
+
# return n_frames
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def main():
|
| 141 |
+
parser = argparse.ArgumentParser()
|
| 142 |
+
parser.add_argument('--source_dir', type=str, default='data/FaceForensics_compressed')
|
| 143 |
+
parser.add_argument('--output_dir', type=str, default='data/ffs_processed')
|
| 144 |
+
parser.add_argument('--img_size', type=int, default=256)
|
| 145 |
+
parser.add_argument('--num_workers', type=int, default=8)
|
| 146 |
+
parser.add_argument('--wide_crop', action='store_true', help="Should we crop each frame independently (this makes a video shaking)?")
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
|
| 149 |
+
# splits = ['train', 'val', 'test']
|
| 150 |
+
# categories = ['original', 'mask', 'altered']
|
| 151 |
+
splits = ['train']
|
| 152 |
+
categories = ['original', 'mask']
|
| 153 |
+
df = parse_videos(args.source_dir, splits, categories)
|
| 154 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 155 |
+
|
| 156 |
+
for split in splits:
|
| 157 |
+
target_frame = df[df['split'] == split]
|
| 158 |
+
filenames = target_frame['filename'].unique()
|
| 159 |
+
|
| 160 |
+
# print('Count # of frames')
|
| 161 |
+
# rets = []
|
| 162 |
+
# for i, filename in enumerate(filenames):
|
| 163 |
+
# fn_frame = target_frame[target_frame['filename'] == filename]
|
| 164 |
+
# video_path = os.path.join(
|
| 165 |
+
# args.source_dir, fn_frame[fn_frame['category'] == 'original'].iloc[0]['filepath'])
|
| 166 |
+
# rets.append(p.apply_async(count_frames, args=(video_path,)))
|
| 167 |
+
# n_frames = 0
|
| 168 |
+
# for ret in tqdm(rets):
|
| 169 |
+
# n_frames += ret.get()
|
| 170 |
+
# print('# of frames: {}'.format(n_frames))
|
| 171 |
+
|
| 172 |
+
# h5file = h5py.File(os.path.join(args.output_dir, '{}.h5'.format(split)), 'w')
|
| 173 |
+
# dset = h5file.create_dataset('image', (n_frames, 3, args.img_size, args.img_size), dtype=np.uint8)
|
| 174 |
+
# conf = []
|
| 175 |
+
# start = 0
|
| 176 |
+
|
| 177 |
+
pool = Pool(processes=args.num_workers)
|
| 178 |
+
job_kwargs_list = []
|
| 179 |
+
|
| 180 |
+
for i, filename in enumerate(filenames):
|
| 181 |
+
fn_frame = target_frame[target_frame['filename'] == filename]
|
| 182 |
+
video_path = os.path.join(args.source_dir, fn_frame[fn_frame['category'] == 'original'].iloc[0]['filepath'])
|
| 183 |
+
mask_path = os.path.join(args.source_dir, fn_frame[fn_frame['category'] == 'mask'].iloc[0]['filepath'])
|
| 184 |
+
|
| 185 |
+
job_kwargs_list.append(dict(
|
| 186 |
+
video_path=video_path,
|
| 187 |
+
mask_path=mask_path,
|
| 188 |
+
img_size=args.img_size,
|
| 189 |
+
wide_crop=args.wide_crop,
|
| 190 |
+
output_dir=os.path.join(args.output_dir, filename[:filename.rfind('.')]),
|
| 191 |
+
))
|
| 192 |
+
|
| 193 |
+
for _ in tqdm(pool.imap_unordered(job_proxy, job_kwargs_list), desc=f'Processing {split}', total=len(job_kwargs_list)):
|
| 194 |
+
pass
|
| 195 |
+
# T = len(video)
|
| 196 |
+
#dset[start:(start + T)] = video
|
| 197 |
+
# conf.append({'start': start, 'end': (start + T)})
|
| 198 |
+
# start += T
|
| 199 |
+
# conf = pandas.DataFrame(conf)
|
| 200 |
+
# conf.to_json(os.path.join(args.output_dir, '{}.json'.format(split)), orient='records')
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == '__main__':
|
| 204 |
+
main()
|
src/scripts/profile_model.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script computes imgs/sec for a generator in the eval mode
|
| 3 |
+
for different batch sizes
|
| 4 |
+
"""
|
| 5 |
+
import sys; sys.path.extend(['..', '.', 'src'])
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import hydra
|
| 12 |
+
from hydra.experimental import initialize
|
| 13 |
+
from omegaconf import DictConfig, OmegaConf
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import torch.autograd.profiler as profiler
|
| 16 |
+
|
| 17 |
+
from src import dnnlib
|
| 18 |
+
from src.infra.utils import recursive_instantiate
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DEVICE = 'cuda'
|
| 22 |
+
BATCH_SIZES = [32]
|
| 23 |
+
NUM_WARMUP_ITERS = 5
|
| 24 |
+
NUM_PROFILE_ITERS = 25
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def instantiate_G(cfg: DictConfig) -> nn.Module:
|
| 28 |
+
G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
|
| 29 |
+
G_kwargs.synthesis_kwargs.channel_base = int(cfg.model.generator.get('fmaps', 0.5) * 32768)
|
| 30 |
+
G_kwargs.synthesis_kwargs.channel_max = 512
|
| 31 |
+
G_kwargs.mapping_kwargs.num_layers = cfg.model.generator.get('mapping_net_n_layers', 2)
|
| 32 |
+
if cfg.get('num_fp16_res', 0) > 0:
|
| 33 |
+
G_kwargs.synthesis_kwargs.num_fp16_res = cfg.num_fp16_res
|
| 34 |
+
G_kwargs.synthesis_kwargs.conv_clamp = 256
|
| 35 |
+
G_kwargs.cfg = cfg.model.generator
|
| 36 |
+
G_kwargs.c_dim = 0
|
| 37 |
+
G_kwargs.img_resolution = cfg.get('resolution', 256)
|
| 38 |
+
G_kwargs.img_channels = 3
|
| 39 |
+
|
| 40 |
+
G = dnnlib.util.construct_class_by_name(**G_kwargs).eval().requires_grad_(False).to(DEVICE)
|
| 41 |
+
|
| 42 |
+
return G
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def profile_for_batch_size(G: nn.Module, cfg: DictConfig, batch_size: int):
|
| 47 |
+
z = torch.randn(batch_size, G.z_dim, device=DEVICE)
|
| 48 |
+
c = torch.zeros(batch_size, G.c_dim, device=DEVICE)
|
| 49 |
+
t = torch.zeros(batch_size, 2, device=DEVICE)
|
| 50 |
+
times = []
|
| 51 |
+
|
| 52 |
+
for i in tqdm(range(NUM_WARMUP_ITERS), desc='Warming up'):
|
| 53 |
+
torch.cuda.synchronize()
|
| 54 |
+
fake_img = G(z, c=c, t=t).contiguous()
|
| 55 |
+
y = fake_img[0, 0, 0, 0].item() # sync
|
| 56 |
+
torch.cuda.synchronize()
|
| 57 |
+
|
| 58 |
+
time.sleep(1)
|
| 59 |
+
|
| 60 |
+
torch.cuda.reset_peak_memory_stats()
|
| 61 |
+
|
| 62 |
+
with profiler.profile(record_shapes=True, use_cuda=True) as prof:
|
| 63 |
+
for i in tqdm(range(NUM_PROFILE_ITERS), desc='Profiling'):
|
| 64 |
+
torch.cuda.synchronize()
|
| 65 |
+
start_time = time.time()
|
| 66 |
+
with profiler.record_function("forward"):
|
| 67 |
+
fake_img = G(z, c=c, t=t).contiguous()
|
| 68 |
+
y = fake_img[0, 0, 0, 0].item() # sync
|
| 69 |
+
torch.cuda.synchronize()
|
| 70 |
+
times.append(time.time() - start_time)
|
| 71 |
+
|
| 72 |
+
torch.cuda.empty_cache()
|
| 73 |
+
num_imgs_processed = len(times) * batch_size
|
| 74 |
+
total_time_spent = np.sum(times)
|
| 75 |
+
bandwidth = num_imgs_processed / total_time_spent
|
| 76 |
+
summary = prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)
|
| 77 |
+
|
| 78 |
+
print(f'[Batch size: {batch_size}] Mean: {np.mean(times):.05f}s/it. Std: {np.std(times):.05f}s')
|
| 79 |
+
print(f'[Batch size: {batch_size}] Imgs/sec: {bandwidth:.03f}')
|
| 80 |
+
print(f'[Batch size: {batch_size}] Max mem: {torch.cuda.max_memory_allocated(DEVICE) / 2**30:<6.2f} gb')
|
| 81 |
+
|
| 82 |
+
return bandwidth, summary
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@hydra.main(config_path="../../configs", config_name="config.yaml")
|
| 86 |
+
def profile(cfg: DictConfig):
|
| 87 |
+
recursive_instantiate(cfg)
|
| 88 |
+
G = instantiate_G(cfg)
|
| 89 |
+
bandwidths = []
|
| 90 |
+
summaries = []
|
| 91 |
+
print(f'Number of parameters: {sum(p.numel() for p in G.parameters())}')
|
| 92 |
+
|
| 93 |
+
for batch_size in BATCH_SIZES:
|
| 94 |
+
bandwidth, summary = profile_for_batch_size(G, cfg, batch_size)
|
| 95 |
+
bandwidths.append(bandwidth)
|
| 96 |
+
summaries.append(summary)
|
| 97 |
+
|
| 98 |
+
best_batch_size_idx = int(np.argmax(bandwidths))
|
| 99 |
+
print(f'------------ Best batch size is {BATCH_SIZES[best_batch_size_idx]} ------------')
|
| 100 |
+
print(summaries[best_batch_size_idx])
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == '__main__':
|
| 104 |
+
profile()
|
src/scripts/project.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Given a dataset of images, it (optionally crops it) and embeds into the model
|
| 3 |
+
Also optionally generates random videos from the found w
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys; sys.path.extend(['.', 'src'])
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
from typing import List, Optional, Callable
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import click
|
| 16 |
+
from src import dnnlib
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torchvision import utils
|
| 24 |
+
from torch import Tensor
|
| 25 |
+
import torchvision.transforms.functional as TVF
|
| 26 |
+
from torchvision.utils import save_image
|
| 27 |
+
|
| 28 |
+
import legacy
|
| 29 |
+
from src.training.logging import generate_videos, save_video_frames_as_mp4, save_video_frames_as_frames
|
| 30 |
+
from src.torch_utils import misc
|
| 31 |
+
|
| 32 |
+
#----------------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
def project(
|
| 35 |
+
_sentinel=None,
|
| 36 |
+
G: Callable=None,
|
| 37 |
+
vgg16: nn.Module=None,
|
| 38 |
+
target_images: List[Tensor]=None,
|
| 39 |
+
device: str='cuda',
|
| 40 |
+
use_w_init: bool=False,
|
| 41 |
+
use_motion_init: bool=False,
|
| 42 |
+
w_avg_samples = 10000,
|
| 43 |
+
num_steps = 1000,
|
| 44 |
+
initial_learning_rate = 0.1,
|
| 45 |
+
initial_noise_factor = 0.05,
|
| 46 |
+
noise_ramp_length = 0.75,
|
| 47 |
+
lr_rampdown_length = 0.25,
|
| 48 |
+
lr_rampup_length = 0.05,
|
| 49 |
+
#regularize_noise_weight = 1e5,
|
| 50 |
+
regularize_noise_weight = 0.0001,
|
| 51 |
+
motion_reg_type: str=None,
|
| 52 |
+
):
|
| 53 |
+
num_videos = len(target_images)
|
| 54 |
+
|
| 55 |
+
# misc.assert_shape(target_images, [None, G.img_channels, G.img_resolution, G.img_resolution])
|
| 56 |
+
G = G.eval().requires_grad_(False).to(device) # type: ignore
|
| 57 |
+
|
| 58 |
+
c = torch.zeros(num_videos, G.c_dim, device=device)
|
| 59 |
+
ts = torch.zeros(num_videos, 1, device=device)
|
| 60 |
+
|
| 61 |
+
# Compute w stats.
|
| 62 |
+
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
|
| 63 |
+
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
|
| 64 |
+
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
|
| 65 |
+
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
|
| 66 |
+
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
|
| 67 |
+
|
| 68 |
+
# img_mean = G.synthesis(
|
| 69 |
+
# ws=torch.from_numpy(w_avg).repeat(1, G.num_ws, 1).to(device),
|
| 70 |
+
# c=c[0], t=ts[[0]],
|
| 71 |
+
# )
|
| 72 |
+
# img_mean = (img_mean * 0.5 + 0.5).cpu().detach()
|
| 73 |
+
# TVF.to_pil_image(img_mean[0]).save('/tmp/data/mean.png')
|
| 74 |
+
# print('saved!')
|
| 75 |
+
|
| 76 |
+
# Load VGG16 feature detector.
|
| 77 |
+
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
| 78 |
+
with dnnlib.util.open_url(url) as f:
|
| 79 |
+
vgg16 = torch.jit.load(f).eval().to(device)
|
| 80 |
+
|
| 81 |
+
# Features for target image.
|
| 82 |
+
target_features = []
|
| 83 |
+
for img in target_images:
|
| 84 |
+
img = img.to(device).to(torch.float32).unsqueeze(0) * 255.0
|
| 85 |
+
if img.shape[2] > 256:
|
| 86 |
+
img = F.interpolate(img, size=(256, 256), mode='area')
|
| 87 |
+
target_features.append(vgg16(img, resize_images=False, return_lpips=True).squeeze(0))
|
| 88 |
+
target_features = torch.stack(target_features) # [num_images, lpips_dim]
|
| 89 |
+
|
| 90 |
+
if use_w_init:
|
| 91 |
+
w_opt = find_w_init() # [num_videos, 1, w_dim]
|
| 92 |
+
w_opt = w_opt.detach().requires_grad_(True) # [num_videos, num_ws, w_dim]
|
| 93 |
+
else:
|
| 94 |
+
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
|
| 95 |
+
w_opt = w_opt.repeat(num_videos, G.num_ws, 1).detach().requires_grad_(True) # [num_videos, num_ws, w_dim]
|
| 96 |
+
|
| 97 |
+
# w_opt_to_ws = lambda w_opt: torch.cat([w_opt[:, [0]].repeat(1, G.num_ws // 2, 1), w_opt[:, 1:]], dim=1)
|
| 98 |
+
|
| 99 |
+
# Trying a lot of motions to find which one works best
|
| 100 |
+
if use_motion_init:
|
| 101 |
+
motion_z_opt = select_motions(motion_codes)
|
| 102 |
+
else:
|
| 103 |
+
motion_z_opt = G.synthesis.motion_encoder(c=c, t=ts)['motion_z']
|
| 104 |
+
# motion_z_opt.data = torch.randn_like(motion_z_opt.data) * 1e-3
|
| 105 |
+
|
| 106 |
+
motion_z_opt.requires_grad_(True)
|
| 107 |
+
|
| 108 |
+
w_result = torch.zeros([num_steps] + list(w_opt.shape), dtype=torch.float32, device=device)
|
| 109 |
+
# optimizer = torch.optim.Adam([w_opt] + [motion_z_opt], betas=(0.9, 0.999), lr=initial_learning_rate)
|
| 110 |
+
optimizer = torch.optim.Adam([w_opt], betas=(0.9, 0.999), lr=initial_learning_rate)
|
| 111 |
+
|
| 112 |
+
for step in tqdm(range(num_steps)):
|
| 113 |
+
# Learning rate schedule.
|
| 114 |
+
t = step / num_steps
|
| 115 |
+
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
|
| 116 |
+
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
|
| 117 |
+
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
|
| 118 |
+
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
|
| 119 |
+
lr = initial_learning_rate * lr_ramp
|
| 120 |
+
|
| 121 |
+
for param_group in optimizer.param_groups:
|
| 122 |
+
param_group['lr'] = lr
|
| 123 |
+
|
| 124 |
+
# Synth images from opt_w.
|
| 125 |
+
w_noise = torch.randn_like(w_opt) * w_noise_scale
|
| 126 |
+
ws = w_opt + w_noise
|
| 127 |
+
#ws = w_opt_to_ws(w_opt + w_noise)
|
| 128 |
+
#ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
|
| 129 |
+
#synth_images = G.synthesis(ws, c=c, t=ts, motion_z=motion_z_opt + torch.randn_like(motion_z_opt) * w_noise_scale)
|
| 130 |
+
synth_images = G.synthesis(ws, c=c, t=ts, motion_z=motion_z_opt)
|
| 131 |
+
#synth_images = G.synthesis(ws, c=c, t=ts)
|
| 132 |
+
|
| 133 |
+
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
|
| 134 |
+
synth_images = (synth_images * 0.5 + 0.5) * 255.0
|
| 135 |
+
if synth_images.shape[2] > 256:
|
| 136 |
+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
| 137 |
+
|
| 138 |
+
# Features for synth images.
|
| 139 |
+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
|
| 140 |
+
dist = (target_features - synth_features).square().sum()
|
| 141 |
+
|
| 142 |
+
# Noise regularization.
|
| 143 |
+
if motion_reg_type is None:
|
| 144 |
+
reg_loss = 0.0
|
| 145 |
+
elif motion_reg_type == "norm":
|
| 146 |
+
reg_loss = motion_z_opt.norm(dim=2).mean()
|
| 147 |
+
elif motion_reg_type == "dist":
|
| 148 |
+
reg_loss = motion_z_opt.mean().pow(2) + (motion_z_opt.var() - 1).pow(2)
|
| 149 |
+
elif motion_reg_type == "sg2":
|
| 150 |
+
for v in noise_bufs.values():
|
| 151 |
+
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
|
| 152 |
+
while True:
|
| 153 |
+
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
|
| 154 |
+
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
|
| 155 |
+
if noise.shape[2] <= 8:
|
| 156 |
+
break
|
| 157 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
| 158 |
+
else:
|
| 159 |
+
raise NotImplementedError(f"Uknown motion_reg_type: {motion_reg_type}")
|
| 160 |
+
|
| 161 |
+
loss = dist + reg_loss * regularize_noise_weight
|
| 162 |
+
|
| 163 |
+
# Step
|
| 164 |
+
optimizer.zero_grad(set_to_none=True)
|
| 165 |
+
loss.backward()
|
| 166 |
+
optimizer.step()
|
| 167 |
+
|
| 168 |
+
# Save projected W for each optimization step.
|
| 169 |
+
w_result[step] = w_opt.detach()
|
| 170 |
+
|
| 171 |
+
# Normalize noise.
|
| 172 |
+
# with torch.no_grad():
|
| 173 |
+
# for buf in motion_z_opt.values():
|
| 174 |
+
# buf -= buf.mean()
|
| 175 |
+
# buf *= buf.square().mean().rsqrt()
|
| 176 |
+
|
| 177 |
+
return w_result, motion_z_opt
|
| 178 |
+
|
| 179 |
+
#----------------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
def find_motions_init(G: Callable, vgg16: nn.Module, target_features: Tensor, c: Tensor, t: Tensor, num_motions_to_try: int=128):
|
| 183 |
+
motions = G.synthesis.motion_encoder(
|
| 184 |
+
c=c.repeat_interleave(num_motions_to_try, dim=0),
|
| 185 |
+
t=t.repeat_interleave(num_motions_to_try, dim=0))['motion_z'] # [num_videos * num_motions_to_try, ...]
|
| 186 |
+
|
| 187 |
+
synth_images = G.synthesis(
|
| 188 |
+
w_opt.repeat_interleave(num_motions_to_try, dim=0),
|
| 189 |
+
c=c.repeat_interleave(num_motions_to_try, dim=0),
|
| 190 |
+
t=t.repeat_interleave(num_motions_to_try, dim=0),
|
| 191 |
+
motion_z=motions)
|
| 192 |
+
|
| 193 |
+
if synth_images.shape[2] > 256:
|
| 194 |
+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
| 195 |
+
|
| 196 |
+
synth_images = (synth_images * 0.5 + 0.5) * 255.0
|
| 197 |
+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) # [num_videos * num_motions_to_try, ...]
|
| 198 |
+
dist = (target_features.repeat_interleave(num_motions_to_try, dim=0) - synth_features).square().sum(dim=1) # [num_videos * num_motions_to_try]
|
| 199 |
+
best_motions_idx = dist.view(num_videos, num_motions_to_try).argmin(dim=1) # [num_videos]
|
| 200 |
+
motion_z_opt = motions[best_motions_idx] # [num_videos, ...]
|
| 201 |
+
|
| 202 |
+
return motion_z_opt
|
| 203 |
+
|
| 204 |
+
#----------------------------------------------------------------------------
|
| 205 |
+
|
| 206 |
+
@torch.no_grad()
|
| 207 |
+
def find_w_init(G: Callable, vgg16: nn.Module, target_features: Tensor, c: Tensor, t: Tensor, l: Tensor, num_w_to_try: int=128):
|
| 208 |
+
z = torch.randn(num_videos * num_w_to_try, G.z_dim, device=device)
|
| 209 |
+
w = G.mapping(z=z, c=None) # [N, L, C]
|
| 210 |
+
|
| 211 |
+
synth_images = G.synthesis(
|
| 212 |
+
ws=w,
|
| 213 |
+
c=c.repeat_interleave(num_w_to_try, dim=0),
|
| 214 |
+
t=t.repeat_interleave(num_w_to_try, dim=0))
|
| 215 |
+
if synth_images.shape[2] > 256:
|
| 216 |
+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
| 217 |
+
synth_images = (synth_images * 0.5 + 0.5) * 255.0
|
| 218 |
+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) # [num_videos * num_motions_to_try, ...]
|
| 219 |
+
dist = (target_features.repeat_interleave(num_w_to_try, dim=0) - synth_features).square().sum(dim=1) # [num_videos * num_motions_to_try]
|
| 220 |
+
best_w_idx = dist.view(num_videos, num_w_to_try).argmin(dim=1) # [num_videos]
|
| 221 |
+
w_opt = w[best_w_idx] # [num_videos, num_ws, w_dim]
|
| 222 |
+
|
| 223 |
+
return w_opt
|
| 224 |
+
|
| 225 |
+
#----------------------------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
@torch.no_grad()
|
| 228 |
+
def load_target_images(img_paths: List[os.PathLike], extract_faces: bool=False, ref_image: Tensor=None):
|
| 229 |
+
images = [Image.open(f) for f in tqdm(img_paths, desc='Loading images')]
|
| 230 |
+
|
| 231 |
+
if extract_faces:
|
| 232 |
+
images = extract_faces_from_images(imgs=images, ref_image=ref_image)
|
| 233 |
+
for p, img in zip(img_paths, images):
|
| 234 |
+
img.save('/tmp/data/faces_extracted/' + os.path.basename(p), q=95)
|
| 235 |
+
assert False
|
| 236 |
+
# grid = torch.stack([TVF.to_tensor(x) for x in images])
|
| 237 |
+
# grid = utils.make_grid(grid, nrow=8)
|
| 238 |
+
# save_image(grid, f'/tmp/data/faces_extracted.png')
|
| 239 |
+
# print('Saved the extracted images!')
|
| 240 |
+
|
| 241 |
+
# images = [x[:, 200:-400, 450:-200] for x in images]
|
| 242 |
+
images = [TVF.to_tensor(x) for x in images]
|
| 243 |
+
images = [TVF.resize(x, size=(256, 256)) for x in images]
|
| 244 |
+
|
| 245 |
+
return images
|
| 246 |
+
|
| 247 |
+
#----------------------------------------------------------------------------
|
| 248 |
+
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def extract_faces_from_images(_sentinel=None, imgs: List=None, ref_image: "Image"=None, device: str='cuda'):
|
| 251 |
+
assert _sentinel is None
|
| 252 |
+
try:
|
| 253 |
+
import face_alignment
|
| 254 |
+
except ImportError:
|
| 255 |
+
raise ImportError("To project images with alignment, you need to install the `face_alignment` library.")
|
| 256 |
+
|
| 257 |
+
SELECTED_LANDMARKS = [38, 44]
|
| 258 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
|
| 259 |
+
|
| 260 |
+
ref_landmarks = fa.get_landmarks_from_image(np.array(ref_image))[0][SELECTED_LANDMARKS] # [2, 2]
|
| 261 |
+
landmarks = [fa.get_landmarks_from_image(np.array(x))[0][SELECTED_LANDMARKS] for x in imgs] # [num_imgs, 2, 2]
|
| 262 |
+
ref_dist = ((ref_landmarks[0] - ref_landmarks[1]) ** 2).sum() ** 0.5 # [1]
|
| 263 |
+
dists = [((p[0] - p[1]) ** 2).sum() ** 0.5 for p in landmarks] # [num_imgs]
|
| 264 |
+
resize_ratios = [ref_dist / d for d in dists] # [num_imgs]
|
| 265 |
+
new_sizes = [(int(r * x.size[1]), int(r * x.size[0])) for r, x in zip(resize_ratios, imgs)]
|
| 266 |
+
imgs_resized = [TVF.resize(x, size=s, interpolation=Image.LANCZOS) for x, s in zip(imgs, new_sizes)] # [num_imgs, Image]
|
| 267 |
+
bbox_left = [p[0][0] * r - ref_landmarks[0][0] for p, r in zip(landmarks, resize_ratios)]
|
| 268 |
+
bbox_top = [p[0][1] * r - ref_landmarks[0][1] for p, r in zip(landmarks, resize_ratios)]
|
| 269 |
+
|
| 270 |
+
out = [x.crop(box=(l, t, l + ref_image.size[0], t + ref_image.size[1])) for x, l, t in zip(imgs_resized, bbox_left, bbox_top)]
|
| 271 |
+
|
| 272 |
+
return out
|
| 273 |
+
|
| 274 |
+
#----------------------------------------------------------------------------
|
| 275 |
+
|
| 276 |
+
def pad_box_to_square(left, upper, right, lower):
|
| 277 |
+
h = lower - upper
|
| 278 |
+
w = right - left
|
| 279 |
+
|
| 280 |
+
if h == w:
|
| 281 |
+
return left, upper, right, lower
|
| 282 |
+
elif w > h:
|
| 283 |
+
diff = w - h
|
| 284 |
+
assert False, "Not implemented"
|
| 285 |
+
else:
|
| 286 |
+
pad = (h - w) // 2
|
| 287 |
+
|
| 288 |
+
return (left - pad, upper, right + pad, lower)
|
| 289 |
+
|
| 290 |
+
#----------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
def add_margins(box, margin, width: int=float('inf'), height: int=float('inf')):
|
| 293 |
+
left, upper, right, lower = box
|
| 294 |
+
|
| 295 |
+
return (
|
| 296 |
+
max(0, left - margin[0]),
|
| 297 |
+
max(0, upper - margin[1]),
|
| 298 |
+
min(width, right + margin[2]),
|
| 299 |
+
min(height, lower + margin[3]),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
#----------------------------------------------------------------------------
|
| 303 |
+
|
| 304 |
+
def add_top_margin(box, margin_ratio: float=0.0):
|
| 305 |
+
left, upper, right, lower = box
|
| 306 |
+
height = lower - upper
|
| 307 |
+
margin = int(height * margin_ratio)
|
| 308 |
+
|
| 309 |
+
return (left, max(0, upper - margin), right, lower)
|
| 310 |
+
|
| 311 |
+
#----------------------------------------------------------------------------
|
| 312 |
+
|
| 313 |
+
def save_edited_w(
|
| 314 |
+
_sentinel=None,
|
| 315 |
+
G: Callable=None,
|
| 316 |
+
w_outdir: os.PathLike=None,
|
| 317 |
+
samples_outdir: os.PathLike=None,
|
| 318 |
+
img_names: List[str]=None,
|
| 319 |
+
stack_samples: bool=False,
|
| 320 |
+
num_frames: int = 16,
|
| 321 |
+
each_nth_frame: int = 3,
|
| 322 |
+
all_w: Tensor=None,
|
| 323 |
+
all_motion_z: Tensor=None,
|
| 324 |
+
stacked_samples_out_path: os.PathLike=None,
|
| 325 |
+
):
|
| 326 |
+
assert _sentinel is None
|
| 327 |
+
|
| 328 |
+
# w_outdir = os.path.join(os.path.basename(images_dir))
|
| 329 |
+
|
| 330 |
+
os.makedirs(w_outdir, exist_ok=True)
|
| 331 |
+
num_videos = len(img_names)
|
| 332 |
+
device = all_w.device
|
| 333 |
+
|
| 334 |
+
if not stack_samples:
|
| 335 |
+
os.makedirs(samples_outdir, exist_ok=True)
|
| 336 |
+
else:
|
| 337 |
+
all_samples = []
|
| 338 |
+
|
| 339 |
+
# Generate samples from the given w and save them.
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
z = torch.randn(num_videos, G.z_dim, device=device) # [num_videos, z_dim]
|
| 342 |
+
c = torch.zeros(num_videos, G.c_dim, device=device) # [num_videos, c_dim]
|
| 343 |
+
|
| 344 |
+
for i, w in enumerate(all_w):
|
| 345 |
+
torch.save(w.cpu(), os.path.join(w_outdir, f'{img_names[i]}_w.pt'))
|
| 346 |
+
|
| 347 |
+
if all_motion_z is None:
|
| 348 |
+
motion_z = None
|
| 349 |
+
else:
|
| 350 |
+
motion_z = all_motion_z[i] # [...<any>...]
|
| 351 |
+
torch.save(motion_z.cpu(), os.path.join(w_outdir, f'{img_names[i]}_motion.pt'))
|
| 352 |
+
motion_z = motion_z.unsqueeze(0).to(device) # [1, ...<any>...]
|
| 353 |
+
motion_z = torch.randn_like(motion_z)
|
| 354 |
+
|
| 355 |
+
w = w.unsqueeze(0).to(device) # [1, num_ws, w_dim]
|
| 356 |
+
t = torch.linspace(0, num_frames * (1 + each_nth_frame), num_frames, device=device).unsqueeze(0)
|
| 357 |
+
imgs = G.synthesis(w, c=c[[i]]], t=t, motion_z=motion_z)
|
| 358 |
+
imgs = (imgs * 0.5 + 0.5).clamp(0, 1)
|
| 359 |
+
grid = utils.make_grid(imgs, nrow=num_frames).cpu()
|
| 360 |
+
|
| 361 |
+
if stack_samples:
|
| 362 |
+
all_samples.append(grid)
|
| 363 |
+
else:
|
| 364 |
+
# TVF.to_pil_image(grid).save(os.path.join(samples_outdir, img_names[i]) + '.jpg', q=95)
|
| 365 |
+
save_image(grid, os.path.join(samples_outdir, img_names[i]) + '.png')
|
| 366 |
+
|
| 367 |
+
if stack_samples:
|
| 368 |
+
main_grid = torch.stack(all_samples) # [num_videos, c, h, w * num_frames]
|
| 369 |
+
main_grid = utils.make_grid(main_grid, nrow=1)
|
| 370 |
+
# TVF.to_pil_image(main_grid).save(f'{images_dir}.jpg', q=95)
|
| 371 |
+
save_image(main_grid, stacked_samples_out_path)
|
| 372 |
+
|
| 373 |
+
#----------------------------------------------------------------------------
|
| 374 |
+
|
| 375 |
+
@click.command()
|
| 376 |
+
@click.pass_context
|
| 377 |
+
@click.option('--network_pkl', help='Network pickle filename', metavar='PATH')
|
| 378 |
+
@click.option('--networks_dir', help='Network pickles directory', metavar='PATH')
|
| 379 |
+
# @click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True)
|
| 380 |
+
# @click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
| 381 |
+
# @click.option('--same_motion_codes', type=bool, help='Should we use the same motion codes for all videos?', default=False, show_default=True)
|
| 382 |
+
@click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
|
| 383 |
+
@click.option('--images_dir', help='Where to save the output images', type=str, required=True, metavar='DIR')
|
| 384 |
+
# @click.option('--save_as_mp4', help='Should we save as independent frames or mp4?', type=bool, default=False, metavar='BOOL')
|
| 385 |
+
# @click.option('--video_len', help='Number of frames to generate', type=int, default=16, metavar='INT')
|
| 386 |
+
# @click.option('--fps', help='FPS for mp4 saving', type=int, default=25, metavar='INT')
|
| 387 |
+
# @click.option('--as_grids', help='Save videos as grids', type=bool, default=False, metavar='BOOl')
|
| 388 |
+
@click.option('--zero_periods', help='Zero-out periods predictor?', default=False, type=bool, metavar='BOOL')
|
| 389 |
+
@click.option('--num_weights_to_slice', help='Number of high-frequency coords to remove.', default=0, type=int, metavar='INT')
|
| 390 |
+
@click.option('--use_w_init', help='Init w by LPIPS.', default=False, type=bool, metavar='BOOL')
|
| 391 |
+
@click.option('--use_motion_init', help='Init motions by LPIPS.', default=False, type=bool, metavar='BOOL')
|
| 392 |
+
@click.option('--motion_reg_type', help='Type of the regularization for motion', default=None, type=str, metavar='STR')
|
| 393 |
+
@click.option('--num_steps', help='Number of the optimization steps to perform.', default=1000, type=int, metavar='INT')
|
| 394 |
+
@click.option('--stack_samples', help='When saving, should we stack samples together?', default=False, type=bool, metavar='BOOL')
|
| 395 |
+
@click.option('--extract_faces', help='Use FaceNet to extract the face?', default=False, type=bool, metavar='BOOL')
|
| 396 |
+
|
| 397 |
+
def main(
|
| 398 |
+
ctx: click.Context,
|
| 399 |
+
network_pkl: str,
|
| 400 |
+
networks_dir: str,
|
| 401 |
+
seed: int,
|
| 402 |
+
images_dir: str,
|
| 403 |
+
# save_as_mp4: bool,
|
| 404 |
+
# video_len: int,
|
| 405 |
+
# fps: int,
|
| 406 |
+
# as_grids: bool,
|
| 407 |
+
zero_periods: bool,
|
| 408 |
+
num_weights_to_slice: int,
|
| 409 |
+
use_w_init: bool,
|
| 410 |
+
use_motion_init: bool,
|
| 411 |
+
motion_reg_type: str,
|
| 412 |
+
num_steps: int,
|
| 413 |
+
stack_samples: bool,
|
| 414 |
+
extract_faces: bool,
|
| 415 |
+
):
|
| 416 |
+
if network_pkl is None:
|
| 417 |
+
output_regex = "^network-snapshot-\d{6}.pkl$"
|
| 418 |
+
ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
|
| 419 |
+
# ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
|
| 420 |
+
# network_pkl = os.path.join(networks_dir, ckpts[-1])
|
| 421 |
+
metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
|
| 422 |
+
with open(metrics_file, 'r') as f:
|
| 423 |
+
snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
|
| 424 |
+
best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
|
| 425 |
+
network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
|
| 426 |
+
print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
|
| 427 |
+
# Selecting a checkpoint with the best score
|
| 428 |
+
else:
|
| 429 |
+
assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir"
|
| 430 |
+
|
| 431 |
+
print('Loading networks from "%s"...' % network_pkl, end='')
|
| 432 |
+
device = torch.device('cuda')
|
| 433 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
| 434 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
|
| 435 |
+
print('Loaded!')
|
| 436 |
+
|
| 437 |
+
random.seed(seed)
|
| 438 |
+
np.random.seed(seed)
|
| 439 |
+
torch.manual_seed(seed)
|
| 440 |
+
|
| 441 |
+
if zero_periods:
|
| 442 |
+
G.synthesis.motion_encoder.time_encoder.periods_predictor.weight.data.zero_()
|
| 443 |
+
|
| 444 |
+
if num_weights_to_slice > 0:
|
| 445 |
+
G.synthesis.motion_encoder.time_encoder.weights[:, -num_weights_to_slice:] = 0.0
|
| 446 |
+
|
| 447 |
+
img_paths = sorted([os.path.join(images_dir, p) for p in os.listdir(images_dir) if p.endswith('.jpg')])
|
| 448 |
+
img_names = [n[:n.rfind('.')] for n in [os.path.basename(p) for p in img_paths]]
|
| 449 |
+
target_images = load_target_images(img_paths, extract_faces, ref_image=Image.open('/tmp/data/mean.png')) # [b, c, h, w]
|
| 450 |
+
|
| 451 |
+
assert G.c_dim == 0, "G.c_dim > 0 is not supported"
|
| 452 |
+
|
| 453 |
+
w_all_iters, motion_z_final = project(
|
| 454 |
+
G=G,
|
| 455 |
+
target_images=target_images,
|
| 456 |
+
num_steps=num_steps,
|
| 457 |
+
device=device,
|
| 458 |
+
use_w_init=use_w_init,
|
| 459 |
+
use_motion_init=use_motion_init,
|
| 460 |
+
motion_reg_type=motion_reg_type,
|
| 461 |
+
) # [num_videos, num_ws, w_dim]
|
| 462 |
+
|
| 463 |
+
save_edited_w(
|
| 464 |
+
G=G,
|
| 465 |
+
w_outdir = f'{images_dir}_projected',
|
| 466 |
+
samples_outdir = f'{images_dir}_projected_samples',
|
| 467 |
+
img_names=img_names,
|
| 468 |
+
stack_samples=stack_samples,
|
| 469 |
+
all_w = w_all_iters[-1],
|
| 470 |
+
all_motion_z = motion_z_final,
|
| 471 |
+
stacked_samples_out_path = f'{images_dir}.png'
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
#----------------------------------------------------------------------------
|
| 475 |
+
|
| 476 |
+
if __name__ == "__main__":
|
| 477 |
+
main() # pylint: disable=no-value-for-parameter
|
| 478 |
+
|
| 479 |
+
#----------------------------------------------------------------------------
|
src/torch_utils/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
# empty
|
src/torch_utils/custom_ops.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
+
# and proprietary rights in and to this software, related documentation
|
| 5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
+
# distribution of this software and related documentation without an express
|
| 7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import glob
|
| 11 |
+
import torch
|
| 12 |
+
import torch.utils.cpp_extension
|
| 13 |
+
import importlib
|
| 14 |
+
import hashlib
|
| 15 |
+
import shutil
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
from torch.utils.file_baton import FileBaton
|
| 19 |
+
|
| 20 |
+
#----------------------------------------------------------------------------
|
| 21 |
+
# Global options.
|
| 22 |
+
|
| 23 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
| 24 |
+
|
| 25 |
+
#----------------------------------------------------------------------------
|
| 26 |
+
# Internal helper funcs.
|
| 27 |
+
|
| 28 |
+
def _find_compiler_bindir():
|
| 29 |
+
patterns = [
|
| 30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
| 31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
| 32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
| 33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
| 34 |
+
]
|
| 35 |
+
for pattern in patterns:
|
| 36 |
+
matches = sorted(glob.glob(pattern))
|
| 37 |
+
if len(matches):
|
| 38 |
+
return matches[-1]
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
#----------------------------------------------------------------------------
|
| 42 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
| 43 |
+
|
| 44 |
+
_cached_plugins = dict()
|
| 45 |
+
|
| 46 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
| 47 |
+
assert verbosity in ['none', 'brief', 'full']
|
| 48 |
+
|
| 49 |
+
# Already cached?
|
| 50 |
+
if module_name in _cached_plugins:
|
| 51 |
+
return _cached_plugins[module_name]
|
| 52 |
+
|
| 53 |
+
# Print status.
|
| 54 |
+
if verbosity == 'full':
|
| 55 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
| 56 |
+
elif verbosity == 'brief':
|
| 57 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
| 58 |
+
|
| 59 |
+
try: # pylint: disable=too-many-nested-blocks
|
| 60 |
+
# Make sure we can find the necessary compiler binaries.
|
| 61 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
| 62 |
+
compiler_bindir = _find_compiler_bindir()
|
| 63 |
+
if compiler_bindir is None:
|
| 64 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
| 65 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
| 66 |
+
|
| 67 |
+
# Compile and load.
|
| 68 |
+
verbose_build = (verbosity == 'full')
|
| 69 |
+
|
| 70 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
| 71 |
+
# into a cached build directory under a combined md5 digest of the input
|
| 72 |
+
# source files. Copying is done only if the combined digest has changed.
|
| 73 |
+
# This keeps input file timestamps and filenames the same as in previous
|
| 74 |
+
# extension builds, allowing for fast incremental rebuilds.
|
| 75 |
+
#
|
| 76 |
+
# This optimization is done only in case all the source files reside in
|
| 77 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
| 78 |
+
# environment variable is set (we take this as a signal that the user
|
| 79 |
+
# actually cares about this.)
|
| 80 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
| 81 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
| 82 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
| 83 |
+
|
| 84 |
+
# Compute a combined hash digest for all source files in the same
|
| 85 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
| 86 |
+
hash_md5 = hashlib.md5()
|
| 87 |
+
for src in all_source_files:
|
| 88 |
+
with open(src, 'rb') as f:
|
| 89 |
+
hash_md5.update(f.read())
|
| 90 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
| 91 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
| 92 |
+
|
| 93 |
+
if not os.path.isdir(digest_build_dir):
|
| 94 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
| 95 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
| 96 |
+
if baton.try_acquire():
|
| 97 |
+
try:
|
| 98 |
+
for src in all_source_files:
|
| 99 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
| 100 |
+
finally:
|
| 101 |
+
baton.release()
|
| 102 |
+
else:
|
| 103 |
+
# Someone else is copying source files under the digest dir,
|
| 104 |
+
# wait until done and continue.
|
| 105 |
+
baton.wait()
|
| 106 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
| 107 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
| 108 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
| 109 |
+
else:
|
| 110 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
| 111 |
+
module = importlib.import_module(module_name)
|
| 112 |
+
|
| 113 |
+
except:
|
| 114 |
+
if verbosity == 'brief':
|
| 115 |
+
print('Failed!')
|
| 116 |
+
raise
|
| 117 |
+
|
| 118 |
+
# Print status and add to cache.
|
| 119 |
+
if verbosity == 'full':
|
| 120 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
| 121 |
+
elif verbosity == 'brief':
|
| 122 |
+
print('Done.')
|
| 123 |
+
_cached_plugins[module_name] = module
|
| 124 |
+
return module
|
| 125 |
+
|
| 126 |
+
#----------------------------------------------------------------------------
|