tasin commited on
Commit
f075308
·
1 Parent(s): 3f8c938
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  3. README.md +84 -12
  4. Untitled.ipynb +68 -0
  5. app.py +188 -0
  6. distributed.py +126 -0
  7. model_structure.txt +113 -0
  8. models/diffusion_model.py +762 -0
  9. models/unet_dual_encoder.py +62 -0
  10. project_latent_space.py +75 -0
  11. requirements.txt +30 -0
  12. sample/blue.jpg +0 -0
  13. sample/green.jpg +0 -0
  14. sample/silver.jpg +0 -0
  15. src/deps/__init__.py +0 -0
  16. src/deps/facial_recognition/__init__.py +3 -0
  17. src/deps/facial_recognition/helpers.py +123 -0
  18. src/deps/facial_recognition/model_irse.py +88 -0
  19. src/dnnlib/__init__.py +9 -0
  20. src/dnnlib/util.py +480 -0
  21. src/infra/__init__.py +0 -0
  22. src/infra/experiments.yaml +60 -0
  23. src/infra/launch.py +113 -0
  24. src/infra/slurm_batch_launch.py +96 -0
  25. src/infra/slurm_job.py +46 -0
  26. src/infra/slurm_job_proxy.sh +4 -0
  27. src/infra/utils.py +140 -0
  28. src/metrics/__init__.py +9 -0
  29. src/metrics/frechet_inception_distance.py +54 -0
  30. src/metrics/frechet_video_distance.py +59 -0
  31. src/metrics/inception_score.py +47 -0
  32. src/metrics/kernel_inception_distance.py +46 -0
  33. src/metrics/metric_main.py +154 -0
  34. src/metrics/metric_utils.py +332 -0
  35. src/metrics/video_inception_score.py +54 -0
  36. src/scripts/__init__.py +0 -0
  37. src/scripts/calc_metrics.py +250 -0
  38. src/scripts/calc_metrics_for_dataset.py +169 -0
  39. src/scripts/clip_edit.py +403 -0
  40. src/scripts/construct_static_videos_dataset.py +46 -0
  41. src/scripts/convert_video_to_dataset.py +87 -0
  42. src/scripts/convert_videos_to_frames.py +105 -0
  43. src/scripts/crop_video_dataset.py +69 -0
  44. src/scripts/frames_to_video_grid.py +78 -0
  45. src/scripts/generate.py +148 -0
  46. src/scripts/preprocess_ffs.py +204 -0
  47. src/scripts/profile_model.py +104 -0
  48. src/scripts/project.py +479 -0
  49. src/torch_utils/__init__.py +9 -0
  50. src/torch_utils/custom_ops.py +126 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
README.md CHANGED
@@ -1,12 +1,84 @@
1
- ---
2
- title: FashionFlow
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.37.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div id="top"></div>
2
+
3
+ <h3>FashionFlow: Leveraging Diffusion Models for Dynamic Fashion Video Synthesis from Static Imagery</h3>
4
+
5
+ <p>
6
+ This repository has the official code for 'FashionFlow: Leveraging Diffusion Models for Dynamic Fashion Video Synthesis from Static Imagery'.
7
+ We have included the pre-trained checkpoint, dataset and results.
8
+ </p>
9
+
10
+ > **Abstract:** *Our study introduces a new image-to-video generator called FashionFlow to generate fashion videos. By utilising a diffusion model, we are able to create short videos from still fashion images. Our approach involves developing and connecting relevant components with the diffusion model, which results in the creation of high-fidelity videos that are aligned with the conditional image. The components include the use of pseudo-3D convolutional layers to generate videos efficiently. VAE and CLIP encoders capture vital characteristics from still images to condition the diffusion model at a global level. Our research demonstrates a successful synthesis of fashion videos featuring models posing from various angles, showcasing the fit and appearance of the garment. Our findings hold great promise for improving and enhancing the shopping experience for the online fashion industry.*
11
+
12
+ <!-- Results -->
13
+ ## Teaser
14
+ ![image](sample/teaser.gif)
15
+
16
+ ## Requirements
17
+ - Python 3.9
18
+ - PyTorch 1.11+
19
+ - Tensoboard
20
+ - cv2
21
+ - transformers
22
+ - diffusers
23
+
24
+ ## Model Specification
25
+
26
+ The model was developed using PyTorch and loads pretrained weights for VAE and CLIP. The latent diffusion model consists of a 1D convolutional layer stacked against a 2D convolutional layer (forming a pseudo 3D convolution) and includes attention layers. See the ```model_structure.txt``` file to see the exact layers of our LDM.
27
+
28
+ ## Installation
29
+
30
+ Clone this repository:
31
+
32
+ ```
33
+ git clone https://github.com/1702609/FashionFlow
34
+ cd ./FashionFlow/
35
+ ```
36
+
37
+ Install PyTorch and other dependencies:
38
+
39
+ ```
40
+ pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ ## Dataset
45
+
46
+ Download the Fashion dataset by clicking on this link:
47
+ [[Fashion dataset]](https://vision.cs.ubc.ca/datasets/fashion/)
48
+
49
+ Extract the files and place them in the ```fashion_dataset``` directory. The dataset should be organised as follows:
50
+
51
+ ```
52
+ fashion_dataset
53
+ test
54
+ |-- 91-3003CN5S.mp4
55
+ |-- 91BjuE6irxS.mp4
56
+ |-- 91bxAN6BjAS.mp4
57
+ |-- ...
58
+ train
59
+ |-- 81FyMPk-WIS.mp4
60
+ |-- 91+bCFG1jOS.mp4
61
+ |-- 91+PxmDyrgS.mp4
62
+ |-- ...
63
+ ```
64
+
65
+ Feel free to add your own dataset while following the provided file and folder structure.
66
+
67
+ ## Pre-trained Checkpoint
68
+
69
+ Download the checkpoint by clicking on this link:
70
+ [[Pre-trained checkpoints]](https://www.dropbox.com/scl/fi/p9fv7o3j7ti0yu2umsgmv/FashionFlow_checkpoint.pth?rlkey=mqsto9i4ujh6xhvab0e2s6n7d&dl=0)
71
+ Extract the files and place them in the ```checkpoint``` directory
72
+
73
+ ## Inference
74
+ To run the inference of our model, execute ```python inference.py```. The results will be saved in the ```result``` directory.
75
+
76
+ ## Train
77
+
78
+ Before training, images and videos have to be projected to latent space for efficient training. Execute ```python project_latent_space.py``` where the tensors will be saved in the ```fashion_dataset_tensor``` directory.
79
+
80
+ Run ```python -m torch.distributed.launch --nproc_per_node=<number of GPUs> train.py``` to train the model. The checkpoints will be saved in the ```checkpoint``` directory periodically. Also, you can view the training progress using tensorboardX located in ```video_progress``` or find the generated ```.mp4``` on ```training_sample```.
81
+
82
+ ## Comparison
83
+
84
+ ![image](sample/comparison.gif)
Untitled.ipynb ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "5b3c9fac-51c3-4ecc-8606-5a298076560e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from huggingface_hub import notebook_login\n"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "51fe1170-c6eb-4b3a-a055-663faf35ab5a",
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "data": {
21
+ "application/vnd.jupyter.widget-view+json": {
22
+ "model_id": "6c6d20c8f5e847d7985f6b49a7206a2d",
23
+ "version_major": 2,
24
+ "version_minor": 0
25
+ },
26
+ "text/plain": [
27
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
28
+ ]
29
+ },
30
+ "metadata": {},
31
+ "output_type": "display_data"
32
+ }
33
+ ],
34
+ "source": [
35
+ "notebook_login()"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "4297ea17-a4f8-4290-b561-f582b1adc189",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": []
45
+ }
46
+ ],
47
+ "metadata": {
48
+ "kernelspec": {
49
+ "display_name": "work",
50
+ "language": "python",
51
+ "name": "work"
52
+ },
53
+ "language_info": {
54
+ "codemirror_mode": {
55
+ "name": "ipython",
56
+ "version": 3
57
+ },
58
+ "file_extension": ".py",
59
+ "mimetype": "text/x-python",
60
+ "name": "python",
61
+ "nbconvert_exporter": "python",
62
+ "pygments_lexer": "ipython3",
63
+ "version": "3.11.9"
64
+ }
65
+ },
66
+ "nbformat": 4,
67
+ "nbformat_minor": 5
68
+ }
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import torchvision.transforms as transforms
4
+ from models.unet_dual_encoder import Embedding_Adapter
5
+ from models.diffusion_model import SpaceTimeUnet
6
+ import numpy as np
7
+ import torchvision.transforms.functional as TVF
8
+ from diffusers import AutoencoderKL
9
+ from PIL import Image
10
+ from transformers import CLIPVisionModel, CLIPProcessor
11
+ import torch.nn.functional as F
12
+ import gradio as gr
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ frameLimit = 70
17
+
18
+ def cosine_beta_schedule(timesteps, start=0.0001, end=0.02):
19
+ betas = []
20
+ for i in reversed(range(timesteps)):
21
+ T = timesteps - 1
22
+ beta = start + 0.5 * (end - start) * (1 + np.cos((i / T) * np.pi))
23
+ betas.append(beta)
24
+ return torch.Tensor(betas)
25
+
26
+
27
+ def get_index_from_list(vals, t, x_shape):
28
+ batch_size = t.shape[0]
29
+ out = vals.gather(-1, t.cpu())
30
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
31
+
32
+ def forward_diffusion_sample(x_0, t):
33
+ noise = torch.randn_like(x_0)
34
+ sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
35
+ sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
36
+ sqrt_one_minus_alphas_cumprod, t, x_0.shape
37
+ )
38
+ # mean + variance
39
+ return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
40
+ + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
41
+
42
+ T = 1000
43
+ betas = cosine_beta_schedule(timesteps=T)
44
+ # Pre-calculate different terms for closed form
45
+ alphas = 1. - betas
46
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
47
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
48
+ sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
49
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
50
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
51
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
52
+
53
+ def get_transform():
54
+ image_transforms = transforms.Compose(
55
+ [
56
+ transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR),
57
+ transforms.ToTensor(),
58
+ ])
59
+ return image_transforms
60
+
61
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4",
62
+ subfolder="vae",
63
+ revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c")
64
+ vae.to(device)
65
+ vae.requires_grad_(False)
66
+
67
+ with torch.no_grad():
68
+ Net = SpaceTimeUnet(
69
+ dim = 64,
70
+ channels = 4,
71
+ dim_mult = (1, 2, 4, 8),
72
+ temporal_compression = (False, False, False, True),
73
+ self_attns = (False, False, False, True),
74
+ condition_on_timestep=True
75
+ ).to(device)
76
+ adapter = Embedding_Adapter(input_nc=1280, output_nc=1280).to(device)
77
+
78
+ clip_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
79
+ clip_encoder.requires_grad_(False)
80
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
81
+
82
+ checkpoint = torch.load(hf_hub_download(repo_id="sunjuice/FashionFlow_model", filename="FashionFlow_model.pth"))
83
+
84
+ Net.load_state_dict(checkpoint['net'])
85
+ adapter.load_state_dict(checkpoint['adapter'])
86
+ del checkpoint
87
+ torch.cuda.empty_cache()
88
+
89
+ def save_video_frames_as_mp4(frames, fps, save_path):
90
+ frame_h, frame_w = frames[0].shape[2:]
91
+ fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
92
+ video = cv2.VideoWriter(save_path, fourcc, fps, (frame_w, frame_h))
93
+ frames = frames[0]
94
+ for frame in frames:
95
+ frame = np.array(TVF.to_pil_image(frame))
96
+ video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
97
+ video.release()
98
+
99
+ @torch.no_grad()
100
+ def VAE_encode(image):
101
+ init_latent_dist = vae.encode(image).latent_dist.sample()
102
+ init_latent_dist *= 0.18215
103
+ encoded_image = (init_latent_dist).unsqueeze(1)
104
+ return encoded_image
105
+
106
+ @torch.no_grad()
107
+ def VAE_decode(video, vae_net):
108
+ decoded_video = None
109
+ for i in range(video.shape[1]):
110
+ image = video[:, i, :, :, :]
111
+ image = 1 / 0.18215 * image
112
+ image = vae_net.decode(image).sample
113
+ image = image.clamp(0,1)
114
+ if i == 0:
115
+ decoded_video = image.unsqueeze(1)
116
+ else:
117
+ decoded_video = torch.cat([decoded_video, image.unsqueeze(1)], 1)
118
+ return decoded_video
119
+
120
+ @torch.no_grad()
121
+ def sample_timestep(x, image, t):
122
+ betas_t = get_index_from_list(betas, t, x.shape)
123
+ sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
124
+ sqrt_one_minus_alphas_cumprod, t, x.shape
125
+ )
126
+ sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
127
+
128
+ # Call model (current image - noise prediction)
129
+ with torch.cuda.amp.autocast():
130
+ sample_output = Net(x.permute(0, 2, 1, 3, 4), image, timestep=t.float())
131
+
132
+ sample_output = sample_output.permute(0, 2, 1, 3, 4)
133
+ model_mean = sqrt_recip_alphas_t * (
134
+ x - betas_t * sample_output / sqrt_one_minus_alphas_cumprod_t
135
+ )
136
+ if t.item() == 0:
137
+ return model_mean
138
+ else:
139
+ noise = torch.randn_like(x)
140
+ posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
141
+ return model_mean + torch.sqrt(posterior_variance_t) * noise
142
+
143
+ def tensor2image(tensor):
144
+ numpy_image = tensor[0].cpu().detach().numpy()
145
+ rescaled_image = (numpy_image * 255).astype(np.uint8)
146
+ pil_image = Image.fromarray(rescaled_image.transpose(1, 2, 0))
147
+ return pil_image
148
+
149
+ @torch.no_grad()
150
+ def get_image_embedding(input_image):
151
+ inputs = clip_processor(images=list(input_image), return_tensors="pt")
152
+ inputs = {k: v.to(device) for k, v in inputs.items()}
153
+ clip_hidden_states = clip_encoder(**inputs).last_hidden_state.to(device)
154
+ vae_hidden_states = vae.encode(input_image).latent_dist.sample() * 0.18215
155
+ encoder_hidden_states = adapter(clip_hidden_states, vae_hidden_states)
156
+ return encoder_hidden_states
157
+
158
+ def predict_fn(img_path, progress=gr.Progress()):
159
+ image = get_transform(Image.open(img_path).convert('RGB'))
160
+ encoder_hidden_states = get_image_embedding(input_image=image)
161
+ encoded_image = VAE_encode(image)
162
+ noise_video = torch.randn([1, frameLimit, 4, 80, 64]).to(device)
163
+ noise_video[:, 0:1] = encoded_image
164
+ with torch.no_grad():
165
+ for i in progress.tqdm(range(0, T)[::-1]):
166
+ t = torch.full((1,), i, device=device).long()
167
+ noise_video = sample_timestep(noise_video, encoder_hidden_states, t)
168
+ noise_video[:, 0:1] = encoded_image
169
+ final_video = VAE_decode(noise_video, vae)
170
+ save_video_frames_as_mp4(final_video, 25, "result.mp4")
171
+ return "result.mp4"
172
+
173
+ with gr.Tab("Image-to-Video"):
174
+ with gr.Row():
175
+ with gr.Column():
176
+ image_input = gr.Image(type="pil", label="Input Image")
177
+ img_generate = gr.Button("Generate Video")
178
+ with gr.Column():
179
+ img_output = gr.Video(label="Generated Video")
180
+ gr.Examples(
181
+ examples=[
182
+ ['sample/blue.jpg',]
183
+ ],
184
+ inputs=[image_input],
185
+ outputs=[img_output],
186
+ fn=predict_fn,
187
+ cache_examples='lazy',
188
+ )
distributed.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pickle
3
+
4
+ import torch
5
+ from torch import distributed as dist
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+
9
+ def get_rank():
10
+ if not dist.is_available():
11
+ return 0
12
+
13
+ if not dist.is_initialized():
14
+ return 0
15
+
16
+ return dist.get_rank()
17
+
18
+
19
+ def synchronize():
20
+ if not dist.is_available():
21
+ return
22
+
23
+ if not dist.is_initialized():
24
+ return
25
+
26
+ world_size = dist.get_world_size()
27
+
28
+ if world_size == 1:
29
+ return
30
+
31
+ dist.barrier()
32
+
33
+
34
+ def get_world_size():
35
+ if not dist.is_available():
36
+ return 1
37
+
38
+ if not dist.is_initialized():
39
+ return 1
40
+
41
+ return dist.get_world_size()
42
+
43
+
44
+ def reduce_sum(tensor):
45
+ if not dist.is_available():
46
+ return tensor
47
+
48
+ if not dist.is_initialized():
49
+ return tensor
50
+
51
+ tensor = tensor.clone()
52
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53
+
54
+ return tensor
55
+
56
+
57
+ def gather_grad(params):
58
+ world_size = get_world_size()
59
+
60
+ if world_size == 1:
61
+ return
62
+
63
+ for param in params:
64
+ if param.grad is not None:
65
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66
+ param.grad.data.div_(world_size)
67
+
68
+
69
+ def all_gather(data):
70
+ world_size = get_world_size()
71
+
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ buffer = pickle.dumps(data)
76
+ storage = torch.ByteStorage.from_buffer(buffer)
77
+ tensor = torch.ByteTensor(storage).to('cuda')
78
+
79
+ local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80
+ size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81
+ dist.all_gather(size_list, local_size)
82
+ size_list = [int(size.item()) for size in size_list]
83
+ max_size = max(size_list)
84
+
85
+ tensor_list = []
86
+ for _ in size_list:
87
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88
+
89
+ if local_size != max_size:
90
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91
+ tensor = torch.cat((tensor, padding), 0)
92
+
93
+ dist.all_gather(tensor_list, tensor)
94
+
95
+ data_list = []
96
+
97
+ for size, tensor in zip(size_list, tensor_list):
98
+ buffer = tensor.cpu().numpy().tobytes()[:size]
99
+ data_list.append(pickle.loads(buffer))
100
+
101
+ return data_list
102
+
103
+
104
+ def reduce_loss_dict(loss_dict):
105
+ world_size = get_world_size()
106
+
107
+ if world_size < 2:
108
+ return loss_dict
109
+
110
+ with torch.no_grad():
111
+ keys = []
112
+ losses = []
113
+
114
+ for k in sorted(loss_dict.keys()):
115
+ keys.append(k)
116
+ losses.append(loss_dict[k])
117
+
118
+ losses = torch.stack(losses, 0)
119
+ dist.reduce(losses, dst=0)
120
+
121
+ if dist.get_rank() == 0:
122
+ losses /= world_size
123
+
124
+ reduced_losses = {k: v for k, v in zip(keys, losses)}
125
+
126
+ return reduced_losses
model_structure.txt ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ======================================================================================================================================================
2
+ Layer (type (var_name)) Input Shape Output Shape Param # Trainable
3
+ ======================================================================================================================================================
4
+ SpaceTimeUnet (SpaceTimeUnet) [1, 4, 70, 80, 64] [1, 4, 70, 80, 64] -- True
5
+ ├─Sequential (to_timestep_cond) [1] [1, 256] -- True
6
+ │ └─SinusoidalPosEmb (0) [1] [1, 64] -- --
7
+ │ └─Linear (1) [1, 64] [1, 256] 16,640 True
8
+ │ └─SiLU (2) [1, 256] [1, 256] -- --
9
+ ├─PseudoConv3d (conv_in) [1, 4, 70, 80, 64] [1, 64, 70, 80, 64] -- True
10
+ │ └─Conv2d (spatial_conv) [70, 4, 80, 64] [70, 64, 80, 64] 12,608 True
11
+ │ └─Conv1d (temporal_conv) [5120, 64, 70] [5120, 64, 70] 12,352 True
12
+ ├─ModuleList (downs) -- -- -- True
13
+ │ └─ModuleList (0) -- -- -- True
14
+ │ │ └─ResnetBlock (0) [1, 64, 70, 80, 64] [1, 64, 70, 80, 64] 131,712 True
15
+ │ │ └─ModuleList (1) -- -- 197,632 True
16
+ │ │ └─Downsample (3) [1, 64, 70, 80, 64] [1, 64, 70, 40, 32] 16,384 True
17
+ │ │ └─AttentionBlock (4) [1, 64, 70, 40, 32] [1, 64, 70, 40, 32] 160,704 True
18
+ │ └─ModuleList (1) -- -- -- True
19
+ │ │ └─ResnetBlock (0) [1, 64, 70, 40, 32] [1, 128, 70, 40, 32] 394,624 True
20
+ │ │ └─ModuleList (1) -- -- 788,480 True
21
+ │ │ └─Downsample (3) [1, 128, 70, 40, 32] [1, 128, 70, 20, 16] 65,536 True
22
+ │ │ └─AttentionBlock (4) [1, 128, 70, 20, 16] [1, 128, 70, 20, 16] 444,288 True
23
+ │ └─ModuleList (2) -- -- -- True
24
+ │ │ └─ResnetBlock (0) [1, 128, 70, 20, 16] [1, 256, 70, 20, 16] 1,444,608 True
25
+ │ │ └─ModuleList (1) -- -- 3,149,824 True
26
+ │ │ └─Downsample (3) [1, 256, 70, 20, 16] [1, 256, 70, 10, 8] 262,144 True
27
+ │ │ └─AttentionBlock (4) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,380,096 True
28
+ │ └─ModuleList (3) -- -- -- True
29
+ │ │ └─ResnetBlock (0) [1, 256, 70, 10, 8] [1, 512, 70, 10, 8] 5,510,656 True
30
+ │ │ └─ModuleList (1) -- -- 12,591,104 True
31
+ │ │ └─SpatioTemporalAttention (2) [1, 512, 70, 10, 8] [1, 512, 70, 10, 8] 4,334,181 True
32
+ │ │ └─Downsample (3) [1, 512, 70, 10, 8] [1, 512, 35, 5, 4] 1,572,864 True
33
+ │ │ └─AttentionBlock (4) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 4,726,272 True
34
+ ├─ResnetBlock (mid_block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
35
+ │ └─Sequential (timestep_mlp) [1, 256] [1, 1024] -- True
36
+ │ │ └─SiLU (0) [1, 256] [1, 256] -- --
37
+ │ │ └─Linear (1) [1, 256] [1, 1024] 263,168 True
38
+ │ └─Block (block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
39
+ │ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
40
+ │ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
41
+ │ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
42
+ │ └─Block (block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
43
+ │ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
44
+ │ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
45
+ │ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
46
+ │ └─Identity (res_conv) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
47
+ ├─SpatioTemporalAttention (mid_attn) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
48
+ │ └─ContinuousPositionBias (spatial_rel_pos_bias) -- [8, 20, 20] -- True
49
+ │ │ └─ModuleList (net) -- -- 68,616 True
50
+ │ └─Attention (spatial_attn) [35, 20, 512] [35, 20, 512] -- True
51
+ │ │ └─LayerNorm (norm) [35, 20, 512] [35, 20, 512] 1,024 True
52
+ │ │ └─Linear (to_q) [35, 20, 512] [35, 20, 512] 262,144 True
53
+ │ │ └─Linear (to_kv) [35, 20, 512] [35, 20, 1024] 524,288 True
54
+ │ │ └─Linear (to_out) [35, 20, 512] [35, 20, 512] 262,144 True
55
+ │ └─ContinuousPositionBias (temporal_rel_pos_bias) -- [8, 35, 35] -- True
56
+ │ │ └─ModuleList (net) -- -- 68,360 True
57
+ │ └─Attention (temporal_attn) [20, 35, 512] [20, 35, 512] -- True
58
+ │ │ └─LayerNorm (norm) [20, 35, 512] [20, 35, 512] 1,024 True
59
+ │ │ └─Linear (to_q) [20, 35, 512] [20, 35, 512] 262,144 True
60
+ │ │ └─Linear (to_kv) [20, 35, 512] [20, 35, 1024] 524,288 True
61
+ │ │ └─Linear (to_out) [20, 35, 512] [20, 35, 512] 262,144 True
62
+ │ └─FeedForward (ff) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
63
+ │ │ └─Sequential (proj_in) [1, 512, 35, 5, 4] [1, 1365, 35, 5, 4] 1,397,760 True
64
+ │ │ └─Sequential (proj_out) [1, 1365, 35, 5, 4] [1, 512, 35, 5, 4] 700,245 True
65
+ ├─ResnetBlock (mid_block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
66
+ │ └─Sequential (timestep_mlp) [1, 256] [1, 1024] -- True
67
+ │ │ └─SiLU (0) [1, 256] [1, 256] -- --
68
+ │ │ └─Linear (1) [1, 256] [1, 1024] 263,168 True
69
+ │ └─Block (block1) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
70
+ │ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
71
+ │ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
72
+ │ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
73
+ │ └─Block (block2) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- True
74
+ │ │ └─PseudoConv3d (project) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 3,146,752 True
75
+ │ │ └─GroupNorm (norm) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] 1,024 True
76
+ │ │ └─SiLU (act) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
77
+ │ └─Identity (res_conv) [1, 512, 35, 5, 4] [1, 512, 35, 5, 4] -- --
78
+ ├─ModuleList (ups) -- -- -- True
79
+ │ └─ModuleList (3) -- -- -- True
80
+ │ │ └─Upsample (3) [1, 512, 35, 5, 4] [1, 512, 70, 10, 8] 1,575,936 True
81
+ │ │ └─ResnetBlock (0) [1, 1024, 70, 10, 8] [1, 256, 70, 10, 8] 3,738,368 True
82
+ │ │ └─ModuleList (1) -- -- 4,526,336 True
83
+ │ │ └─SpatioTemporalAttention (2) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,609,786 True
84
+ │ │ └─AttentionBlock (4) [1, 256, 70, 10, 8] [1, 256, 70, 10, 8] 1,380,096 True
85
+ │ └─ModuleList (2) -- -- -- True
86
+ │ │ └─Upsample (3) [1, 256, 70, 10, 8] [1, 256, 70, 20, 16] 263,168 True
87
+ │ │ └─ResnetBlock (0) [1, 512, 70, 20, 16] [1, 128, 70, 20, 16] 968,064 True
88
+ │ │ └─ModuleList (1) -- -- 1,132,672 True
89
+ │ │ └─AttentionBlock (4) [1, 128, 70, 20, 16] [1, 128, 70, 20, 16] 444,288 True
90
+ │ └─ModuleList (1) -- -- -- True
91
+ │ │ └─Upsample (3) [1, 128, 70, 20, 16] [1, 128, 70, 40, 32] 66,048 True
92
+ │ │ └─ResnetBlock (0) [1, 256, 70, 40, 32] [1, 64, 70, 40, 32] 258,752 True
93
+ │ │ └─ModuleList (1) -- -- 283,712 True
94
+ │ │ └─AttentionBlock (4) [1, 64, 70, 40, 32] [1, 64, 70, 40, 32] 160,704 True
95
+ │ └─ModuleList (0) -- -- -- True
96
+ │ │ └─Upsample (3) [1, 64, 70, 40, 32] [1, 64, 70, 80, 64] 16,640 True
97
+ │ │ └─ResnetBlock (0) [1, 128, 70, 80, 64] [1, 64, 70, 80, 64] 176,832 True
98
+ │ │ └─ModuleList (1) -- -- 242,752 True
99
+ │ │ └─AttentionBlock (4) [1, 64, 70, 80, 64] [1, 64, 70, 80, 64] 160,704 True
100
+ ├─PseudoConv3d (conv_out) [1, 64, 70, 80, 64] [1, 4, 70, 80, 64] -- True
101
+ │ └─Conv2d (spatial_conv) [70, 64, 80, 64] [70, 4, 80, 64] 2,308 True
102
+ │ └─Conv1d (temporal_conv) [5120, 4, 70] [5120, 4, 70] 52 True
103
+ ======================================================================================================================================================
104
+ Total params: 71,671,548
105
+ Trainable params: 71,671,548
106
+ Non-trainable params: 0
107
+ Total mult-adds (G): 732.56
108
+ ======================================================================================================================================================
109
+ Input size (MB): 5.89
110
+ Forward/backward pass size (MB): 18136.46
111
+ Params size (MB): 286.69
112
+ Estimated Total Size (MB): 18429.04
113
+ ======================================================================================================================================================
models/diffusion_model.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import functools
3
+ from operator import mul
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, einsum
8
+
9
+ from einops import rearrange, repeat, pack, unpack
10
+ from einops.layers.torch import Rearrange
11
+ # helper functions
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ def mul_reduce(tup):
22
+ return functools.reduce(mul, tup)
23
+
24
+
25
+ def divisible_by(numer, denom):
26
+ return (numer % denom) == 0
27
+
28
+
29
+ mlist = nn.ModuleList
30
+
31
+ # for time conditioning
32
+
33
+ class SinusoidalPosEmb(nn.Module):
34
+ def __init__(self, dim, theta=10000):
35
+ super().__init__()
36
+ self.theta = theta
37
+ self.dim = dim
38
+
39
+ def forward(self, x):
40
+ dtype, device = x.dtype, x.device
41
+ assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
42
+
43
+ half_dim = self.dim // 2
44
+ emb = math.log(self.theta) / (half_dim - 1)
45
+ emb = torch.exp(torch.arange(half_dim, device=device, dtype=dtype) * -emb)
46
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
47
+ return torch.cat((emb.sin(), emb.cos()), dim=-1).type(dtype)
48
+
49
+
50
+ # layernorm 3d
51
+
52
+ class ChanLayerNorm(nn.Module):
53
+ def __init__(self, dim):
54
+ super().__init__()
55
+ self.g = nn.Parameter(torch.ones(dim, 1, 1, 1))
56
+
57
+ def forward(self, x):
58
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
59
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
60
+ mean = torch.mean(x, dim=1, keepdim=True)
61
+ return (x - mean) * var.clamp(min=eps).rsqrt() * self.g
62
+
63
+
64
+ # feedforward
65
+
66
+ def shift_token(t):
67
+ t, t_shift = t.chunk(2, dim=1)
68
+ t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value=0.)
69
+ return torch.cat((t, t_shift), dim=1)
70
+
71
+
72
+ class GEGLU(nn.Module):
73
+ def forward(self, x):
74
+ x, gate = x.chunk(2, dim=1)
75
+ return x * F.gelu(gate)
76
+
77
+
78
+ class FeedForward(nn.Module):
79
+ def __init__(self, dim, mult=4):
80
+ super().__init__()
81
+
82
+ inner_dim = int(dim * mult * 2 / 3)
83
+ self.proj_in = nn.Sequential(
84
+ nn.Conv3d(dim, inner_dim * 2, 1, bias=False),
85
+ GEGLU()
86
+ )
87
+
88
+ self.proj_out = nn.Sequential(
89
+ ChanLayerNorm(inner_dim),
90
+ nn.Conv3d(inner_dim, dim, 1, bias=False)
91
+ )
92
+
93
+ def forward(self, x, enable_time=True):
94
+ x = self.proj_in(x)
95
+ if enable_time:
96
+ x = shift_token(x)
97
+ return self.proj_out(x)
98
+
99
+
100
+ # best relative positional encoding
101
+
102
+ class ContinuousPositionBias(nn.Module):
103
+ """ from https://arxiv.org/abs/2111.09883 """
104
+
105
+ def __init__(
106
+ self,
107
+ *,
108
+ dim,
109
+ heads,
110
+ num_dims=1,
111
+ layers=2
112
+ ):
113
+ super().__init__()
114
+ self.num_dims = num_dims
115
+
116
+ self.net = nn.ModuleList([])
117
+ self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
118
+
119
+ for _ in range(layers - 1):
120
+ self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
121
+
122
+ self.net.append(nn.Linear(dim, heads))
123
+
124
+ @property
125
+ def device(self):
126
+ return next(self.parameters()).device
127
+
128
+ def forward(self, *dimensions):
129
+ device = self.device
130
+
131
+ shape = torch.tensor(dimensions, device=device)
132
+ rel_pos_shape = 2 * shape - 1
133
+
134
+ # calculate strides
135
+
136
+ strides = torch.flip(rel_pos_shape, (0,)).cumprod(dim=-1)
137
+ strides = torch.flip(F.pad(strides, (1, -1), value=1), (0,))
138
+
139
+ # get all positions and calculate all the relative distances
140
+
141
+ positions = [torch.arange(d, device=device) for d in dimensions]
142
+ grid = torch.stack(torch.meshgrid(*positions, indexing='ij'), dim=-1)
143
+ grid = rearrange(grid, '... c -> (...) c')
144
+ rel_dist = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
145
+
146
+ # get all relative positions across all dimensions
147
+
148
+ rel_positions = [torch.arange(-d + 1, d, device=device) for d in dimensions]
149
+ rel_pos_grid = torch.stack(torch.meshgrid(*rel_positions, indexing='ij'), dim=-1)
150
+ rel_pos_grid = rearrange(rel_pos_grid, '... c -> (...) c')
151
+
152
+ # mlp input
153
+
154
+ bias = rel_pos_grid.float()
155
+
156
+ for layer in self.net:
157
+ bias = layer(bias)
158
+
159
+ # convert relative distances to indices of the bias
160
+
161
+ rel_dist += (shape - 1) # make sure all positive
162
+ rel_dist *= strides
163
+ rel_dist_indices = rel_dist.sum(dim=-1)
164
+
165
+ # now select the bias for each unique relative position combination
166
+
167
+ bias = bias[rel_dist_indices]
168
+ return rearrange(bias, 'i j h -> h i j')
169
+
170
+
171
+ # helper classes
172
+
173
+ class CrossAttention(nn.Module):
174
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
175
+ super().__init__()
176
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
177
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
178
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
179
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
180
+ self.n_heads = n_heads
181
+ self.d_head = d_embed // n_heads
182
+
183
+ def forward(self, x, y):
184
+ input_shape = x.shape
185
+ batch_size, sequence_length, d_embed = input_shape
186
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
187
+
188
+ q = self.q_proj(x)
189
+ k = self.k_proj(y)
190
+ v = self.v_proj(y)
191
+
192
+ q = q.view(interim_shape).transpose(1, 2)
193
+ k = k.view(interim_shape).transpose(1, 2)
194
+ v = v.view(interim_shape).transpose(1, 2)
195
+
196
+ weight = q @ k.transpose(-1, -2)
197
+ weight /= math.sqrt(self.d_head)
198
+ weight = F.softmax(weight, dim=-1)
199
+
200
+ output = weight @ v
201
+ output = output.transpose(1, 2).contiguous()
202
+ output = output.view(input_shape)
203
+ output = self.out_proj(output)
204
+ return output
205
+
206
+ class AttentionBlock(nn.Module):
207
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
208
+ super().__init__()
209
+ channels = n_head * n_embd
210
+
211
+ #self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
212
+ #self.conv_input = PseudoConv3d(channels, channels, 1)
213
+ self.layernorm_2 = nn.LayerNorm(channels)
214
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
215
+ self.layernorm_3 = nn.LayerNorm(channels)
216
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
217
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
218
+ self.conv_output = PseudoConv3d(channels, channels, 1, bias=False)
219
+
220
+ def forward(self, x, context):
221
+ b, c, *_, h, w = x.shape
222
+ #x = self.groupnorm(x)
223
+ #x = self.conv_input(x)
224
+ x = rearrange(x, 'b c f h w -> b (h w f) c')
225
+
226
+ residue_short = x
227
+ x = self.layernorm_2(x)
228
+ x = self.attention_2(x, context)
229
+ x += residue_short
230
+
231
+ residue_short = x
232
+ x = self.layernorm_3(x)
233
+ x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
234
+ x = x * F.gelu(gate)
235
+ x = self.linear_geglu_2(x)
236
+ x += residue_short
237
+
238
+ x = rearrange(x, 'b (h w f) c -> b c f h w', b=b, c=c, h=h, w=w)
239
+ x = self.conv_output(x)
240
+ return x
241
+
242
+ class Attention(nn.Module):
243
+ def __init__(
244
+ self,
245
+ dim,
246
+ dim_head=64,
247
+ heads=8
248
+ ):
249
+ super().__init__()
250
+ self.heads = heads
251
+ self.scale = dim_head ** -0.5
252
+ inner_dim = dim_head * heads
253
+
254
+ self.norm = nn.LayerNorm(dim)
255
+
256
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
257
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
258
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
259
+ nn.init.zeros_(self.to_out.weight.data) # identity with skip connection
260
+
261
+ def forward(
262
+ self,
263
+ x,
264
+ rel_pos_bias=None
265
+ ):
266
+ x = self.norm(x)
267
+
268
+ q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)
269
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
270
+
271
+ q = q * self.scale
272
+
273
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
274
+
275
+ if exists(rel_pos_bias):
276
+ sim = sim + rel_pos_bias
277
+
278
+ attn = sim.softmax(dim=-1)
279
+
280
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
281
+
282
+ out = rearrange(out, 'b h n d -> b n (h d)')
283
+ return self.to_out(out)
284
+
285
+
286
+ # main contribution - pseudo 3d conv
287
+
288
+ class PseudoConv3d(nn.Module):
289
+ def __init__(
290
+ self,
291
+ dim,
292
+ dim_out=None,
293
+ kernel_size=3,
294
+ *,
295
+ temporal_kernel_size=None,
296
+ **kwargs
297
+ ):
298
+ super().__init__()
299
+ dim_out = default(dim_out, dim)
300
+ temporal_kernel_size = default(temporal_kernel_size, kernel_size)
301
+
302
+ self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size=kernel_size, padding=kernel_size // 2)
303
+ self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size=temporal_kernel_size,
304
+ padding=temporal_kernel_size // 2) if kernel_size > 1 else None
305
+
306
+ if exists(self.temporal_conv):
307
+ nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
308
+ nn.init.zeros_(self.temporal_conv.bias.data)
309
+
310
+ def forward(
311
+ self,
312
+ x,
313
+ enable_time=True
314
+ ):
315
+ b, c, *_, h, w = x.shape
316
+
317
+ is_video = x.ndim == 5
318
+ enable_time &= is_video
319
+
320
+ if is_video:
321
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
322
+
323
+ x = self.spatial_conv(x)
324
+
325
+ if is_video:
326
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=b)
327
+
328
+ if not enable_time or not exists(self.temporal_conv):
329
+ return x
330
+
331
+ x = rearrange(x, 'b c f h w -> (b h w) c f')
332
+
333
+ x = self.temporal_conv(x)
334
+
335
+ x = rearrange(x, '(b h w) c f -> b c f h w', h=h, w=w)
336
+
337
+ return x
338
+
339
+
340
+ # factorized spatial temporal attention from Ho et al.
341
+
342
+ class SpatioTemporalAttention(nn.Module):
343
+ def __init__(
344
+ self,
345
+ dim,
346
+ *,
347
+ dim_head=64,
348
+ heads=8,
349
+ add_feed_forward=True,
350
+ ff_mult=4
351
+ ):
352
+ super().__init__()
353
+ self.spatial_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
354
+ self.spatial_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=2)
355
+
356
+ self.temporal_attn = Attention(dim=dim, dim_head=dim_head, heads=heads)
357
+ self.temporal_rel_pos_bias = ContinuousPositionBias(dim=dim // 2, heads=heads, num_dims=1)
358
+
359
+ self.has_feed_forward = add_feed_forward
360
+ if not add_feed_forward:
361
+ return
362
+
363
+ self.ff = FeedForward(dim=dim, mult=ff_mult)
364
+
365
+ def forward(
366
+ self,
367
+ x,
368
+ enable_time=True
369
+ ):
370
+ b, c, *_, h, w = x.shape
371
+ is_video = x.ndim == 5
372
+ enable_time &= is_video
373
+
374
+ if is_video:
375
+ x = rearrange(x, 'b c f h w -> (b f) (h w) c')
376
+ else:
377
+ x = rearrange(x, 'b c h w -> b (h w) c')
378
+
379
+ space_rel_pos_bias = self.spatial_rel_pos_bias(h, w)
380
+
381
+ x = self.spatial_attn(x, rel_pos_bias=space_rel_pos_bias) + x
382
+
383
+ if is_video:
384
+ x = rearrange(x, '(b f) (h w) c -> b c f h w', b=b, h=h, w=w)
385
+ else:
386
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
387
+
388
+ if enable_time:
389
+ x = rearrange(x, 'b c f h w -> (b h w) f c')
390
+
391
+ time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
392
+
393
+ x = self.temporal_attn(x, rel_pos_bias=time_rel_pos_bias) + x
394
+
395
+ x = rearrange(x, '(b h w) f c -> b c f h w', w=w, h=h)
396
+
397
+ if self.has_feed_forward:
398
+ x = self.ff(x, enable_time=enable_time) + x
399
+
400
+ return x
401
+
402
+
403
+ # resnet block
404
+
405
+ class Block(nn.Module):
406
+ def __init__(
407
+ self,
408
+ dim,
409
+ dim_out,
410
+ kernel_size=3,
411
+ temporal_kernel_size=None,
412
+ groups=8
413
+ ):
414
+ super().__init__()
415
+ self.project = PseudoConv3d(dim, dim_out, 3)
416
+ self.norm = nn.GroupNorm(groups, dim_out)
417
+ self.act = nn.SiLU()
418
+
419
+ def forward(
420
+ self,
421
+ x,
422
+ scale_shift=None,
423
+ enable_time=False
424
+ ):
425
+ x = self.project(x, enable_time=enable_time)
426
+ x = self.norm(x)
427
+
428
+ if exists(scale_shift):
429
+ scale, shift = scale_shift
430
+ x = x * (scale + 1) + shift
431
+
432
+ return self.act(x)
433
+
434
+
435
+ class ResnetBlock(nn.Module):
436
+ def __init__(
437
+ self,
438
+ dim,
439
+ dim_out,
440
+ *,
441
+ timestep_cond_dim=None,
442
+ groups=8
443
+ ):
444
+ super().__init__()
445
+
446
+ self.timestep_mlp = None
447
+
448
+ if exists(timestep_cond_dim):
449
+ self.timestep_mlp = nn.Sequential(
450
+ nn.SiLU(),
451
+ nn.Linear(timestep_cond_dim, dim_out * 2)
452
+ )
453
+
454
+ self.block1 = Block(dim, dim_out, groups=groups)
455
+ self.block2 = Block(dim_out, dim_out, groups=groups)
456
+ self.res_conv = PseudoConv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
457
+
458
+ def forward(
459
+ self,
460
+ x,
461
+ timestep_emb=None,
462
+ enable_time=True
463
+ ):
464
+ assert not (exists(timestep_emb) ^ exists(self.timestep_mlp))
465
+
466
+ scale_shift = None
467
+
468
+ if exists(self.timestep_mlp) and exists(timestep_emb):
469
+ time_emb = self.timestep_mlp(timestep_emb)
470
+ to_einsum_eq = 'b c 1 1 1' if x.ndim == 5 else 'b c 1 1'
471
+ time_emb = rearrange(time_emb, f'b c -> {to_einsum_eq}')
472
+ scale_shift = time_emb.chunk(2, dim=1)
473
+
474
+ h = self.block1(x, scale_shift=scale_shift, enable_time=enable_time)
475
+
476
+ h = self.block2(h, enable_time=enable_time)
477
+
478
+ return h + self.res_conv(x)
479
+
480
+
481
+ # pixelshuffle upsamples and downsamples
482
+ # where time dimension can be configured
483
+
484
+ class Downsample(nn.Module):
485
+ def __init__(
486
+ self,
487
+ dim,
488
+ downsample_space=True,
489
+ downsample_time=False,
490
+ nonlin=False
491
+ ):
492
+ super().__init__()
493
+ assert downsample_space or downsample_time
494
+
495
+ self.down_space = nn.Sequential(
496
+ Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=2, p2=2),
497
+ nn.Conv2d(dim * 4, dim, 1, bias=False),
498
+ nn.SiLU() if nonlin else nn.Identity()
499
+ ) if downsample_space else None
500
+
501
+ self.down_time = nn.Sequential(
502
+ Rearrange('b c (f p) h w -> b (c p) f h w', p=2),
503
+ nn.Conv3d(dim * 2, dim, 1, bias=False),
504
+ nn.SiLU() if nonlin else nn.Identity()
505
+ ) if downsample_time else None
506
+
507
+ def forward(
508
+ self,
509
+ x,
510
+ enable_time=True
511
+ ):
512
+ is_video = x.ndim == 5
513
+
514
+ if is_video:
515
+ x = rearrange(x, 'b c f h w -> b f c h w')
516
+ x, ps = pack([x], '* c h w')
517
+
518
+ if exists(self.down_space):
519
+ x = self.down_space(x)
520
+
521
+ if is_video:
522
+ x, = unpack(x, ps, '* c h w')
523
+ x = rearrange(x, 'b f c h w -> b c f h w')
524
+
525
+ if not is_video or not exists(self.down_time) or not enable_time:
526
+ return x
527
+
528
+ x = self.down_time(x)
529
+
530
+ return x
531
+
532
+
533
+ class Upsample(nn.Module):
534
+ def __init__(
535
+ self,
536
+ dim,
537
+ upsample_space=True,
538
+ upsample_time=False,
539
+ nonlin=False
540
+ ):
541
+ super().__init__()
542
+ assert upsample_space or upsample_time
543
+
544
+ self.up_space = nn.Sequential(
545
+ nn.Conv2d(dim, dim * 4, 1),
546
+ nn.SiLU() if nonlin else nn.Identity(),
547
+ Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1=2, p2=2)
548
+ ) if upsample_space else None
549
+
550
+ self.up_time = nn.Sequential(
551
+ nn.Conv3d(dim, dim * 2, 1),
552
+ nn.SiLU() if nonlin else nn.Identity(),
553
+ Rearrange('b (c p) f h w -> b c (f p) h w', p=2)
554
+ ) if upsample_time else None
555
+
556
+ self.init_()
557
+
558
+ def init_(self):
559
+ if exists(self.up_space):
560
+ self.init_conv_(self.up_space[0], 4)
561
+
562
+ if exists(self.up_time):
563
+ self.init_conv_(self.up_time[0], 2)
564
+
565
+ def init_conv_(self, conv, factor):
566
+ o, *remain_dims = conv.weight.shape
567
+ conv_weight = torch.empty(o // factor, *remain_dims)
568
+ nn.init.kaiming_uniform_(conv_weight)
569
+ conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r=factor)
570
+
571
+ conv.weight.data.copy_(conv_weight)
572
+ nn.init.zeros_(conv.bias.data)
573
+
574
+ def forward(
575
+ self,
576
+ x,
577
+ enable_time=True
578
+ ):
579
+ is_video = x.ndim == 5
580
+
581
+ if is_video:
582
+ x = rearrange(x, 'b c f h w -> b f c h w')
583
+ x, ps = pack([x], '* c h w')
584
+
585
+ if exists(self.up_space):
586
+ x = self.up_space(x)
587
+
588
+ if is_video:
589
+ x, = unpack(x, ps, '* c h w')
590
+ x = rearrange(x, 'b f c h w -> b c f h w')
591
+
592
+ if not is_video or not exists(self.up_time) or not enable_time:
593
+ return x
594
+
595
+ x = self.up_time(x)
596
+
597
+ return x
598
+
599
+
600
+ class SpaceTimeUnet(nn.Module):
601
+ def __init__(
602
+ self,
603
+ *,
604
+ dim,
605
+ channels=4,
606
+ dim_mult=(1, 2, 4, 8),
607
+ self_attns=(False, False, False, True),
608
+ temporal_compression=(False, True, True, True),
609
+ resnet_block_depths=(2, 2, 2, 2),
610
+ attn_dim_head=64,
611
+ attn_heads=8,
612
+ condition_on_timestep=False,
613
+ ):
614
+ super().__init__()
615
+ assert len(dim_mult) == len(self_attns) == len(temporal_compression) == len(resnet_block_depths)
616
+ num_layers = len(dim_mult)
617
+
618
+ dims = [dim, *map(lambda mult: mult * dim, dim_mult)]
619
+ dim_in_out = zip(dims[:-1], dims[1:])
620
+
621
+
622
+ # determine the valid multiples of the image size and frames of the video
623
+ self.frame_multiple = 2 ** sum(tuple(map(int, temporal_compression)))
624
+ self.image_size_multiple = 2 ** num_layers
625
+
626
+ # timestep conditioning for DDPM, not to be confused with the time dimension of the video
627
+
628
+ self.to_timestep_cond = None
629
+ timestep_cond_dim = (dim * 4) if condition_on_timestep else None
630
+
631
+ if condition_on_timestep:
632
+ self.to_timestep_cond = nn.Sequential(
633
+ SinusoidalPosEmb(dim),
634
+ nn.Linear(dim, timestep_cond_dim),
635
+ nn.SiLU()
636
+ )
637
+
638
+ # Cross Attention
639
+ cross_attention_D1 = AttentionBlock(1, 64) # 64
640
+ cross_attention_D2 = AttentionBlock(1, 128) # 128
641
+ cross_attention_D3 = AttentionBlock(2, 128) # 256
642
+ cross_attention_D4 = AttentionBlock(4, 128) # 512
643
+
644
+ cross_attention_U1 = AttentionBlock(4, 64) # 256
645
+ cross_attention_U2 = AttentionBlock(2, 64) # 128
646
+ cross_attention_U3 = AttentionBlock(1, 64) # 64
647
+ cross_attention_U4 = AttentionBlock(1, 64) # 64
648
+
649
+ cross_attns_down = (cross_attention_D1, cross_attention_D2, cross_attention_D3, cross_attention_D4)
650
+ cross_attns_up = (cross_attention_U4, cross_attention_U3, cross_attention_U2, cross_attention_U1)
651
+ # layers
652
+
653
+ self.downs = mlist([])
654
+ self.ups = mlist([])
655
+
656
+ attn_kwargs = dict(
657
+ dim_head=attn_dim_head,
658
+ heads=attn_heads
659
+ )
660
+
661
+ mid_dim = dims[-1]
662
+
663
+ self.mid_block1 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
664
+ self.mid_attn = SpatioTemporalAttention(dim=mid_dim)
665
+ self.mid_block2 = ResnetBlock(mid_dim, mid_dim, timestep_cond_dim=timestep_cond_dim)
666
+ for _, self_attend, (dim_in, dim_out), compress_time, resnet_block_depth, cross_attns_d, cross_attns_u in zip(range(num_layers),
667
+ self_attns,
668
+ dim_in_out,
669
+ temporal_compression,
670
+ resnet_block_depths,
671
+ cross_attns_down,
672
+ cross_attns_up):
673
+ assert resnet_block_depth >= 1
674
+ self.downs.append(mlist([
675
+ ResnetBlock(dim_in, dim_out, timestep_cond_dim=timestep_cond_dim),
676
+ mlist([ResnetBlock(dim_out, dim_out) for _ in range(resnet_block_depth)]),
677
+ SpatioTemporalAttention(dim=dim_out, **attn_kwargs) if self_attend else None,
678
+ Downsample(dim_out, downsample_time=compress_time),
679
+ cross_attns_d if exists(cross_attns_d) else None
680
+ ]))
681
+ self.ups.append(mlist([
682
+ ResnetBlock(dim_out * 2, dim_in, timestep_cond_dim=timestep_cond_dim),
683
+ mlist(
684
+ [ResnetBlock(dim_in + (dim_out if ind == 0 else 0), dim_in) for ind in range(resnet_block_depth)]),
685
+ SpatioTemporalAttention(dim=dim_in, **attn_kwargs) if self_attend else None,
686
+ Upsample(dim_out, upsample_time=compress_time),
687
+ cross_attns_u if exists(cross_attns_u) else None
688
+
689
+ ]))
690
+ self.skip_scale = 2 ** -0.5 # paper shows faster convergence
691
+
692
+ self.conv_in = PseudoConv3d(dim=channels, dim_out=dim, kernel_size=7, temporal_kernel_size=3)
693
+ self.conv_out = PseudoConv3d(dim=dim, dim_out=channels, kernel_size=3, temporal_kernel_size=3)
694
+
695
+ def forward(
696
+ self,
697
+ x,
698
+ clip_vae_embed,
699
+ timestep=None,
700
+ enable_time=True
701
+ ):
702
+
703
+ assert not (exists(self.to_timestep_cond) ^ exists(timestep))
704
+ is_video = x.ndim == 5
705
+
706
+ if enable_time and is_video:
707
+ frames = x.shape[2]
708
+ assert divisible_by(frames,
709
+ self.frame_multiple), f'number of frames on the video ({frames}) must be divisible by the frame multiple ({self.frame_multiple})'
710
+
711
+ height, width = x.shape[-2:]
712
+ assert divisible_by(height, self.image_size_multiple) and divisible_by(width,
713
+ self.image_size_multiple), f'height and width of the image or video must be a multiple of {self.image_size_multiple}'
714
+
715
+ # main logic
716
+
717
+ t = self.to_timestep_cond(rearrange(timestep, '... -> (...)')) if exists(timestep) else None
718
+ x = self.conv_in(x, enable_time=enable_time)
719
+
720
+ hiddens = []
721
+ for init_block, blocks, maybe_attention, downsample, cross_attn in self.downs:
722
+ x = init_block(x, t, enable_time=enable_time)
723
+ hiddens.append(x.clone())
724
+ for block in blocks:
725
+ x = block(x, enable_time=enable_time)
726
+ if exists(maybe_attention):
727
+ x = maybe_attention(x, enable_time=enable_time) # only happens in the last layer
728
+ hiddens.append(x.clone())
729
+ x = downsample(x, enable_time=enable_time)
730
+ if exists(cross_attn):
731
+ x = cross_attn(x, clip_vae_embed)
732
+
733
+ x = self.mid_block1(x, t, enable_time=enable_time)
734
+ x = self.mid_attn(x, enable_time=enable_time)
735
+ x = self.mid_block2(x, t, enable_time=enable_time)
736
+
737
+ for init_block, blocks, maybe_attention, upsample, cross_attn in reversed(self.ups):
738
+ x = upsample(x, enable_time=enable_time)
739
+ x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
740
+ x = init_block(x, t, enable_time=enable_time)
741
+ x = torch.cat((hiddens.pop() * self.skip_scale, x), dim=1)
742
+ for block in blocks:
743
+ x = block(x, enable_time=enable_time)
744
+ if exists(maybe_attention):
745
+ x = maybe_attention(x, enable_time=enable_time)
746
+ if exists(cross_attn):
747
+ x = cross_attn(x, clip_vae_embed)
748
+
749
+ x = self.conv_out(x, enable_time=enable_time)
750
+ return x
751
+
752
+ if __name__ == '__main__':
753
+ Net = SpaceTimeUnet(
754
+ dim=64,
755
+ channels=3,
756
+ dim_mult=(1, 2, 4, 8),
757
+ temporal_compression=(False, False, False, True),
758
+ self_attns=(False, False, False, True),
759
+ condition_on_timestep=False)
760
+
761
+ x = torch.randn([1,8,3,32,32])
762
+ sample_output = Net(x.permute(0, 2, 1, 3, 4))
models/unet_dual_encoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Load pretrained 2D UNet and modify with temporal attention
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.utils.checkpoint
5
+ from einops import rearrange
6
+
7
+ from diffusers.models import UNet2DConditionModel
8
+
9
+ def get_unet(pretrained_model_name_or_path, revision, resolution=256, n_poses=5):
10
+ # Load pretrained UNet layers
11
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4",
12
+ subfolder="unet",
13
+ revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c",
14
+ cache_dir="checkpoints/unet")
15
+
16
+ # Modify input layer to have 1 additional input channels (pose)
17
+ weights = unet.conv_in.weight.clone()
18
+ unet.conv_in = nn.Conv2d(4 + 2*n_poses, weights.shape[0], kernel_size=3, padding=(1, 1)) # input noise + n poses
19
+ with torch.no_grad():
20
+ unet.conv_in.weight[:, :4] = weights # original weights
21
+ unet.conv_in.weight[:, 3:] = torch.zeros(unet.conv_in.weight[:, 3:].shape) # new weights initialized to zero
22
+
23
+ return unet
24
+
25
+ '''
26
+ This module takes in CLIP + VAE embeddings and outputs CLIP-compatible embeddings.
27
+ '''
28
+ class Embedding_Adapter(nn.Module):
29
+ def __init__(self, input_nc=38, output_nc=4, norm_layer=nn.InstanceNorm2d, chkpt=None):
30
+ super(Embedding_Adapter, self).__init__()
31
+
32
+ self.save_method_name = "adapter"
33
+
34
+ self.pool = nn.MaxPool2d(2)
35
+ self.vae2clip = nn.Linear(1280, 768)
36
+
37
+ self.linear1 = nn.Linear(54, 50) # 50 x 54 shape
38
+
39
+ # initialize weights
40
+ with torch.no_grad():
41
+ self.linear1.weight = nn.Parameter(torch.eye(50, 54))
42
+
43
+ if chkpt is not None:
44
+ pass
45
+
46
+ def forward(self, clip, vae):
47
+
48
+ vae = self.pool(vae) # 1 4 80 64 --> 1 4 40 32
49
+ vae = rearrange(vae, 'b c h w -> b c (h w)') # 1 4 20 16 --> 1 4 1280
50
+
51
+ vae = self.vae2clip(vae) # 1 4 768
52
+
53
+ # Concatenate
54
+ concat = torch.cat((clip, vae), 1)
55
+
56
+ # Encode
57
+
58
+ concat = rearrange(concat, 'b c d -> b d c')
59
+ concat = self.linear1(concat)
60
+ concat = rearrange(concat, 'b d c -> b c d')
61
+
62
+ return concat
project_latent_space.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ import os.path as osp
3
+ import cv2
4
+ import torch
5
+ import os, argparse
6
+ import tqdm
7
+ from PIL import Image
8
+ from diffusers import AutoencoderKL
9
+ import random
10
+ device = torch.device("cuda")
11
+
12
+ parser = argparse.ArgumentParser(description="Configuration of the tensor projection.")
13
+ parser.add_argument('--dataset', default="fashion_dataset/train", help="Path to the dataset")
14
+ parser.add_argument('--output_dir', default="fashion_dataset_tensor", help="Path to save the tensors")
15
+ args = parser.parse_args()
16
+
17
+ vae = AutoencoderKL.from_pretrained(
18
+ "CompVis/stable-diffusion-v1-4",
19
+ subfolder="vae",
20
+ revision="ebb811dd71cdc38a204ecbdd6ac5d580f529fd8c"
21
+ ).to(device)
22
+ vae.requires_grad_(False)
23
+
24
+ @torch.no_grad()
25
+ def VAE_encode(video):
26
+ for i in range(video.shape[0]):
27
+ image = video[i, :, :, :]
28
+ image = image.unsqueeze(0)
29
+ if i == 0:
30
+ init_latent_dist = vae.encode(image).latent_dist.sample()
31
+ init_latent_dist *= 0.18215
32
+ encoded_video = (init_latent_dist).unsqueeze(1)
33
+ else:
34
+ init_latent_dist = vae.encode(image).latent_dist.sample()
35
+ init_latent_dist *= 0.18215
36
+ encoded_video = torch.cat([encoded_video, (init_latent_dist).unsqueeze(1)], 1)
37
+ return encoded_video
38
+
39
+ def get_transform():
40
+ image_transforms = transforms.Compose(
41
+ [
42
+ transforms.Resize((640, 512), interpolation=transforms.InterpolationMode.BILINEAR),
43
+ transforms.ToTensor(),
44
+ ])
45
+ return image_transforms
46
+
47
+
48
+ path = osp.join(args.dataset)
49
+ video_names = os.listdir(path)
50
+ transform = get_transform()
51
+
52
+ if not os.path.exists(args.output_dir):
53
+ os.makedirs(args.output_dir)
54
+
55
+ for video_name in tqdm.tqdm(video_names):
56
+ cap = cv2.VideoCapture(osp.join(path, video_name))
57
+ numberOfFrames = 241
58
+ number = random.randint(0, numberOfFrames - 70)
59
+ for i in range(number, number + 70):
60
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
61
+ _, frame = cap.read()
62
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
+ frame = Image.fromarray(frame)
64
+ frame = transform(frame)
65
+ if i == number:
66
+ inputImage = frame
67
+ torch.save(inputImage, args.output_dir + "/" + video_name[:-4] + "_image.pt")
68
+ frame = frame.unsqueeze(0)
69
+ restOfVideo = torch.clone(frame)
70
+ else:
71
+ frame = frame.unsqueeze(0)
72
+ restOfVideo = torch.cat([restOfVideo, frame], 0)
73
+ restOfVideo = restOfVideo.to(device=device)
74
+ vae_video = VAE_encode(restOfVideo).detach().cpu()[0]
75
+ torch.save(vae_video, args.output_dir + "/" + video_name[:-4] + ".pt")
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.1
2
+ certifi==2023.11.17
3
+ charset-normalizer==3.3.2
4
+ diffusers==0.14.0
5
+ einops==0.7.0
6
+ filelock==3.13.1
7
+ fsspec==2023.12.2
8
+ huggingface-hub==0.20.2
9
+ idna==3.6
10
+ importlib-metadata==7.0.1
11
+ numpy==1.26.3
12
+ opencv-python==4.9.0.80
13
+ packaging==23.2
14
+ pillow==10.2.0
15
+ protobuf==4.25.2
16
+ psutil==5.9.7
17
+ PyYAML==6.0.1
18
+ regex==2023.12.25
19
+ requests==2.31.0
20
+ safetensors==0.4.1
21
+ tensorboardX==2.6.2.2
22
+ tokenizers==0.15.0
23
+ torch==1.11.0+cu113
24
+ torchaudio==0.11.0+cu113
25
+ torchvision==0.12.0+cu113
26
+ tqdm==4.66.1
27
+ transformers==4.36.2
28
+ typing_extensions==4.9.0
29
+ urllib3==2.1.0
30
+ zipp==3.17.0
sample/blue.jpg ADDED
sample/green.jpg ADDED
sample/silver.jpg ADDED
src/deps/__init__.py ADDED
File without changes
src/deps/facial_recognition/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Copy-pasted from https://github.com/orpatashnik/StyleCLIP/tree/main/models/facial_recognition/__init__.py
3
+ """
src/deps/facial_recognition/helpers.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copy-pasted from https://github.com/orpatashnik/StyleCLIP/tree/main/models/facial_recognition/helpers.py
3
+ """
4
+
5
+ from collections import namedtuple
6
+ import torch
7
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
8
+
9
+ """
10
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
11
+ """
12
+
13
+
14
+ class Flatten(Module):
15
+ def forward(self, input):
16
+ return input.view(input.size(0), -1)
17
+
18
+
19
+ def l2_norm(input, axis=1):
20
+ norm = torch.norm(input, 2, axis, True)
21
+ output = torch.div(input, norm)
22
+ return output
23
+
24
+
25
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
26
+ """ A named tuple describing a ResNet block. """
27
+
28
+
29
+ def get_block(in_channel, depth, num_units, stride=2):
30
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
31
+
32
+
33
+ def get_blocks(num_layers):
34
+ if num_layers == 50:
35
+ blocks = [
36
+ get_block(in_channel=64, depth=64, num_units=3),
37
+ get_block(in_channel=64, depth=128, num_units=4),
38
+ get_block(in_channel=128, depth=256, num_units=14),
39
+ get_block(in_channel=256, depth=512, num_units=3)
40
+ ]
41
+ elif num_layers == 100:
42
+ blocks = [
43
+ get_block(in_channel=64, depth=64, num_units=3),
44
+ get_block(in_channel=64, depth=128, num_units=13),
45
+ get_block(in_channel=128, depth=256, num_units=30),
46
+ get_block(in_channel=256, depth=512, num_units=3)
47
+ ]
48
+ elif num_layers == 152:
49
+ blocks = [
50
+ get_block(in_channel=64, depth=64, num_units=3),
51
+ get_block(in_channel=64, depth=128, num_units=8),
52
+ get_block(in_channel=128, depth=256, num_units=36),
53
+ get_block(in_channel=256, depth=512, num_units=3)
54
+ ]
55
+ else:
56
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
57
+ return blocks
58
+
59
+
60
+ class SEModule(Module):
61
+ def __init__(self, channels, reduction):
62
+ super(SEModule, self).__init__()
63
+ self.avg_pool = AdaptiveAvgPool2d(1)
64
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
65
+ self.relu = ReLU(inplace=True)
66
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
67
+ self.sigmoid = Sigmoid()
68
+
69
+ def forward(self, x):
70
+ module_input = x
71
+ x = self.avg_pool(x)
72
+ x = self.fc1(x)
73
+ x = self.relu(x)
74
+ x = self.fc2(x)
75
+ x = self.sigmoid(x)
76
+ return module_input * x
77
+
78
+
79
+ class bottleneck_IR(Module):
80
+ def __init__(self, in_channel, depth, stride):
81
+ super(bottleneck_IR, self).__init__()
82
+ if in_channel == depth:
83
+ self.shortcut_layer = MaxPool2d(1, stride)
84
+ else:
85
+ self.shortcut_layer = Sequential(
86
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
87
+ BatchNorm2d(depth)
88
+ )
89
+ self.res_layer = Sequential(
90
+ BatchNorm2d(in_channel),
91
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
92
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
93
+ )
94
+
95
+ def forward(self, x):
96
+ shortcut = self.shortcut_layer(x)
97
+ res = self.res_layer(x)
98
+ return res + shortcut
99
+
100
+
101
+ class bottleneck_IR_SE(Module):
102
+ def __init__(self, in_channel, depth, stride):
103
+ super(bottleneck_IR_SE, self).__init__()
104
+ if in_channel == depth:
105
+ self.shortcut_layer = MaxPool2d(1, stride)
106
+ else:
107
+ self.shortcut_layer = Sequential(
108
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
109
+ BatchNorm2d(depth)
110
+ )
111
+ self.res_layer = Sequential(
112
+ BatchNorm2d(in_channel),
113
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
114
+ PReLU(depth),
115
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
116
+ BatchNorm2d(depth),
117
+ SEModule(depth, 16)
118
+ )
119
+
120
+ def forward(self, x):
121
+ shortcut = self.shortcut_layer(x)
122
+ res = self.res_layer(x)
123
+ return res + shortcut
src/deps/facial_recognition/model_irse.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copy-pasted from https://github.com/orpatashnik/StyleCLIP/tree/main/models/facial_recognition/model_irse.py
3
+ """
4
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
5
+ from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
6
+
7
+ """
8
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
9
+ """
10
+
11
+ class Backbone(Module):
12
+ WEIGHTS_URL = "https://www.dropbox.com/s/n6xicva1lrghb5w/model_ir_se50.pth?dl=1"
13
+
14
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
15
+ super(Backbone, self).__init__()
16
+ assert input_size in [112, 224], "input_size should be 112 or 224"
17
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
18
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
19
+ blocks = get_blocks(num_layers)
20
+ if mode == 'ir':
21
+ unit_module = bottleneck_IR
22
+ elif mode == 'ir_se':
23
+ unit_module = bottleneck_IR_SE
24
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
25
+ BatchNorm2d(64),
26
+ PReLU(64))
27
+ if input_size == 112:
28
+ self.output_layer = Sequential(BatchNorm2d(512),
29
+ Dropout(drop_ratio),
30
+ Flatten(),
31
+ Linear(512 * 7 * 7, 512),
32
+ BatchNorm1d(512, affine=affine))
33
+ else:
34
+ self.output_layer = Sequential(BatchNorm2d(512),
35
+ Dropout(drop_ratio),
36
+ Flatten(),
37
+ Linear(512 * 14 * 14, 512),
38
+ BatchNorm1d(512, affine=affine))
39
+
40
+ modules = []
41
+ for block in blocks:
42
+ for bottleneck in block:
43
+ modules.append(unit_module(bottleneck.in_channel,
44
+ bottleneck.depth,
45
+ bottleneck.stride))
46
+ self.body = Sequential(*modules)
47
+
48
+ def forward(self, x):
49
+ x = self.input_layer(x)
50
+ x = self.body(x)
51
+ x = self.output_layer(x)
52
+ return l2_norm(x)
53
+
54
+
55
+ def IR_50(input_size):
56
+ """Constructs a ir-50 model."""
57
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
58
+ return model
59
+
60
+
61
+ def IR_101(input_size):
62
+ """Constructs a ir-101 model."""
63
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
64
+ return model
65
+
66
+
67
+ def IR_152(input_size):
68
+ """Constructs a ir-152 model."""
69
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
70
+ return model
71
+
72
+
73
+ def IR_SE_50(input_size):
74
+ """Constructs a ir_se-50 model."""
75
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
76
+ return model
77
+
78
+
79
+ def IR_SE_101(input_size):
80
+ """Constructs a ir_se-101 model."""
81
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
82
+ return model
83
+
84
+
85
+ def IR_SE_152(input_size):
86
+ """Constructs a ir_se-152 model."""
87
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
88
+ return model
src/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
src/dnnlib/util.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union, Dict
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+ def to_dict(self) -> Dict:
56
+ return {k: (v.to_dict() if isinstance(v, EasyDict) else v) for (k, v) in self.items()}
57
+
58
+
59
+ class Logger(object):
60
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
61
+
62
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
63
+ self.file = None
64
+
65
+ if file_name is not None:
66
+ self.file = open(file_name, file_mode)
67
+
68
+ self.should_flush = should_flush
69
+ self.stdout = sys.stdout
70
+ self.stderr = sys.stderr
71
+
72
+ sys.stdout = self
73
+ sys.stderr = self
74
+
75
+ def __enter__(self) -> "Logger":
76
+ return self
77
+
78
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
79
+ self.close()
80
+
81
+ def write(self, text: Union[str, bytes]) -> None:
82
+ """Write text to stdout (and a file) and optionally flush."""
83
+ if isinstance(text, bytes):
84
+ text = text.decode()
85
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
86
+ return
87
+
88
+ if self.file is not None:
89
+ self.file.write(text)
90
+
91
+ self.stdout.write(text)
92
+
93
+ if self.should_flush:
94
+ self.flush()
95
+
96
+ def flush(self) -> None:
97
+ """Flush written text to both stdout and a file, if open."""
98
+ if self.file is not None:
99
+ self.file.flush()
100
+
101
+ self.stdout.flush()
102
+
103
+ def close(self) -> None:
104
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
105
+ self.flush()
106
+
107
+ # if using multiple loggers, prevent closing in wrong order
108
+ if sys.stdout is self:
109
+ sys.stdout = self.stdout
110
+ if sys.stderr is self:
111
+ sys.stderr = self.stderr
112
+
113
+ if self.file is not None:
114
+ self.file.close()
115
+ self.file = None
116
+
117
+
118
+ # Cache directories
119
+ # ------------------------------------------------------------------------------------------
120
+
121
+ _dnnlib_cache_dir = None
122
+
123
+ def set_cache_dir(path: str) -> None:
124
+ global _dnnlib_cache_dir
125
+ _dnnlib_cache_dir = path
126
+
127
+ def make_cache_dir_path(*paths: str) -> str:
128
+ if _dnnlib_cache_dir is not None:
129
+ return os.path.join(_dnnlib_cache_dir, *paths)
130
+ if 'DNNLIB_CACHE_DIR' in os.environ:
131
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
132
+ if 'HOME' in os.environ:
133
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
134
+ if 'USERPROFILE' in os.environ:
135
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
136
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
137
+
138
+ # Small util functions
139
+ # ------------------------------------------------------------------------------------------
140
+
141
+
142
+ def format_time(seconds: Union[int, float]) -> str:
143
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
144
+ s = int(np.rint(seconds))
145
+
146
+ if s < 60:
147
+ return "{0}s".format(s)
148
+ elif s < 60 * 60:
149
+ return "{0}m {1:02}s".format(s // 60, s % 60)
150
+ elif s < 24 * 60 * 60:
151
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
152
+ else:
153
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
154
+
155
+
156
+ def ask_yes_no(question: str) -> bool:
157
+ """Ask the user the question until the user inputs a valid answer."""
158
+ while True:
159
+ try:
160
+ print("{0} [y/n]".format(question))
161
+ return strtobool(input().lower())
162
+ except ValueError:
163
+ pass
164
+
165
+
166
+ def tuple_product(t: Tuple) -> Any:
167
+ """Calculate the product of the tuple elements."""
168
+ result = 1
169
+
170
+ for v in t:
171
+ result *= v
172
+
173
+ return result
174
+
175
+
176
+ _str_to_ctype = {
177
+ "uint8": ctypes.c_ubyte,
178
+ "uint16": ctypes.c_uint16,
179
+ "uint32": ctypes.c_uint32,
180
+ "uint64": ctypes.c_uint64,
181
+ "int8": ctypes.c_byte,
182
+ "int16": ctypes.c_int16,
183
+ "int32": ctypes.c_int32,
184
+ "int64": ctypes.c_int64,
185
+ "float32": ctypes.c_float,
186
+ "float64": ctypes.c_double
187
+ }
188
+
189
+
190
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
191
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
192
+ type_str = None
193
+
194
+ if isinstance(type_obj, str):
195
+ type_str = type_obj
196
+ elif hasattr(type_obj, "__name__"):
197
+ type_str = type_obj.__name__
198
+ elif hasattr(type_obj, "name"):
199
+ type_str = type_obj.name
200
+ else:
201
+ raise RuntimeError("Cannot infer type name from input")
202
+
203
+ assert type_str in _str_to_ctype.keys()
204
+
205
+ my_dtype = np.dtype(type_str)
206
+ my_ctype = _str_to_ctype[type_str]
207
+
208
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
209
+
210
+ return my_dtype, my_ctype
211
+
212
+
213
+ def is_pickleable(obj: Any) -> bool:
214
+ try:
215
+ with io.BytesIO() as stream:
216
+ pickle.dump(obj, stream)
217
+ return True
218
+ except:
219
+ return False
220
+
221
+
222
+ # Functionality to import modules/objects by name, and call functions by name
223
+ # ------------------------------------------------------------------------------------------
224
+
225
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
226
+ """Searches for the underlying module behind the name to some python object.
227
+ Returns the module and the object name (original name with module part removed)."""
228
+
229
+ # allow convenience shorthands, substitute them by full names
230
+ obj_name = re.sub("^np.", "numpy.", obj_name)
231
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
232
+
233
+ # list alternatives for (module_name, local_obj_name)
234
+ parts = obj_name.split(".")
235
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
236
+
237
+ # try each alternative in turn
238
+ for module_name, local_obj_name in name_pairs:
239
+ try:
240
+ module = importlib.import_module(module_name) # may raise ImportError
241
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
242
+ return module, local_obj_name
243
+ except:
244
+ pass
245
+
246
+ # maybe some of the modules themselves contain errors?
247
+ for module_name, _local_obj_name in name_pairs:
248
+ try:
249
+ importlib.import_module(module_name) # may raise ImportError
250
+ except ImportError:
251
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
252
+ raise
253
+
254
+ # maybe the requested attribute is missing?
255
+ for module_name, local_obj_name in name_pairs:
256
+ try:
257
+ module = importlib.import_module(module_name) # may raise ImportError
258
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
259
+ except ImportError:
260
+ pass
261
+
262
+ # we are out of luck, but we have no idea why
263
+ raise ImportError(obj_name)
264
+
265
+
266
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
267
+ """Traverses the object name and returns the last (rightmost) python object."""
268
+ if obj_name == '':
269
+ return module
270
+ obj = module
271
+ for part in obj_name.split("."):
272
+ obj = getattr(obj, part)
273
+ return obj
274
+
275
+
276
+ def get_obj_by_name(name: str) -> Any:
277
+ """Finds the python object with the given name."""
278
+ module, obj_name = get_module_from_obj_name(name)
279
+ return get_obj_from_module(module, obj_name)
280
+
281
+
282
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
283
+ """Finds the python object with the given name and calls it as a function."""
284
+ assert func_name is not None
285
+ func_obj = get_obj_by_name(func_name)
286
+ assert callable(func_obj)
287
+ return func_obj(*args, **kwargs)
288
+
289
+
290
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
291
+ """Finds the python class with the given name and constructs it with the given arguments."""
292
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
293
+
294
+
295
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
296
+ """Get the directory path of the module containing the given object name."""
297
+ module, _ = get_module_from_obj_name(obj_name)
298
+ return os.path.dirname(inspect.getfile(module))
299
+
300
+
301
+ def is_top_level_function(obj: Any) -> bool:
302
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
303
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
304
+
305
+
306
+ def get_top_level_function_name(obj: Any) -> str:
307
+ """Return the fully-qualified name of a top-level function."""
308
+ assert is_top_level_function(obj)
309
+ module = obj.__module__
310
+ if module == '__main__':
311
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
312
+ return module + "." + obj.__name__
313
+
314
+
315
+ # File system helpers
316
+ # ------------------------------------------------------------------------------------------
317
+
318
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
319
+ """List all files recursively in a given directory while ignoring given file and directory names.
320
+ Returns list of tuples containing both absolute and relative paths."""
321
+ assert os.path.isdir(dir_path)
322
+ base_name = os.path.basename(os.path.normpath(dir_path))
323
+
324
+ if ignores is None:
325
+ ignores = []
326
+
327
+ result = []
328
+
329
+ for root, dirs, files in os.walk(dir_path, topdown=True):
330
+ for ignore_ in ignores:
331
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
332
+
333
+ # dirs need to be edited in-place
334
+ for d in dirs_to_remove:
335
+ dirs.remove(d)
336
+
337
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
338
+
339
+ absolute_paths = [os.path.join(root, f) for f in files]
340
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
341
+
342
+ if add_base_to_relative:
343
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
344
+
345
+ assert len(absolute_paths) == len(relative_paths)
346
+ result += zip(absolute_paths, relative_paths)
347
+
348
+ return result
349
+
350
+
351
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
352
+ """Takes in a list of tuples of (src, dst) paths and copies files.
353
+ Will create all necessary directories."""
354
+ for file in files:
355
+ target_dir_name = os.path.dirname(file[1])
356
+
357
+ # will create all intermediate-level directories
358
+ if not os.path.exists(target_dir_name):
359
+ os.makedirs(target_dir_name)
360
+
361
+ shutil.copyfile(file[0], file[1])
362
+
363
+
364
+ # URL helpers
365
+ # ------------------------------------------------------------------------------------------
366
+
367
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
368
+ """Determine whether the given object is a valid URL string."""
369
+ if not isinstance(obj, str) or not "://" in obj:
370
+ return False
371
+ if allow_file_urls and obj.startswith('file://'):
372
+ return True
373
+ try:
374
+ res = requests.compat.urlparse(obj)
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
378
+ if not res.scheme or not res.netloc or not "." in res.netloc:
379
+ return False
380
+ except:
381
+ return False
382
+ return True
383
+
384
+
385
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
386
+ """Download the given URL and return a binary-mode file object to access the data."""
387
+ assert num_attempts >= 1
388
+ assert not (return_filename and (not cache))
389
+
390
+ # Doesn't look like an URL scheme so interpret it as a local filename.
391
+ if not re.match('^[a-z]+://', url):
392
+ return url if return_filename else open(url, "rb")
393
+
394
+ # Handle file URLs. This code handles unusual file:// patterns that
395
+ # arise on Windows:
396
+ #
397
+ # file:///c:/foo.txt
398
+ #
399
+ # which would translate to a local '/c:/foo.txt' filename that's
400
+ # invalid. Drop the forward slash for such pathnames.
401
+ #
402
+ # If you touch this code path, you should test it on both Linux and
403
+ # Windows.
404
+ #
405
+ # Some internet resources suggest using urllib.request.url2pathname() but
406
+ # but that converts forward slashes to backslashes and this causes
407
+ # its own set of problems.
408
+ if url.startswith('file://'):
409
+ filename = urllib.parse.urlparse(url).path
410
+ if re.match(r'^/[a-zA-Z]:', filename):
411
+ filename = filename[1:]
412
+ return filename if return_filename else open(filename, "rb")
413
+
414
+ assert is_url(url)
415
+
416
+ # Lookup from cache.
417
+ if cache_dir is None:
418
+ cache_dir = make_cache_dir_path('downloads')
419
+
420
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
421
+ if cache:
422
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
423
+ if len(cache_files) == 1:
424
+ filename = cache_files[0]
425
+ return filename if return_filename else open(filename, "rb")
426
+
427
+ # Download.
428
+ url_name = None
429
+ url_data = None
430
+ with requests.Session() as session:
431
+ if verbose:
432
+ print("Downloading %s ..." % url, end="", flush=True)
433
+ for attempts_left in reversed(range(num_attempts)):
434
+ try:
435
+ with session.get(url) as res:
436
+ res.raise_for_status()
437
+ if len(res.content) == 0:
438
+ raise IOError("No data received")
439
+
440
+ if len(res.content) < 8192:
441
+ content_str = res.content.decode("utf-8")
442
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
443
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
444
+ if len(links) == 1:
445
+ url = requests.compat.urljoin(url, links[0])
446
+ raise IOError("Google Drive virus checker nag")
447
+ if "Google Drive - Quota exceeded" in content_str:
448
+ raise IOError("Google Drive download quota exceeded -- please try again later")
449
+
450
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
451
+ url_name = match[1] if match else url
452
+ url_data = res.content
453
+ if verbose:
454
+ print(" done")
455
+ break
456
+ except KeyboardInterrupt:
457
+ raise
458
+ except:
459
+ if not attempts_left:
460
+ if verbose:
461
+ print(" failed")
462
+ raise
463
+ if verbose:
464
+ print(".", end="", flush=True)
465
+
466
+ # Save to cache.
467
+ if cache:
468
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
469
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
470
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
471
+ os.makedirs(cache_dir, exist_ok=True)
472
+ with open(temp_file, "wb") as f:
473
+ f.write(url_data)
474
+ os.replace(temp_file, cache_file) # atomic
475
+ if return_filename:
476
+ return cache_file
477
+
478
+ # Return data as file object.
479
+ assert not return_filename
480
+ return io.BytesIO(url_data)
src/infra/__init__.py ADDED
File without changes
src/infra/experiments.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------------------
2
+ # Here, we keep the experiments HPs in case we want to do mass-launching via SLURM
3
+ #----------------------------------------------------------------------------
4
+
5
+ mocogan_sg2:
6
+ common_args:
7
+ model: mocogan
8
+ training.batch: 16
9
+ dataset.max_num_frames: 32
10
+ experiments:
11
+ b16_mnf16:
12
+ sampling: traditional_16
13
+ dataset.max_num_frames: 16
14
+ model.generator.motion.long_history: false
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ ffs:
19
+ common_args:
20
+ sampling.num_frames_per_video: 3
21
+ experiments:
22
+ mnf1024_sfpm32_minperiod16: {}
23
+ mnf1024_sfpm32_minperiod32:
24
+ model.generator.time_enc.min_period_len: 32
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ sky_timelapse:
29
+ common_args:
30
+ sampling.num_frames_per_video: 3
31
+ experiments:
32
+ mnf1024_sfpm32_minperiod16: {}
33
+ mnf1024_sfpm256_minperiod256:
34
+ model.generator.motion.motion_z_distance: 256
35
+ model.generator.time_enc.min_period_len: 256
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ highres:
40
+ common_args:
41
+ training.metrics: \"fvd2048_16f,fvd2048_128f_subsample,fid50k_full\"
42
+ training.batch: 16
43
+ sampling.num_frames_per_video: 2
44
+ experiments:
45
+ mnf1024_sfpm32_minperiod16_batch16: {}
46
+ mnf32_sfpm32_minperiod16_batch16:
47
+ dataset.max_num_frames: 32
48
+
49
+ #----------------------------------------------------------------------------
50
+
51
+ cond_ablation_ffs:
52
+ common_args:
53
+ sampling.num_frames_per_video: 3
54
+ experiments:
55
+ hyper_mod:
56
+ model.discriminator.hyper_type: hyper
57
+ without_proj_cond:
58
+ model.discriminator.dummy_c: true
59
+
60
+ #----------------------------------------------------------------------------
src/infra/launch.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run a __reproducible__ experiment on __allocated__ resources
3
+ It submits a slurm job(s) with the given hyperparams which will then execute `slurm_job.py`
4
+ This is the main entry-point
5
+ """
6
+
7
+ import os
8
+ os.environ["HYDRA_FULL_ERROR"] = "1"
9
+
10
+ import subprocess
11
+ import re
12
+
13
+ import hydra
14
+ from omegaconf import DictConfig, OmegaConf
15
+ from pathlib import Path
16
+
17
+ from utils import create_project_dir, recursive_instantiate
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ HYDRA_ARGS = "hydra.run.dir=. hydra.output_subdir=null hydra/job_logging=disabled hydra/hydra_logging=disabled"
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ @hydra.main(config_path="../../configs", config_name="config.yaml")
26
+ def main(cfg: DictConfig):
27
+ recursive_instantiate(cfg)
28
+ OmegaConf.set_struct(cfg, True)
29
+ cfg.env.project_path = str(cfg.env.project_path) # This is needed to evaluate ${hydra:runtime.cwd}
30
+
31
+ before_train_cmd = '\n'.join(cfg.env.before_train_commands)
32
+ before_train_cmd = before_train_cmd + '\n' if len(before_train_cmd) > 0 else ''
33
+ torch_extensions_dir = os.environ.get('TORCH_EXTENSIONS_DIR', cfg.env.torch_extensions_dir)
34
+ training_cmd = f'{before_train_cmd}TORCH_EXTENSIONS_DIR={torch_extensions_dir} cd {cfg.project_release_dir} && {cfg.env.python_bin} src/train.py {HYDRA_ARGS}'
35
+ quiet = cfg.get('quiet', False)
36
+ training_cmd_save_path = os.path.join(cfg.project_release_dir, 'training_cmd.sh')
37
+ cfg_save_path = os.path.join(cfg.project_release_dir, 'experiment_config.yaml')
38
+
39
+ if not quiet:
40
+ print('<=== TRAINING COMMAND START ===>')
41
+ print(training_cmd)
42
+ print('<=== TRAINING COMMAND END ===>')
43
+
44
+ is_running_from_scratch = True
45
+
46
+ if cfg.training.resume == "latest" and os.path.isdir(cfg.project_release_dir) and os.path.isfile(training_cmd_save_path) and os.path.isfile(cfg_save_path):
47
+ is_running_from_scratch = False
48
+ if not quiet:
49
+ print("We are going to resume the training and the experiment already exists. " \
50
+ "That's why the provided config/training_cmd are discarded and the project dir is not created.")
51
+
52
+ if is_running_from_scratch and not cfg.print_only:
53
+ create_project_dir(
54
+ cfg.project_release_dir,
55
+ cfg.env.objects_to_copy,
56
+ cfg.env.symlinks_to_create,
57
+ quiet=quiet,
58
+ ignore_uncommited_changes=cfg.get('ignore_uncommited_changes', False),
59
+ overwrite=cfg.get('overwrite', False))
60
+
61
+ with open(training_cmd_save_path, 'w') as f:
62
+ f.write(training_cmd + '\n')
63
+ if not quiet:
64
+ print(f'Saved training command in {training_cmd_save_path}')
65
+
66
+ with open(cfg_save_path, 'w') as f:
67
+ OmegaConf.save(config=cfg, f=f)
68
+ if not quiet:
69
+ print(f'Saved config in {cfg_save_path}')
70
+
71
+ if not cfg.print_only:
72
+ os.chdir(cfg.project_release_dir)
73
+
74
+ if cfg.slurm:
75
+ assert Path(cfg.dataset.path_for_slurm_job).exists()
76
+
77
+ curr_job_id = None
78
+
79
+ for i in range(cfg.job_sequence_length):
80
+ if i == 0:
81
+ deps_args_str = ''
82
+ else:
83
+ deps_args_str = f'--dependency=afterany:{curr_job_id}'
84
+
85
+ # Submitting the slurm job
86
+ qos_arg_str = f'--account {os.environ["PRIORITY_BOOST_ACC"]}' if cfg.use_qos else ''
87
+ output_file_arg_str = f'--output {cfg.project_release_dir}/slurm_{i}.log'
88
+ submit_job_cmd = f'sbatch {cfg.sbatch_args_str} {output_file_arg_str} {qos_arg_str} --export=ALL,{cfg.env_args_str} {deps_args_str} src/infra/slurm_job_proxy.sh'
89
+
90
+ if cfg.print_only:
91
+ print(submit_job_cmd)
92
+ curr_job_id = "DUMMY_JOB_ID"
93
+ else:
94
+ result = subprocess.run(submit_job_cmd, stdout=subprocess.PIPE, shell=True)
95
+ output_str = result.stdout.decode("utf-8").strip("\n") # It has a format of "Submitted batch job 17033559"
96
+ if not quiet or i == 0:
97
+ print(output_str)
98
+ curr_job_id = re.findall(r"^Submitted\ batch\ job\ \d{5,8}$", output_str)
99
+ assert len(curr_job_id) == 1, f"Bad output: `{output_str}`"
100
+ curr_job_id = int(curr_job_id[0][len('Submitted batch job '):])
101
+ else:
102
+ assert cfg.job_sequence_length == 1, "You can use a job sequence only when running via slurm."
103
+ if cfg.print_only:
104
+ print(training_cmd)
105
+ else:
106
+ os.system(training_cmd)
107
+
108
+ #----------------------------------------------------------------------------
109
+
110
+ if __name__ == "__main__":
111
+ main()
112
+
113
+ #----------------------------------------------------------------------------
src/infra/slurm_batch_launch.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import copy
4
+ from typing import List, Dict, Optional
5
+ from omegaconf import OmegaConf, DictConfig
6
+ from src.infra.utils import cfg_to_args_str
7
+
8
+ #----------------------------------------------------------------------------
9
+
10
+ HYDRA_ARGS = "hydra.run.dir=. hydra.output_subdir=null hydra/job_logging=disabled hydra/hydra_logging=disabled"
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def batch_launch(launcher: str, experiments_dir: os.PathLike, cfg: DictConfig, datasets: List[str], print_only: bool, time: str, use_qos: bool=False, other_args: Dict={}, num_gpus: int=4, *args, **kwargs):
15
+ for dataset in datasets:
16
+ for exp_args in construct_experiments_args(cfg, *args, **kwargs):
17
+ exp_args['sbatch_args.time'] = time
18
+ exp_args['experiments_dir'] = experiments_dir
19
+ exp_args['dataset'] = dataset
20
+ exp_args['env'] = 'ibex'
21
+ exp_args['use_qos'] = use_qos
22
+ exp_args = {**exp_args, **other_args}
23
+ curr_exp_args_str = cfg_to_args_str(exp_args, use_dashes=False)
24
+ launching_command = f"{launcher} num_gpus={num_gpus} {curr_exp_args_str}"
25
+
26
+ if print_only:
27
+ os.makedirs(exp_args['experiments_dir'], exist_ok=True)
28
+ print(launching_command)
29
+ else:
30
+ os.system(launching_command)
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def construct_experiments_args(cfg: DictConfig, experiments_list: Optional[List[str]]=None, suffix: str="") -> List[Dict]:
35
+ args_dicts = []
36
+ common_cfg = cfg.get('common_args', {})
37
+
38
+ for exp_name, exp_cfg in to_dict(cfg.experiments).items():
39
+ if not experiments_list is None and not exp_name in experiments_list:
40
+ continue
41
+ curr_exp_cfg = {**copy.deepcopy(to_dict(common_cfg)), **to_dict(exp_cfg)}
42
+ curr_exp_cfg['exp_suffix'] = f'{exp_name}{suffix}'
43
+ args_dicts.append(curr_exp_cfg)
44
+
45
+ return args_dicts
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def to_dict(cfg) -> Dict:
50
+ return OmegaConf.to_container(OmegaConf.create({**cfg}))
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser(description="Experiments launcher")
56
+ parser.add_argument('-e', '--series_name', type=str, required=True, help="Which experiments series to launch?")
57
+ parser.add_argument('-d', '--datasets', required=True, type=str, help='Comma-separate list of datasets')
58
+ parser.add_argument('-p', '--print_only', action='store_true', help='Just print commands and exit?')
59
+ parser.add_argument('-t', '--time', type=str, default='1-0', help='Which time to specify for the sbatch command?')
60
+ parser.add_argument('-q', '--use_qos', action='store_true', help='Should we use QoS to launch jobs?')
61
+ parser.add_argument('--experiments_list', type=str, help='Should we run only some specific experiments from this experiments series?')
62
+ parser.add_argument('--other_args', type=str, default="", help='Additional arguments for the experiments')
63
+ parser.add_argument('--suffix', type=str, default="", help='Additional suffix for the experiments')
64
+ parser.add_argument('--num_gpus', type=int, default=4, help='Number of GPUs to use per each experiment')
65
+ parser.add_argument('--project_dir', type=str, default=os.getcwd(), help='Project directory path')
66
+ parser.add_argument('--project_dir_for_exps_cfg', type=str, help="Overwrite the project directory to use for experiments.yaml. Useful for debugging the config.")
67
+ args = parser.parse_args()
68
+
69
+ os.chdir(args.project_dir)
70
+ user = os.environ.get('USER', 'unknown')
71
+ python_bin = os.path.join(args.project_dir, 'env/bin/python')
72
+ launcher = f"{python_bin} src/infra/launch.py {HYDRA_ARGS} +quiet=true slurm=true"
73
+ experiments_dir = f'experiments/{user}/{args.series_name}'
74
+ exps_cfg_path = os.path.join(args.project_dir if args.project_dir_for_exps_cfg is None else args.project_dir_for_exps_cfg, 'src/infra/experiments.yaml')
75
+ all_exp_series = OmegaConf.load(exps_cfg_path)
76
+ assert args.series_name in all_exp_series, f"Experiments series not found: {args.series_name}"
77
+ cfg = all_exp_series[args.series_name]
78
+ datasets = args.datasets.split(',')
79
+ experiments_list = None if args.experiments_list is None else args.experiments_list.split(',')
80
+ other_args = {kv.split('=')[0]: kv.split('=')[1] for kv in args.other_args.split(',') if len(kv.split('=')) == 2}
81
+
82
+ batch_launch(
83
+ launcher=launcher,
84
+ experiments_dir=experiments_dir,
85
+ cfg=cfg,
86
+ datasets=datasets,
87
+ print_only=args.print_only,
88
+ time=args.time,
89
+ use_qos=args.use_qos,
90
+ experiments_list=experiments_list,
91
+ other_args=other_args,
92
+ suffix=args.suffix,
93
+ num_gpus=args.num_gpus,
94
+ )
95
+
96
+ #----------------------------------------------------------------------------
src/infra/slurm_job.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Must be launched from the released project dir
3
+ """
4
+
5
+ import os
6
+ import time
7
+ import random
8
+ import subprocess
9
+ from shutil import copyfile
10
+
11
+ import hydra
12
+ from omegaconf import DictConfig
13
+
14
+ # Unfortunately, (AFAIK) we cannot pass arguments normally (to parse them with argparse)
15
+ # that's why we are reading them from env
16
+ SLURM_JOB_ID = os.getenv('SLURM_JOB_ID')
17
+ project_dir = os.getenv('project_dir')
18
+ python_bin = os.getenv('python_bin')
19
+
20
+ # Printing the environment
21
+ print('PROJECT DIR:', project_dir)
22
+ print(f'SLURM_JOB_ID: {SLURM_JOB_ID}')
23
+ print('HOSTNAME:', subprocess.run(['hostname'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
24
+ print(subprocess.run([os.path.join(os.path.dirname(python_bin), 'gpustat')], stdout=subprocess.PIPE).stdout.decode('utf-8'))
25
+
26
+ @hydra.main(config_name=os.path.join(project_dir, 'experiment_config.yaml'))
27
+ def main(cfg: DictConfig):
28
+ os.chdir(project_dir)
29
+
30
+ target_data_dir_base = os.path.dirname(cfg.dataset.path)
31
+ if os.path.islink(target_data_dir_base):
32
+ os.makedirs(os.readlink(target_data_dir_base), exist_ok=True)
33
+ else:
34
+ os.makedirs(target_data_dir_base, exist_ok=True)
35
+
36
+ copyfile(cfg.dataset.path_for_slurm_job, cfg.dataset.path)
37
+ print(f'Copied the data: {cfg.dataset.path_for_slurm_job} => {cfg.dataset.path}. Starting the training...')
38
+
39
+ training_cmd = open('training_cmd.sh').read()
40
+ print('<=== TRAINING COMMAND ===>')
41
+ print(training_cmd)
42
+ os.system(training_cmd)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
src/infra/slurm_job_proxy.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+ # We need this proxy so not to put the shebang into `slurm_job.py`
3
+ # We cannot put a shebang there since we use different python executors for it
4
+ $python_bin $python_script
src/infra/utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ from distutils.dir_util import copy_tree
5
+ from shutil import copyfile
6
+ from typing import List, Optional
7
+
8
+ from hydra.utils import instantiate
9
+ import click
10
+ import git
11
+ from omegaconf import DictConfig
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def copy_objects(target_dir: os.PathLike, objects_to_copy: List[os.PathLike]):
16
+ for src_path in objects_to_copy:
17
+ trg_path = os.path.join(target_dir, os.path.basename(src_path))
18
+
19
+ if os.path.islink(src_path):
20
+ os.symlink(os.readlink(src_path), trg_path)
21
+ elif os.path.isfile(src_path):
22
+ copyfile(src_path, trg_path)
23
+ elif os.path.isdir(src_path):
24
+ copy_tree(src_path, trg_path)
25
+ else:
26
+ raise NotImplementedError(f"Unknown object type: {src_path}")
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ def create_symlinks(target_dir: os.PathLike, symlinks_to_create: List[os.PathLike]):
31
+ """
32
+ Creates symlinks to the given paths
33
+ """
34
+ for src_path in symlinks_to_create:
35
+ trg_path = os.path.join(target_dir, os.path.basename(src_path))
36
+
37
+ if os.path.islink(src_path):
38
+ # Let's not create symlinks to symlinks
39
+ # Since dropping the current symlink will break the experiment
40
+ os.symlink(os.readlink(src_path), trg_path)
41
+ else:
42
+ print(f'Creating a symlink to {src_path}, so try not to delete it occasionally!')
43
+ os.symlink(src_path, trg_path)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def is_git_repo(path: os.PathLike):
48
+ try:
49
+ _ = git.Repo(path).git_dir
50
+ return True
51
+ except git.exc.InvalidGitRepositoryError:
52
+ return False
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def create_project_dir(
57
+ project_dir: os.PathLike,
58
+ objects_to_copy: List[os.PathLike],
59
+ symlinks_to_create: List[os.PathLike],
60
+ quiet: bool=False,
61
+ ignore_uncommited_changes: bool=False,
62
+ overwrite: bool=False):
63
+
64
+ if is_git_repo(os.getcwd()) and are_there_uncommitted_changes():
65
+ if ignore_uncommited_changes or click.confirm("There are uncommited changes. Continue?", default=False):
66
+ pass
67
+ else:
68
+ raise PermissionError("Cannot created a dir when there are uncommited changes")
69
+
70
+ if os.path.exists(project_dir):
71
+ if overwrite or click.confirm(f'Dir {project_dir} already exists. Overwrite it?', default=False):
72
+ shutil.rmtree(project_dir)
73
+ else:
74
+ print('User refused to delete an existing project dir.')
75
+ raise PermissionError("There is an existing dir and I cannot delete it.")
76
+
77
+ os.makedirs(project_dir)
78
+ copy_objects(project_dir, objects_to_copy)
79
+ create_symlinks(project_dir, symlinks_to_create)
80
+
81
+ if not quiet:
82
+ print(f'Created a project dir: {project_dir}')
83
+
84
+ #----------------------------------------------------------------------------
85
+
86
+ def get_git_hash() -> Optional[str]:
87
+ if not is_git_repo(os.getcwd()):
88
+ return None
89
+
90
+ try:
91
+ return subprocess \
92
+ .check_output(['git', 'rev-parse', '--short', 'HEAD']) \
93
+ .decode("utf-8") \
94
+ .strip()
95
+ except:
96
+ return None
97
+
98
+ #----------------------------------------------------------------------------
99
+
100
+ # def get_experiment_path(master_dir: os.PathLike, experiment_name: str) -> os.PathLike:
101
+ # return os.path.join(master_dir, f"{experiment_name}-{get_git_hash()}")
102
+
103
+ #----------------------------------------------------------------------------
104
+
105
+ def get_git_hash_suffix() -> str:
106
+ git_hash: Optional[str] = get_git_hash()
107
+ git_hash_suffix = "-nogit" if git_hash is None else f"-{git_hash}"
108
+
109
+ return git_hash_suffix
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ def are_there_uncommitted_changes() -> bool:
114
+ return len(subprocess.check_output('git status -s'.split()).decode("utf-8")) > 0
115
+
116
+ #----------------------------------------------------------------------------
117
+
118
+ def cfg_to_args_str(cfg: DictConfig, use_dashes=True) -> str:
119
+ dashes = '--' if use_dashes else ''
120
+
121
+ return ' '.join([f'{dashes}{p}={cfg[p]}' for p in cfg])
122
+
123
+ #----------------------------------------------------------------------------
124
+
125
+ def recursive_instantiate(cfg: DictConfig):
126
+ for key in cfg:
127
+ # print(type(cfg[key]))
128
+ if isinstance(cfg[key], DictConfig):
129
+ if '_target_' in cfg[key]:
130
+ cfg[key] = instantiate(cfg[key])
131
+ else:
132
+ recursive_instantiate(cfg[key])
133
+
134
+ #----------------------------------------------------------------------------
135
+
136
+ def num_gpus_to_mem(num_gpus: int, mem_per_gpu: 64) -> str:
137
+ # Doing it here since hydra config cannot do formatting for ${...}
138
+ return f"{num_gpus * mem_per_gpu}G"
139
+
140
+ #----------------------------------------------------------------------------
src/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
src/metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
+ import numpy as np
15
+ import scipy.linalg
16
+ from . import metric_utils
17
+
18
+ NUM_FRAMES_IN_BATCH = {128: 32, 256: 32, 512: 8, 1024: 2}
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def compute_fid(opts, max_real, num_gen):
23
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
24
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
25
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
26
+
27
+ batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution]
28
+
29
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
30
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real, use_image_dataset=True).get_mean_cov()
32
+
33
+ if opts.generator_as_dataset:
34
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
35
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
36
+ gen_kwargs = dict(use_image_dataset=True)
37
+ else:
38
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
39
+ gen_opts = opts
40
+ gen_kwargs = dict()
41
+
42
+ mu_gen, sigma_gen = compute_gen_stats_fn(
43
+ opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, batch_size=batch_size,
44
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen, **gen_kwargs).get_mean_cov()
45
+
46
+ if opts.rank != 0:
47
+ return float('nan')
48
+
49
+ m = np.square(mu_gen - mu_real).sum()
50
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
51
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
52
+ return float(fid)
53
+
54
+ #----------------------------------------------------------------------------
src/metrics/frechet_video_distance.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Frechet Video Distance (FVD). Matches the original tensorflow implementation from
3
+ https://github.com/google-research/google-research/blob/master/frechet_video_distance/frechet_video_distance.py
4
+ up to the upsampling operation. Note that this tf.hub I3D model is different from the one released in the I3D repo.
5
+ """
6
+
7
+ import copy
8
+ import numpy as np
9
+ import scipy.linalg
10
+ from . import metric_utils
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ NUM_FRAMES_IN_BATCH = {128: 128, 256: 128, 512: 64, 1024: 32}
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_fvd(opts, max_real: int, num_gen: int, num_frames: int, subsample_factor: int=1):
19
+ # Perfectly reproduced torchscript version of the I3D model, trained on Kinetics-400, used here:
20
+ # https://github.com/google-research/google-research/blob/master/frechet_video_distance/frechet_video_distance.py
21
+ # Note that the weights on tf.hub (used in the script above) differ from the original released weights
22
+ detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1'
23
+ detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer.
24
+
25
+ opts = copy.deepcopy(opts)
26
+ opts.dataset_kwargs.load_n_consecutive = num_frames
27
+ opts.dataset_kwargs.subsample_factor = subsample_factor
28
+ opts.dataset_kwargs.discard_short_videos = True
29
+ batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution] // num_frames
30
+
31
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
32
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, rel_lo=0, rel_hi=0,
33
+ capture_mean_cov=True, max_items=max_real, temporal_detector=True, batch_size=batch_size).get_mean_cov()
34
+
35
+ if opts.generator_as_dataset:
36
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
37
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
38
+ gen_opts.dataset_kwargs.load_n_consecutive = num_frames
39
+ gen_opts.dataset_kwargs.load_n_consecutive_random_offset = False
40
+ gen_opts.dataset_kwargs.subsample_factor = subsample_factor
41
+ gen_kwargs = dict()
42
+ else:
43
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
44
+ gen_opts = opts
45
+ gen_kwargs = dict(num_video_frames=num_frames, subsample_factor=subsample_factor)
46
+
47
+ mu_gen, sigma_gen = compute_gen_stats_fn(
48
+ opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs, rel_lo=0, rel_hi=1, capture_mean_cov=True,
49
+ max_items=num_gen, temporal_detector=True, batch_size=batch_size, **gen_kwargs).get_mean_cov()
50
+
51
+ if opts.rank != 0:
52
+ return float('nan')
53
+
54
+ m = np.square(mu_gen - mu_real).sum()
55
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
56
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
57
+ return float(fid)
58
+
59
+ #----------------------------------------------------------------------------
src/metrics/inception_score.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
+
23
+ if opts.generator_as_dataset:
24
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
25
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
26
+ gen_kwargs = dict(use_image_dataset=True)
27
+ else:
28
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
29
+ gen_opts = opts
30
+ gen_kwargs = dict()
31
+
32
+ gen_probs = compute_gen_stats_fn(
33
+ opts=gen_opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
34
+ capture_all=True, max_items=num_gen, **gen_kwargs).get_all()
35
+
36
+ if opts.rank != 0:
37
+ return float('nan'), float('nan')
38
+
39
+ scores = []
40
+ for i in range(num_splits):
41
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
42
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
43
+ kl = np.mean(np.sum(kl, axis=1))
44
+ scores.append(np.exp(kl))
45
+ return float(np.mean(scores)), float(np.std(scores))
46
+
47
+ #----------------------------------------------------------------------------
src/metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
+
23
+ real_features = metric_utils.compute_feature_stats_for_dataset(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real, use_image_dataset=True).get_all()
26
+
27
+ gen_features = metric_utils.compute_feature_stats_for_generator(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
+
31
+ if opts.rank != 0:
32
+ return float('nan')
33
+
34
+ n = real_features.shape[1]
35
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
+ t = 0
37
+ for _subset_idx in range(num_subsets):
38
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
+ b = (x @ y.T / n + 1) ** 3
42
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
+ kid = t / num_subsets / m
44
+ return float(kid) * 1000.0
45
+
46
+ #----------------------------------------------------------------------------
src/metrics/metric_main.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import json
12
+ import torch
13
+ import numpy as np
14
+ from src import dnnlib
15
+
16
+ from . import metric_utils
17
+ from . import frechet_inception_distance
18
+ from . import kernel_inception_distance
19
+ from . import inception_score
20
+ from . import video_inception_score
21
+ from . import frechet_video_distance
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ _metric_dict = dict() # name => fn
26
+
27
+ def register_metric(fn):
28
+ assert callable(fn)
29
+ _metric_dict[fn.__name__] = fn
30
+ return fn
31
+
32
+ def is_valid_metric(metric):
33
+ return metric in _metric_dict
34
+
35
+ def list_valid_metrics():
36
+ return list(_metric_dict.keys())
37
+
38
+ def is_power_of_two(n: int) -> bool:
39
+ return (n & (n-1) == 0) and n != 0
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def calc_metric(metric, num_runs: int=1, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
44
+ assert is_valid_metric(metric)
45
+ opts = metric_utils.MetricOptions(**kwargs)
46
+
47
+ # Calculate.
48
+ start_time = time.time()
49
+ all_runs_results = [_metric_dict[metric](opts) for _ in range(num_runs)]
50
+ total_time = time.time() - start_time
51
+
52
+ # Broadcast results.
53
+ for results in all_runs_results:
54
+ for key, value in list(results.items()):
55
+ if opts.num_gpus > 1:
56
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
57
+ torch.distributed.broadcast(tensor=value, src=0)
58
+ value = float(value.cpu())
59
+ results[key] = value
60
+
61
+ if num_runs > 1:
62
+ results = {f'{key}_run{i+1:02d}': value for i, results in enumerate(all_runs_results) for key, value in results.items()}
63
+ for key, value in all_runs_results[0].items():
64
+ all_runs_values = [r[key] for r in all_runs_results]
65
+ results[f'{key}_mean'] = np.mean(all_runs_values)
66
+ results[f'{key}_std'] = np.std(all_runs_values)
67
+ else:
68
+ results = all_runs_results[0]
69
+
70
+ # Decorate with metadata.
71
+ return dnnlib.EasyDict(
72
+ results = dnnlib.EasyDict(results),
73
+ metric = metric,
74
+ total_time = total_time,
75
+ total_time_str = dnnlib.util.format_time(total_time),
76
+ num_gpus = opts.num_gpus,
77
+ )
78
+
79
+ #----------------------------------------------------------------------------
80
+
81
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
82
+ metric = result_dict['metric']
83
+ assert is_valid_metric(metric)
84
+ if run_dir is not None and snapshot_pkl is not None:
85
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
86
+
87
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
88
+ print(jsonl_line)
89
+ if run_dir is not None and os.path.isdir(run_dir):
90
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
91
+ f.write(jsonl_line + '\n')
92
+
93
+ #----------------------------------------------------------------------------
94
+ # Primary metrics.
95
+
96
+ @register_metric
97
+ def fid50k_full(opts):
98
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
99
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
100
+ return dict(fid50k_full=fid)
101
+
102
+ @register_metric
103
+ def kid50k_full(opts):
104
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
105
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
106
+ return dict(kid50k_full=kid)
107
+
108
+ @register_metric
109
+ def is50k(opts):
110
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
111
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
112
+ return dict(is50k_mean=mean, is50k_std=std)
113
+
114
+ @register_metric
115
+ def fvd2048_16f(opts):
116
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
117
+ fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=16)
118
+ return dict(fvd2048_16f=fvd)
119
+
120
+ @register_metric
121
+ def fvd2048_128f(opts):
122
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
123
+ fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=128)
124
+ return dict(fvd2048_128f=fvd)
125
+
126
+ @register_metric
127
+ def fvd2048_128f_subsample8f(opts):
128
+ """Similar to `fvd2048_128f`, but we sample each 8-th frame"""
129
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
130
+ fvd = frechet_video_distance.compute_fvd(opts, max_real=2048, num_gen=2048, num_frames=16, subsample_factor=8)
131
+ return dict(fvd2048_128f_subsample8f=fvd)
132
+
133
+ @register_metric
134
+ def isv2048_ucf(opts):
135
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
136
+ mean, std = video_inception_score.compute_isv(opts, num_gen=2048, num_splits=10, backbone='c3d_ucf101')
137
+ return dict(isv2048_ucf_mean=mean, isv2048_ucf_std=std)
138
+
139
+ #----------------------------------------------------------------------------
140
+ # Legacy metrics.
141
+
142
+ @register_metric
143
+ def fid50k(opts):
144
+ opts.dataset_kwargs.update(max_size=None)
145
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
146
+ return dict(fid50k=fid)
147
+
148
+ @register_metric
149
+ def kid50k(opts):
150
+ opts.dataset_kwargs.update(max_size=None)
151
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
152
+ return dict(kid50k=kid)
153
+
154
+ #----------------------------------------------------------------------------
src/metrics/metric_utils.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import hashlib
12
+ import pickle
13
+ import copy
14
+ import uuid
15
+ from urllib.parse import urlparse
16
+ import numpy as np
17
+ import torch
18
+ from src import dnnlib
19
+ from src.training.dataset import video_to_image_dataset_kwargs
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ class MetricOptions:
24
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None,
25
+ progress=None, cache=True, gen_dataset_kwargs={}, generator_as_dataset=False):
26
+ assert 0 <= rank < num_gpus
27
+ self.G = G
28
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
29
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
30
+ self.num_gpus = num_gpus
31
+ self.rank = rank
32
+ self.device = device if device is not None else torch.device('cuda', rank)
33
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
34
+ self.cache = cache
35
+ self.gen_dataset_kwargs = gen_dataset_kwargs
36
+ self.generator_as_dataset = generator_as_dataset
37
+
38
+ #----------------------------------------------------------------------------
39
+
40
+ _feature_detector_cache = dict()
41
+
42
+ def get_feature_detector_name(url):
43
+ return os.path.splitext(url.split('/')[-1])[0]
44
+
45
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
46
+ assert 0 <= rank < num_gpus
47
+ key = (url, device)
48
+ if key not in _feature_detector_cache:
49
+ is_leader = (rank == 0)
50
+ if not is_leader and num_gpus > 1:
51
+ torch.distributed.barrier() # leader goes first
52
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
53
+ if urlparse(url).path.endswith('.pkl'):
54
+ _feature_detector_cache[key] = pickle.load(f).to(device)
55
+ else:
56
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
57
+ if is_leader and num_gpus > 1:
58
+ torch.distributed.barrier() # others follow
59
+ return _feature_detector_cache[key]
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ class FeatureStats:
64
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
65
+ self.capture_all = capture_all
66
+ self.capture_mean_cov = capture_mean_cov
67
+ self.max_items = max_items
68
+ self.num_items = 0
69
+ self.num_features = None
70
+ self.all_features = None
71
+ self.raw_mean = None
72
+ self.raw_cov = None
73
+
74
+ def set_num_features(self, num_features):
75
+ if self.num_features is not None:
76
+ assert num_features == self.num_features
77
+ else:
78
+ self.num_features = num_features
79
+ self.all_features = []
80
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
81
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
82
+
83
+ def is_full(self):
84
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
85
+
86
+ def append(self, x):
87
+ x = np.asarray(x, dtype=np.float32)
88
+ assert x.ndim == 2
89
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
90
+ if self.num_items >= self.max_items:
91
+ return
92
+ x = x[:self.max_items - self.num_items]
93
+
94
+ self.set_num_features(x.shape[1])
95
+ self.num_items += x.shape[0]
96
+ if self.capture_all:
97
+ self.all_features.append(x)
98
+ if self.capture_mean_cov:
99
+ x64 = x.astype(np.float64)
100
+ self.raw_mean += x64.sum(axis=0)
101
+ self.raw_cov += x64.T @ x64
102
+
103
+ def append_torch(self, x, num_gpus=1, rank=0):
104
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
105
+ assert 0 <= rank < num_gpus
106
+ if num_gpus > 1:
107
+ ys = []
108
+ for src in range(num_gpus):
109
+ y = x.clone()
110
+ torch.distributed.broadcast(y, src=src)
111
+ ys.append(y)
112
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
113
+ self.append(x.cpu().numpy())
114
+
115
+ def get_all(self):
116
+ assert self.capture_all
117
+ return np.concatenate(self.all_features, axis=0)
118
+
119
+ def get_all_torch(self):
120
+ return torch.from_numpy(self.get_all())
121
+
122
+ def get_mean_cov(self):
123
+ assert self.capture_mean_cov
124
+ mean = self.raw_mean / self.num_items
125
+ cov = self.raw_cov / self.num_items
126
+ cov = cov - np.outer(mean, mean)
127
+ return mean, cov
128
+
129
+ def save(self, pkl_file):
130
+ with open(pkl_file, 'wb') as f:
131
+ pickle.dump(self.__dict__, f)
132
+
133
+ @staticmethod
134
+ def load(pkl_file):
135
+ with open(pkl_file, 'rb') as f:
136
+ s = dnnlib.EasyDict(pickle.load(f))
137
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
138
+ obj.__dict__.update(s)
139
+ return obj
140
+
141
+ #----------------------------------------------------------------------------
142
+
143
+ class ProgressMonitor:
144
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
145
+ self.tag = tag
146
+ self.num_items = num_items
147
+ self.verbose = verbose
148
+ self.flush_interval = flush_interval
149
+ self.progress_fn = progress_fn
150
+ self.pfn_lo = pfn_lo
151
+ self.pfn_hi = pfn_hi
152
+ self.pfn_total = pfn_total
153
+ self.start_time = time.time()
154
+ self.batch_time = self.start_time
155
+ self.batch_items = 0
156
+ if self.progress_fn is not None:
157
+ self.progress_fn(self.pfn_lo, self.pfn_total)
158
+
159
+ def update(self, cur_items: int):
160
+ assert (self.num_items is None) or (cur_items <= self.num_items), f"Wrong `items` values: cur_items={cur_items}, self.num_items={self.num_items}"
161
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
162
+ return
163
+ cur_time = time.time()
164
+ total_time = cur_time - self.start_time
165
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
166
+ if (self.verbose) and (self.tag is not None):
167
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
168
+ self.batch_time = cur_time
169
+ self.batch_items = cur_items
170
+
171
+ if (self.progress_fn is not None) and (self.num_items is not None):
172
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
173
+
174
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
175
+ return ProgressMonitor(
176
+ tag = tag,
177
+ num_items = num_items,
178
+ flush_interval = flush_interval,
179
+ verbose = self.verbose,
180
+ progress_fn = self.progress_fn,
181
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
182
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
183
+ pfn_total = self.pfn_total,
184
+ )
185
+
186
+ #----------------------------------------------------------------------------
187
+
188
+ @torch.no_grad()
189
+ def compute_feature_stats_for_dataset(
190
+ opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64,
191
+ data_loader_kwargs=None, max_items=None, temporal_detector=False, use_image_dataset=False,
192
+ feature_stats_cls=FeatureStats, **stats_kwargs):
193
+
194
+ dataset_kwargs = video_to_image_dataset_kwargs(opts.dataset_kwargs) if use_image_dataset else opts.dataset_kwargs
195
+ dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs)
196
+
197
+ if data_loader_kwargs is None:
198
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
199
+
200
+ # Try to lookup from cache.
201
+ cache_file = None
202
+ if opts.cache:
203
+ # Choose cache file name.
204
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs,
205
+ stats_kwargs=stats_kwargs, feature_stats_cls=feature_stats_cls.__name__)
206
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
207
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
208
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
209
+
210
+ # Check if the file exists (all processes must agree).
211
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
212
+ if opts.num_gpus > 1:
213
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
214
+ torch.distributed.broadcast(tensor=flag, src=0)
215
+ flag = (float(flag.cpu()) != 0)
216
+
217
+ # Load.
218
+ if flag:
219
+ return feature_stats_cls.load(cache_file)
220
+
221
+ # Initialize.
222
+ num_items = len(dataset)
223
+ if max_items is not None:
224
+ num_items = min(num_items, max_items)
225
+ stats = feature_stats_cls(max_items=num_items, **stats_kwargs)
226
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
227
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
228
+
229
+ # Main loop.
230
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
231
+ for batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
232
+ images = batch['image']
233
+ if temporal_detector:
234
+ images = images.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
235
+
236
+ # images = images.float() / 255
237
+ # images = torch.nn.functional.interpolate(images, size=(images.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
238
+ # images = torch.nn.functional.interpolate(images, size=(images.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
239
+ # images = (images * 255).to(torch.uint8)
240
+ else:
241
+ images = images.view(-1, *images.shape[-3:]) # [-1, c, h, w]
242
+
243
+ if images.shape[1] == 1:
244
+ images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
245
+ features = detector(images.to(opts.device), **detector_kwargs)
246
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
247
+ progress.update(stats.num_items)
248
+
249
+ # Save to cache.
250
+ if cache_file is not None and opts.rank == 0:
251
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
252
+ temp_file = cache_file + '.' + uuid.uuid4().hex
253
+ stats.save(temp_file)
254
+ os.replace(temp_file, cache_file) # atomic
255
+ return stats
256
+
257
+ #----------------------------------------------------------------------------
258
+
259
+ @torch.no_grad()
260
+ def compute_feature_stats_for_generator(
261
+ opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16,
262
+ batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16,
263
+ feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs):
264
+
265
+ if batch_gen is None:
266
+ batch_gen = min(batch_size, 4)
267
+ assert batch_size % batch_gen == 0
268
+
269
+ # Setup generator and load labels.
270
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
271
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
272
+
273
+ # Image generation func.
274
+ def run_generator(z, c, t):
275
+ img = G(z=z, c=c, t=t, **opts.G_kwargs)
276
+ bt, c, h, w = img.shape
277
+
278
+ if temporal_detector:
279
+ img = img.view(bt // num_video_frames, num_video_frames, c, h, w) # [batch_size, t, c, h, w]
280
+ img = img.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w]
281
+
282
+ # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample
283
+ # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample
284
+
285
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
286
+ return img
287
+
288
+ # JIT.
289
+ if jit:
290
+ z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
291
+ c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
292
+ t = torch.zeros([batch_gen, G.cfg.sampling.num_frames_per_video], device=opts.device)
293
+ run_generator = torch.jit.trace(run_generator, [z, c, t], check_trace=False)
294
+
295
+ # Initialize.
296
+ stats = feature_stats_cls(**stats_kwargs)
297
+ assert stats.max_items is not None
298
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
299
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
300
+
301
+ # Main loop.
302
+ while not stats.is_full():
303
+ images = []
304
+ for _i in range(batch_size // batch_gen):
305
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
306
+ cond_sample_idx = [np.random.randint(len(dataset)) for _ in range(batch_gen)]
307
+ c = [dataset.get_label(i) for i in cond_sample_idx]
308
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
309
+ t = [list(range(0, num_video_frames * subsample_factor, subsample_factor)) for _i in range(batch_gen)]
310
+ t = torch.from_numpy(np.stack(t)).pin_memory().to(opts.device)
311
+ images.append(run_generator(z, c, t))
312
+ images = torch.cat(images)
313
+ if images.shape[1] == 1:
314
+ images = images.repeat([1, 3, *([1] * (images.ndim - 2))])
315
+ features = detector(images, **detector_kwargs)
316
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
317
+ progress.update(stats.num_items)
318
+ return stats
319
+
320
+ #----------------------------------------------------------------------------
321
+
322
+ def rewrite_opts_for_gen_dataset(opts):
323
+ """
324
+ Updates dataset arguments in the opts to enable the second dataset stats computation
325
+ """
326
+ new_opts = copy.deepcopy(opts)
327
+ new_opts.dataset_kwargs = new_opts.gen_dataset_kwargs
328
+ new_opts.cache = False
329
+
330
+ return new_opts
331
+
332
+ #----------------------------------------------------------------------------
src/metrics/video_inception_score.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inception Score (IS) from the paper "Improved techniques for training
2
+ GANs". Matches the original implementation by Salimans et al. at
3
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
4
+
5
+ import numpy as np
6
+ from . import metric_utils
7
+
8
+ #----------------------------------------------------------------------------
9
+
10
+ NUM_FRAMES_IN_BATCH = {128: 128, 256: 128, 512: 64, 1024: 32}
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def compute_isv(opts, num_gen: int, num_splits: int, backbone: str):
15
+ if backbone == 'c3d_ucf101':
16
+ # Perfectly reproduced torchscript version of the original chainer checkpoint:
17
+ # https://github.com/pfnet-research/tgan2/blob/f892bc432da315d4f6b6ae9448f69d046ef6fe01/tgan2/models/c3d/c3d_ucf101.py
18
+ # It is a UCF-101-finetuned C3D model.
19
+ detector_url = 'https://www.dropbox.com/s/jxpu7avzdc9n97q/c3d_ucf101.pt?dl=1'
20
+ else:
21
+ raise NotImplementedError(f'Backbone {backbone} is not supported.')
22
+
23
+ num_frames = 16
24
+ batch_size = NUM_FRAMES_IN_BATCH[opts.dataset_kwargs.resolution] // num_frames
25
+
26
+ if opts.generator_as_dataset:
27
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_dataset
28
+ gen_opts = metric_utils.rewrite_opts_for_gen_dataset(opts)
29
+ gen_opts.dataset_kwargs.load_n_consecutive = num_frames
30
+ gen_opts.dataset_kwargs.load_n_consecutive_random_offset = False
31
+ gen_opts.dataset_kwargs.subsample_factor = 1
32
+ gen_kwargs = dict()
33
+ else:
34
+ compute_gen_stats_fn = metric_utils.compute_feature_stats_for_generator
35
+ gen_opts = opts
36
+ gen_kwargs = dict(num_video_frames=num_frames, subsample_factor=1)
37
+
38
+ gen_probs = compute_gen_stats_fn(
39
+ opts=gen_opts, detector_url=detector_url, detector_kwargs={},
40
+ capture_all=True, max_items=num_gen, temporal_detector=True, **gen_kwargs).get_all() # [num_gen, num_classes]
41
+
42
+ if opts.rank != 0:
43
+ return float('nan'), float('nan')
44
+
45
+ scores = []
46
+ np.random.RandomState(42).shuffle(gen_probs)
47
+ for i in range(num_splits):
48
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
49
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
50
+ kl = np.mean(np.sum(kl, axis=1))
51
+ scores.append(np.exp(kl))
52
+ return float(np.mean(scores)), float(np.std(scores))
53
+
54
+ #----------------------------------------------------------------------------
src/scripts/__init__.py ADDED
File without changes
src/scripts/calc_metrics.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Calculate quality metrics for previous training run or pretrained network pickle."""
10
+
11
+ import sys; sys.path.extend(['.', 'src'])
12
+ import os
13
+ import re
14
+ import click
15
+ import json
16
+ import tempfile
17
+ import copy
18
+ import torch
19
+ from src import dnnlib
20
+ from omegaconf import OmegaConf
21
+
22
+ import legacy
23
+ from metrics import metric_main
24
+ from metrics import metric_utils
25
+ from src.torch_utils import training_stats
26
+ from src.torch_utils import custom_ops
27
+ from src.torch_utils import misc
28
+
29
+ #----------------------------------------------------------------------------
30
+
31
+ def subprocess_fn(rank, args, temp_dir):
32
+ dnnlib.util.Logger(should_flush=True)
33
+
34
+ # Init torch.distributed.
35
+ if args.num_gpus > 1:
36
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
37
+ if os.name == 'nt':
38
+ init_method = 'file:///' + init_file.replace('\\', '/')
39
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
40
+ else:
41
+ init_method = f'file://{init_file}'
42
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
43
+
44
+ # Init torch_utils.
45
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
46
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
47
+ if rank != 0 or not args.verbose:
48
+ custom_ops.verbosity = 'none'
49
+
50
+ # Print network summary.
51
+ device = torch.device('cuda', rank)
52
+ torch.backends.cudnn.benchmark = True
53
+ torch.backends.cuda.matmul.allow_tf32 = False
54
+ torch.backends.cudnn.allow_tf32 = False
55
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
56
+ if rank == 0 and args.verbose:
57
+ z = torch.empty([8, G.z_dim], device=device)
58
+ c = torch.empty([8, G.c_dim], device=device)
59
+ t = torch.zeros([8, G.cfg.sampling.num_frames_per_video], device=device).long()
60
+ misc.print_module_summary(G, [z, c, t])
61
+
62
+ # Calculate each metric.
63
+ for metric in args.metrics:
64
+ if rank == 0 and args.verbose:
65
+ print(f'Calculating {metric}...')
66
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
67
+ result_dict = metric_main.calc_metric(
68
+ metric=metric,
69
+ G=G,
70
+ dataset_kwargs=args.dataset_kwargs,
71
+ num_gpus=args.num_gpus,
72
+ rank=rank,
73
+ device=device,
74
+ progress=progress,
75
+ cache=args.use_cache,
76
+ num_runs=(1 if metric == 'fid50k_full' else args.num_runs),
77
+ )
78
+ if rank == 0:
79
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
80
+ if rank == 0 and args.verbose:
81
+ print()
82
+
83
+ # Done.
84
+ if rank == 0 and args.verbose:
85
+ print('Exiting...')
86
+
87
+ #----------------------------------------------------------------------------
88
+
89
+ class CommaSeparatedList(click.ParamType):
90
+ name = 'list'
91
+
92
+ def convert(self, value, param, ctx):
93
+ _ = param, ctx
94
+ if value is None or value.lower() == 'none' or value == '':
95
+ return []
96
+ return value.split(',')
97
+
98
+ #----------------------------------------------------------------------------
99
+
100
+ @click.command()
101
+ @click.pass_context
102
+ @click.option('--network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH')
103
+ @click.option('--networks_dir', '--networks_dir', help='Path to the experiment directory if the latest checkpoint is requested.', metavar='PATH')
104
+ @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
105
+ @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
106
+ @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
107
+ @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
108
+ @click.option('--cfg_path', help='Path to the experiments config', type=str, default="auto", metavar='PATH')
109
+ @click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
110
+ @click.option('--use_cache', help='Should we use the cache file?', type=bool, default=True, metavar='BOOL', show_default=True)
111
+ @click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
112
+
113
+ def calc_metrics(ctx, network_pkl, networks_dir, metrics, data, mirror, gpus, cfg_path, verbose, use_cache: bool, num_runs: int):
114
+ """Calculate quality metrics for previous training run or pretrained network pickle.
115
+
116
+ Examples:
117
+
118
+ \b
119
+ # Previous training run: look up options automatically, save result to JSONL file.
120
+ python calc_metrics.py --metrics=pr50k3_full \\
121
+ --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
122
+
123
+ \b
124
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
125
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
126
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
127
+
128
+ Available metrics:
129
+
130
+ \b
131
+ ADA paper:
132
+ fid50k_full Frechet inception distance against the full dataset.
133
+ kid50k_full Kernel inception distance against the full dataset.
134
+ pr50k3_full Precision and recall againt the full dataset.
135
+ is50k Inception score for CIFAR-10.
136
+
137
+ \b
138
+ StyleGAN and StyleGAN2 papers:
139
+ fid50k Frechet inception distance against 50k real images.
140
+ kid50k Kernel inception distance against 50k real images.
141
+ pr50k3 Precision and recall against 50k real images.
142
+ ppl2_wend Perceptual path length in W at path endpoints against full image.
143
+ ppl_zfull Perceptual path length in Z for full paths against cropped image.
144
+ ppl_wfull Perceptual path length in W for full paths against cropped image.
145
+ ppl_zend Perceptual path length in Z at path endpoints against cropped image.
146
+ ppl_wend Perceptual path length in W at path endpoints against cropped image.
147
+ """
148
+ dnnlib.util.Logger(should_flush=True)
149
+
150
+ if network_pkl is None:
151
+ output_regex = "^network-snapshot-\d{6}.pkl$"
152
+ ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
153
+ # ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
154
+ # network_pkl = os.path.join(networks_dir, ckpts[-1])
155
+ metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
156
+ with open(metrics_file, 'r') as f:
157
+ snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
158
+ best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
159
+ network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
160
+ print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
161
+ # Selecting a checkpoint with the best score
162
+
163
+ # Validate arguments.
164
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
165
+ if cfg_path == "auto":
166
+ # Assuming that `network_pkl` has the structure /path/to/experiment/output/network-X.pkl
167
+ output_path = os.path.dirname(network_pkl)
168
+ assert os.path.basename(output_path) == "output", f"Unknown path structure: {output_path}"
169
+ experiment_path = os.path.dirname(output_path)
170
+ cfg_path = os.path.join(experiment_path, 'experiment_config.yaml')
171
+
172
+ cfg = OmegaConf.load(cfg_path)
173
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
174
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
175
+ if not args.num_gpus >= 1:
176
+ ctx.fail('--gpus must be at least 1')
177
+
178
+ # Load network.
179
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
180
+ ctx.fail('--network must point to a file or URL')
181
+ if args.verbose:
182
+ print(f'Loading network from "{network_pkl}"...')
183
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
184
+ network_dict = legacy.load_network_pkl(f)
185
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
186
+
187
+ from src.training.networks import Generator
188
+ G = args.G
189
+ G.cfg.z_dim = G.z_dim
190
+ G_new = Generator(
191
+ w_dim=G.cfg.w_dim,
192
+ mapping_kwargs=dnnlib.EasyDict(num_layers=G.cfg.get('mapping_net_n_layers', 2), cfg=G.cfg),
193
+ synthesis_kwargs=dnnlib.EasyDict(
194
+ channel_base=int(G.cfg.get('fmaps', 0.5) * 32768),
195
+ channel_max=G.cfg.get('channel_max', 512),
196
+ num_fp16_res=4,
197
+ conv_clamp=256,
198
+ ),
199
+ cfg=G.cfg,
200
+ img_resolution=256,
201
+ img_channels=3,
202
+ c_dim=G.cfg.c_dim,
203
+ ).eval()
204
+ G_new.load_state_dict(G.state_dict())
205
+ args.G = G_new
206
+
207
+ # Initialize dataset options.
208
+ if data is not None:
209
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.VideoFramesFolderDataset', cfg=cfg.dataset, path=data)
210
+ elif network_dict['training_set_kwargs'] is not None:
211
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
212
+ else:
213
+ ctx.fail('Could not look up dataset options; please specify --data')
214
+
215
+ # Finalize dataset options.
216
+ args.dataset_kwargs.resolution = args.G.img_resolution
217
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
218
+ if mirror is not None:
219
+ args.dataset_kwargs.xflip = mirror
220
+ args.use_cache = use_cache
221
+ args.num_runs = num_runs
222
+
223
+ # Print dataset options.
224
+ if args.verbose:
225
+ print('Dataset options:')
226
+ print(cfg.dataset)
227
+
228
+ # Locate run dir.
229
+ args.run_dir = None
230
+ if os.path.isfile(network_pkl):
231
+ pkl_dir = os.path.dirname(network_pkl)
232
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
233
+ args.run_dir = pkl_dir
234
+
235
+ # Launch processes.
236
+ if args.verbose:
237
+ print('Launching processes...')
238
+ torch.multiprocessing.set_start_method('spawn')
239
+ with tempfile.TemporaryDirectory() as temp_dir:
240
+ if args.num_gpus == 1:
241
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
242
+ else:
243
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
244
+
245
+ #----------------------------------------------------------------------------
246
+
247
+ if __name__ == "__main__":
248
+ calc_metrics() # pylint: disable=no-value-for-parameter
249
+
250
+ #----------------------------------------------------------------------------
src/scripts/calc_metrics_for_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Calculate quality metrics for previous training run or pretrained network pickle."""
10
+
11
+ import sys; sys.path.extend(['.', 'src'])
12
+ import os
13
+ import click
14
+ import tempfile
15
+ import torch
16
+ from omegaconf import OmegaConf
17
+ from src import dnnlib
18
+
19
+ from metrics import metric_main
20
+ from metrics import metric_utils
21
+ from src.torch_utils import training_stats
22
+ from src.torch_utils import custom_ops
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ def subprocess_fn(rank, args, temp_dir):
27
+ dnnlib.util.Logger(should_flush=True)
28
+
29
+ # Init torch.distributed.
30
+ if args.num_gpus > 1:
31
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
32
+ if os.name == 'nt':
33
+ init_method = 'file:///' + init_file.replace('\\', '/')
34
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
35
+ else:
36
+ init_method = f'file://{init_file}'
37
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
38
+
39
+ # Init torch_utils.
40
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
41
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
42
+ if rank != 0 or not args.verbose:
43
+ custom_ops.verbosity = 'none'
44
+
45
+ # Print network summary.
46
+ device = torch.device('cuda', rank)
47
+ torch.backends.cudnn.benchmark = True
48
+ torch.backends.cuda.matmul.allow_tf32 = False
49
+ torch.backends.cudnn.allow_tf32 = False
50
+
51
+ # Calculate each metric.
52
+ for metric in args.metrics:
53
+ if rank == 0 and args.verbose:
54
+ print(f'Calculating {metric}...')
55
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
56
+ result_dict = metric_main.calc_metric(
57
+ metric=metric,
58
+ dataset_kwargs=args.dataset_kwargs,
59
+ gen_dataset_kwargs=args.gen_dataset_kwargs,
60
+ generator_as_dataset=args.generator_as_dataset,
61
+ num_gpus=args.num_gpus,
62
+ rank=rank,
63
+ device=device,
64
+ progress=progress,
65
+ cache=args.use_cache,
66
+ num_runs=args.num_runs,
67
+ )
68
+
69
+ if rank == 0:
70
+ metric_main.report_metric(result_dict, run_dir=args.run_dir)
71
+
72
+ if rank == 0 and args.verbose:
73
+ print()
74
+
75
+ # Done.
76
+ if rank == 0 and args.verbose:
77
+ print('Exiting...')
78
+
79
+ #----------------------------------------------------------------------------
80
+
81
+ class CommaSeparatedList(click.ParamType):
82
+ name = 'list'
83
+
84
+ def convert(self, value, param, ctx):
85
+ _ = param, ctx
86
+ if value is None or value.lower() == 'none' or value == '':
87
+ return []
88
+ return value.split(',')
89
+
90
+ #----------------------------------------------------------------------------
91
+
92
+ def calc_metrics_for_dataset(ctx, metrics, real_data_path, fake_data_path, mirror, resolution, gpus, verbose, use_cache: bool, num_runs: int):
93
+ dnnlib.util.Logger(should_flush=True)
94
+
95
+ # Validate arguments.
96
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, verbose=verbose)
97
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
98
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
99
+ if not args.num_gpus >= 1:
100
+ ctx.fail('--gpus must be at least 1')
101
+
102
+ dummy_dataset_cfg = OmegaConf.create({'max_num_frames': 10000})
103
+
104
+ # Initialize dataset options for real data.
105
+ args.dataset_kwargs = dnnlib.EasyDict(
106
+ class_name='training.dataset.VideoFramesFolderDataset',
107
+ path=real_data_path,
108
+ cfg=dummy_dataset_cfg,
109
+ xflip=mirror,
110
+ resolution=resolution,
111
+ use_labels=False,
112
+ )
113
+
114
+ # Initialize dataset options for fake data.
115
+ args.gen_dataset_kwargs = dnnlib.EasyDict(
116
+ class_name='training.dataset.VideoFramesFolderDataset',
117
+ path=fake_data_path,
118
+ cfg=dummy_dataset_cfg,
119
+ xflip=False,
120
+ resolution=resolution,
121
+ use_labels=False,
122
+ )
123
+ args.generator_as_dataset = True
124
+
125
+ # Print dataset options.
126
+ if args.verbose:
127
+ print('Real data options:')
128
+ print(args.dataset_kwargs)
129
+
130
+ print('Fake data options:')
131
+ print(args.gen_dataset_kwargs)
132
+
133
+ # Locate run dir.
134
+ args.run_dir = None
135
+ args.use_cache = use_cache
136
+ args.num_runs = num_runs
137
+
138
+ # Launch processes.
139
+ if args.verbose:
140
+ print('Launching processes...')
141
+ torch.multiprocessing.set_start_method('spawn')
142
+ with tempfile.TemporaryDirectory() as temp_dir:
143
+ if args.num_gpus == 1:
144
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
145
+ else:
146
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
147
+
148
+ #----------------------------------------------------------------------------
149
+
150
+ @click.command()
151
+ @click.pass_context
152
+ @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fvd2048_16f,fid50k_full', show_default=True)
153
+ @click.option('--real_data_path', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
154
+ @click.option('--fake_data_path', help='Generated images (directory or zip)', metavar='PATH')
155
+ @click.option('--mirror', help='Should we mirror the real data?', type=bool, metavar='BOOL')
156
+ @click.option('--resolution', help='Resolution for the source dataset', type=int, metavar='INT')
157
+ @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
158
+ @click.option('--verbose', help='Print optional information', type=bool, default=False, metavar='BOOL', show_default=True)
159
+ @click.option('--use_cache', help='Use stats cache', type=bool, default=True, metavar='BOOL', show_default=True)
160
+ @click.option('--num_runs', help='Number of runs', type=int, default=1, metavar='INT', show_default=True)
161
+ def calc_metrics_cli_wrapper(ctx, *args, **kwargs):
162
+ calc_metrics_for_dataset(ctx, *args, **kwargs)
163
+
164
+ #----------------------------------------------------------------------------
165
+
166
+ if __name__ == "__main__":
167
+ calc_metrics_cli_wrapper() # pylint: disable=no-value-for-parameter
168
+
169
+ #----------------------------------------------------------------------------
src/scripts/clip_edit.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sys; sys.path.extend(['.', 'src', '/home/skoroki/StyleCLIP'])
2
+ import argparse
3
+ import math
4
+ import os
5
+ from typing import List
6
+ import json
7
+ import re
8
+ import random
9
+ import yaml
10
+ import itertools
11
+
12
+ import torchvision
13
+ from torch import optim
14
+ from PIL import Image
15
+ import click
16
+ import numpy as np
17
+ import torch
18
+ from tqdm import tqdm
19
+ from omegaconf import OmegaConf
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torchvision import utils
23
+ from torch import Tensor
24
+ import torchvision.transforms.functional as TVF
25
+ from torchvision.utils import save_image
26
+ from torch import Tensor
27
+
28
+ from src.deps.facial_recognition.model_irse import Backbone
29
+
30
+ try:
31
+ import clip
32
+ except ImportError:
33
+ raise ImportError(
34
+ "To edit videos with CLIP, you need to install the `clip` library. " \
35
+ "Please follow the instructions in https://github.com/openai/CLIP")
36
+
37
+ from src import dnnlib
38
+ import legacy
39
+ from src.scripts.project import save_edited_w
40
+
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
45
+ lr_ramp = min(1, (1 - t) / rampdown)
46
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
47
+ lr_ramp = lr_ramp * min(1, t / rampup)
48
+
49
+ return initial_lr * lr_ramp
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ class CLIPLoss(torch.nn.Module):
54
+ """
55
+ Copy-pasted and adapted from StyleCLIP
56
+ """
57
+ def __init__(self):
58
+ super(CLIPLoss, self).__init__()
59
+ self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
60
+ #self.upsample = torch.nn.Upsample(scale_factor=7)
61
+ #self.avg_pool = torch.nn.AvgPool2d(kernel_size=opts.stylegan_size // 32)
62
+
63
+ def forward(self, image, text):
64
+ #image = self.avg_pool(self.upsample(image))
65
+ #print('shape', image.shape, text.shape)
66
+ image = F.interpolate(image, size=(224, 224), mode='area')
67
+ similarity = 1 - self.model(image, text)[0] / 100
68
+ similarity = similarity.diag()
69
+
70
+ return similarity
71
+
72
+ #----------------------------------------------------------------------------
73
+
74
+ class IDLoss(nn.Module):
75
+ """
76
+ Copy-pasted from StyleCLIP
77
+ """
78
+ def __init__(self):
79
+ super(IDLoss, self).__init__()
80
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
81
+ with dnnlib.util.open_url(Backbone.WEIGHTS_URL, verbose=True) as f:
82
+ ir_se50_weights = torch.load(f)
83
+ self.facenet.load_state_dict(ir_se50_weights)
84
+ self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
85
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
86
+ self.facenet.eval()
87
+ self.facenet.cuda()
88
+
89
+ def extract_feats(self, x):
90
+ if x.shape[2] != 256:
91
+ x = self.pool(x)
92
+ x = x[:, :, 35:223, 32:220] # Crop interesting region
93
+ x = self.face_pool(x)
94
+ x_feats = self.facenet(x)
95
+ return x_feats
96
+
97
+ def forward(self, y_hat, y):
98
+ n_samples = y.shape[0]
99
+ y_feats = self.extract_feats(y) # Otherwise use the feature from there
100
+ y_hat_feats = self.extract_feats(y_hat)
101
+ y_feats = y_feats.detach()
102
+ loss = 0
103
+
104
+ for i in range(n_samples):
105
+ diff_target = y_hat_feats[i].dot(y_feats[i])
106
+ loss += 1 - diff_target
107
+
108
+ return loss / n_samples
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def run_edit_optimization(
113
+ _sentinel=None,
114
+ G: nn.Module=None,
115
+ w_orig: Tensor=None,
116
+ descriptions: List[str]=None,
117
+ # ckpt: float="stylegan2-ffhq-config-f.pt",
118
+ lr: float=0.1,
119
+ num_steps: int=40,
120
+ l2_lambda: float=0.001,
121
+ id_lambda: float=0.005,
122
+ # latent_path: float=latent_path,
123
+ # truncation: float=0.7,
124
+ # save_intermediate_image_every: float=1 if create_video else 20,
125
+ # results_dir: float="results",
126
+ mask: float=None,
127
+ mask_lambda: float=0.0,
128
+ verbose: bool=False,
129
+ ) -> Tensor:
130
+ assert _sentinel is None
131
+ # text_inputs = torch.cat([clip.tokenize(d) for d in descriptions]).to(device)
132
+ num_prompts = len(descriptions)
133
+ num_images = len(w_orig)
134
+ device = w_orig.device
135
+
136
+ text_inputs = clip.tokenize(descriptions).to(device) # [num_prompts, 77]
137
+ text_inputs = text_inputs.repeat_interleave(len(w_orig), dim=0) # [num_prompts * num_images, 77]
138
+
139
+ c = torch.zeros(num_prompts * num_images, 0, device=device)
140
+ ts = torch.zeros(num_prompts * num_images, 1, device=device)
141
+ w_orig = w_orig.repeat(num_prompts, 1, 1) # [num_prompts * num_images, num_ws, w_dim]
142
+
143
+ with torch.no_grad():
144
+ img_orig = G.synthesis(ws=w_orig, c=c, t=ts) # [num_prompts * num_images, 3, c, h, w]
145
+
146
+ w = w_orig.detach().clone() # [num_prompts * num_images, num_ws, w_dim]
147
+ w.requires_grad = True
148
+
149
+ if mask_lambda > 0:
150
+ target_image = img_orig * (1 - mask) # [num_prompts * num_images, 3, c, h, w]
151
+ #target_image = img_orig[:, :, -128:, :128]
152
+ target_image = (target_image * 0.5 + 0.5) * 255.0 # [num_prompts * num_images, 3, c, h, w]
153
+ if target_image.shape[2] > 256:
154
+ target_image = F.interpolate(target_image, size=(256, 256), mode='area')
155
+ target_features = vgg16(target_image, resize_images=False, return_lpips=True)
156
+ #dist = (target_features - synth_features).square().sum()
157
+ else:
158
+ target_features = None
159
+
160
+ clip_loss = CLIPLoss()
161
+ id_loss = IDLoss()
162
+ optimizer = optim.Adam([w], lr=lr)
163
+
164
+ if verbose:
165
+ pbar = tqdm(range(num_steps))
166
+ else:
167
+ pbar = range(num_steps)
168
+
169
+ for curr_iter in pbar:
170
+ curr_lr = get_lr(curr_iter / num_steps, lr)
171
+ # optimizer.param_groups[0]["lr"] = lr
172
+ for param_group in optimizer.param_groups:
173
+ param_group['lr'] = curr_lr
174
+
175
+ #img_gen, _ = g_ema([latent], input_is_latent=True, randomize_noise=False, input_is_stylespace=work_in_stylespace)
176
+ img_gen = G.synthesis(ws=w, c=c, t=ts) # [num_prompts * num_images, 3, c, h, w]
177
+
178
+ if mask_lambda > 0:
179
+ raise NotImplementedError
180
+ synth_image = img_gen * (1 - mask)
181
+ #synth_image = img_gen[:, :, -128:, :128]
182
+ synth_image = (synth_image * 0.5 + 0.5) * 255.0
183
+ if synth_image.shape[2] > 256:
184
+ synth_image = F.interpolate(synth_image, size=(256, 256), mode='area')
185
+ synth_features = vgg16(synth_image, resize_images=False, return_lpips=True)
186
+ mask_loss = (target_features - synth_features).square().sum()
187
+ else:
188
+ mask_loss = 0
189
+
190
+ if not mask is None:
191
+ img_gen = img_gen * mask.unsqueeze(0) # [num_prompts * num_images, 3, c, h, w]
192
+
193
+ c_loss = clip_loss(img_gen, text_inputs) # [num_prompts * num_images]
194
+
195
+ if id_lambda > 0:
196
+ i_loss = id_loss(img_gen, img_orig)
197
+ else:
198
+ i_loss = 0
199
+
200
+ l2_loss = ((w_orig - w) ** 2) # [1]
201
+ loss = c_loss.sum() + l2_lambda * l2_loss.sum() + id_lambda * i_loss + mask_lambda * mask_loss
202
+
203
+ optimizer.zero_grad()
204
+ loss.backward()
205
+ optimizer.step()
206
+
207
+ if verbose:
208
+ pbar.set_description((f"loss: {loss.item():.4f};"))
209
+
210
+ final_result = torch.stack([img_orig, img_gen]) # [2, num_prompts * num_images, c, h, w]
211
+
212
+ return final_result, w
213
+
214
+ # x, new_w = main(args)
215
+
216
+ # pair = torch.cat([img for img in x], dim=2)
217
+ # TVF.to_pil_image((pair.cpu().detach() * 0.5 + 0.5).clamp(0, 1))
218
+
219
+ #----------------------------------------------------------------------------
220
+
221
+ @click.command()
222
+ @click.pass_context
223
+ @click.option('--network_pkl', help='Network pickle filename', metavar='PATH')
224
+ @click.option('--networks_dir', help='Network pickles directory', metavar='PATH')
225
+ # @click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True)
226
+ # @click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
227
+ # @click.option('--same_motion_codes', type=bool, help='Should we use the same motion codes for all videos?', default=False, show_default=True)
228
+ @click.option('--w_dir', help='A directory leading to latent codes.', type=str, required=False, metavar='DIR')
229
+ @click.option('--results_dir', help='A directory to save the results in.', type=str, required=False, metavar='DIR')
230
+ @click.option('--truncation_psi', help='If we use new w, what truncation to use.', type=float, required=False, metavar='FLOAT', default=1.0)
231
+ @click.option('--num_w', help='If we use new w, how many to sample?', type=int, required=False, metavar='FLOAT', default=16)
232
+ @click.option('--prompts', help='A path to prompts or a string of prompts.', type=str, required=True, metavar='DIR')
233
+ @click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
234
+ @click.option('--zero_periods', help='Zero-out periods predictor?', default=False, type=bool, metavar='BOOL')
235
+ @click.option('--num_weights_to_slice', help='Number of high-frequency coords to remove.', default=0, type=int, metavar='INT')
236
+ @click.option('--num_steps', help='Number of the optimization steps to perform.', default=40, type=int, metavar='INT')
237
+ @click.option('--stack_samples', help='When saving, should we stack samples together?', default=False, type=bool, metavar='BOOL')
238
+ # l2_lambda=0.001,
239
+ # id_lambda=0.005,
240
+ # l2_lambda=0.0005,
241
+ # id_lambda=0.0,
242
+ @click.option('--l2_lambda', help='L2 loss coef', default=0.001, type=float, metavar='FLOAT')
243
+ @click.option('--id_lambda', help='ID loss coef', default=0.005, type=float, metavar='FLOAT')
244
+ @click.option('--lr', help='Learning rate', default=0.1, type=float, metavar='FLOAT')
245
+ @click.option('--mask_lambda', help='If we use a mask, specify the loss coef', default=0.0, type=float, metavar='FLOAT')
246
+ @click.option('--use_id_lambda', help='Should we use id lambda in HPO?', default=False, type=bool, metavar='BOOL')
247
+
248
+ def main(
249
+ ctx: click.Context,
250
+ network_pkl: str,
251
+ networks_dir: str,
252
+ seed: int,
253
+ w_dir: str,
254
+ results_dir: str,
255
+ truncation_psi: float,
256
+ num_w: int,
257
+ # save_as_mp4: bool,
258
+ # video_len: int,
259
+ # fps: int,
260
+ # as_grids: bool,
261
+ zero_periods: bool,
262
+ num_weights_to_slice: int,
263
+ num_steps: int,
264
+ stack_samples: bool,
265
+ l2_lambda: float,
266
+ id_lambda: float,
267
+ lr: float,
268
+ prompts: str,
269
+ mask_lambda: float,
270
+ use_id_lambda: bool,
271
+ ):
272
+ if network_pkl is None:
273
+ output_regex = "^network-snapshot-\d{6}.pkl$"
274
+ ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
275
+ # ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
276
+ # network_pkl = os.path.join(networks_dir, ckpts[-1])
277
+ metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
278
+ with open(metrics_file, 'r') as f:
279
+ snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
280
+ best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
281
+ network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
282
+ print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
283
+ # Selecting a checkpoint with the best score
284
+ else:
285
+ assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir"
286
+
287
+ print('Loading networks from "%s"...' % network_pkl, end='')
288
+ device = torch.device('cuda')
289
+ with dnnlib.util.open_url(network_pkl) as f:
290
+ G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
291
+ print('Loaded!')
292
+
293
+ random.seed(seed)
294
+ np.random.seed(seed)
295
+ torch.manual_seed(seed)
296
+
297
+ if zero_periods:
298
+ G.synthesis.motion_encoder.time_encoder.periods_predictor.weight.data.zero_()
299
+
300
+ if num_weights_to_slice > 0:
301
+ G.synthesis.motion_encoder.time_encoder.weights[:, -num_weights_to_slice:] = 0.0
302
+
303
+ # description = "Bright sunny sky and mountains far away"
304
+ # experiment_type = 'edit' #@param ['edit', 'free_generation']
305
+ # mask = torch.zeros(3, 256, 256, device=device)
306
+ # mask[:, :, 64+32 : 128+32] = 1.0
307
+ # mask[:, :-128, :] = 1.0
308
+ # mask[:, :, 128:] = 1.0
309
+
310
+ if w_dir is None:
311
+ print('Sampling new w')
312
+ z = torch.randn(num_w, G.z_dim, device=device)
313
+ c = torch.zeros(len(z), G.c_dim, device=device)
314
+ w_orig = G.mapping(z=z, c=c, truncation_psi=truncation_psi)
315
+ os.makedirs(results_dir, exist_ok=True)
316
+ torch.save(w_orig.cpu(), f'{results_dir}_w_orig.pt')
317
+ w_save_dir = os.path.join(results_dir, 'w_edit')
318
+ samples_save_dir = os.path.join(results_dir, 'edited_samples')
319
+ else:
320
+ w_paths = sorted([os.path.join(w_dir, f) for f in os.listdir(w_dir) if f.endswith('_w.pt')])
321
+ w_names = [os.path.basename(f) for f in w_paths]
322
+ w_orig = [torch.load(f) for f in w_paths]
323
+ w_orig = torch.stack(w_orig).to(device) # [num_images, num_ws, w_dim]
324
+ w_save_dir = f'{w_dir}_edited_w'
325
+ samples_save_dir = f'{w_dir}_edited_samples'
326
+
327
+ os.makedirs(w_save_dir, exist_ok=True)
328
+ os.makedirs(samples_save_dir, exist_ok=True)
329
+
330
+ print(f'Loading prompts from file: {prompts}')
331
+ with open(prompts, 'r') as f:
332
+ descs_dict = yaml.load(f)
333
+ edit_names, descriptions = list(zip(*descs_dict.items()))
334
+ edit_names = edit_names
335
+ descriptions = descriptions
336
+
337
+ del id_lambda, num_steps, l2_lambda
338
+ l2_lambdas = [1000000.0, 0.0025, 0.001, 0.00025, 0.0005, 0.0001]
339
+ if use_id_lambda:
340
+ id_lambdas = [0.005, 0.0025, 0.001, 0.00025, 0.0005, 0.0001, 0.0]
341
+ else:
342
+ id_lambdas = [0.0]
343
+ all_num_steps = [40]
344
+
345
+ for curr_edit_name, curr_prompt in zip(edit_names, descriptions):
346
+ all_images = []
347
+ all_w_edited = []
348
+
349
+ for l2_lambda, id_lambda, num_steps in tqdm(list(itertools.product(l2_lambdas, id_lambdas, all_num_steps)), desc=f'Performing HPO for {curr_edit_name}'):
350
+ final_image, w_edited = run_edit_optimization(
351
+ G=G,
352
+ w_orig=w_orig,
353
+ descriptions=[curr_prompt],
354
+ # ckpt="stylegan2-ffhq-config-f.pt",
355
+ lr=lr,
356
+ num_steps=num_steps,
357
+ l2_lambda=l2_lambda,
358
+ id_lambda=id_lambda,
359
+ mask_lambda=mask_lambda,
360
+ verbose=False,
361
+ # latent_path=latent_path,
362
+ # truncation=0.7,
363
+ # mask=None,
364
+ # mask_lambda=0.1,
365
+ )
366
+
367
+ all_images.extend((final_image[1].cpu() * 0.5 + 0.5).clamp(0, 1))
368
+ all_w_edited.append({
369
+ "w_edit": w_edited.cpu(),
370
+ "l2_lambda": l2_lambda,
371
+ "id_lambda": id_lambda,
372
+ "num_steps": num_steps,
373
+ "prompt": curr_prompt,
374
+ "edit_name": curr_edit_name,
375
+ })
376
+
377
+ # img_names = [f'{w_name}_{edit_name}' for edit_name in edit_names for w_name in w_names]
378
+
379
+ # save_edited_w(
380
+ # G=G,
381
+ # w_outdir = f'{w_dir}_edited',
382
+ # samples_outdir = f'{w_dir}_projected_samples',
383
+ # img_names=img_names,
384
+ # stack_samples=stack_samples,
385
+ # all_w = w_edited,
386
+ # all_motion_z = None,
387
+ # stacked_samples_out_path = f'{w_dir}_edited_samples.png'
388
+ # )
389
+
390
+ torch.save(all_w_edited, f"{w_save_dir}/{curr_edit_name}_w.pt")
391
+ grid = utils.make_grid(torch.stack(all_images), nrow=len(w_orig))
392
+ print('savig intp', f"{samples_save_dir}/{curr_edit_name}.png")
393
+ save_image(grid, f"{samples_save_dir}/{curr_edit_name}.png")
394
+
395
+ print('Done!')
396
+
397
+
398
+ #----------------------------------------------------------------------------
399
+
400
+ if __name__ == "__main__":
401
+ main() # pylint: disable=no-value-for-parameter
402
+
403
+ #----------------------------------------------------------------------------
src/scripts/construct_static_videos_dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Takes a dataset directory and repeats the frames to include only a random frame from each video
3
+ This is needed to calculate same-frame FVD and DiFID
4
+ """
5
+ import os
6
+ import random
7
+ import argparse
8
+ from typing import List
9
+ import shutil
10
+ from tqdm import tqdm
11
+
12
+
13
+ def construct_static_videos_dataset(videos_dir: os.PathLike, max_len: int=None, output_dir: os.PathLike=None, force_len: int=None):
14
+ output_dir = output_dir if not output_dir is None else f'{videos_dir}_freeze'
15
+ clips_paths = [os.path.join(videos_dir, d) for d in os.listdir(videos_dir)]
16
+
17
+ print(f'Saving into {output_dir}')
18
+
19
+ for video_idx, clip_path in enumerate(tqdm(clips_paths)):
20
+ frames_paths = os.listdir(clip_path)
21
+ frame_to_repeat = random.choice(frames_paths)
22
+ curr_output_dir = os.path.join(output_dir, f'{video_idx:05d}')
23
+ os.makedirs(curr_output_dir, exist_ok=True)
24
+ num_frames_to_create = force_len if not force_len is None else min(len(frames_paths), max_len)
25
+
26
+ for i in range(num_frames_to_create):
27
+ ext = os.path.splitext(frame_to_repeat)[1].lower()
28
+ target_file_path = os.path.join(curr_output_dir, f'{i:06d}{ext}')
29
+ shutil.copy(os.path.join(clip_path, frame_to_repeat), target_file_path)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('-d', '--directory', type=str, help='Directory with video frames')
35
+ parser.add_argument('-o', '--output_dir', type=None, help='Where to save the file?.')
36
+ parser.add_argument('-l', '--max_len', type=int, help='Max video length')
37
+ parser.add_argument('-fl', '--force_len', type=int, help='Force video length')
38
+
39
+ args = parser.parse_args()
40
+
41
+ construct_static_videos_dataset(
42
+ videos_dir=args.directory,
43
+ max_len=args.max_len,
44
+ output_dir=args.output_dir,
45
+ force_len=args.force_len,
46
+ )
src/scripts/convert_video_to_dataset.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts a dataset of mp4 videos into a dataset of video frames
3
+ I.e. a directory of mp4 files becomes a directory of directories of frames
4
+ This speeds up loading during training because we do not need
5
+ """
6
+ import os
7
+ from typing import List
8
+ import argparse
9
+ from pathlib import Path
10
+ from multiprocessing import Pool
11
+ from collections import Counter
12
+
13
+ import numpy as np
14
+ from PIL import Image
15
+ import torchvision.transforms.functional as TVF
16
+ from moviepy.editor import VideoFileClip
17
+ from tqdm import tqdm
18
+
19
+
20
+ def convert_videos_into_dataset(video_path: os.PathLike, target_dir: os.PathLike, num_chunks: int, chunk_size: int, start_frame: int, target_size: int, force_fps: int):
21
+ assert (num_chunks is None) or (chunk_size is None), "Cant use both num_chunks and chunk_size"
22
+
23
+ os.makedirs(target_dir, exist_ok=True)
24
+ clip = VideoFileClip(video_path)
25
+ fps = clip.fps if force_fps is None else force_fps
26
+ num_frames_total = int(np.floor(clip.duration * fps)) - start_frame
27
+
28
+ if num_chunks is None:
29
+ num_chunks = num_frames_total // chunk_size
30
+ else:
31
+ chunk_size = num_frames_total // num_chunks
32
+
33
+ num_frames_to_save = chunk_size * num_chunks
34
+
35
+ print(f'Processing the video at {fps} fps. {num_frames_total} frames in total. We have {num_chunks} videos of {chunk_size} frames each.')
36
+
37
+ current_chunk_idx = 0
38
+ frame_idx = -start_frame
39
+ curr_chunk_dir = os.path.join(target_dir, f'{current_chunk_idx:06d}')
40
+
41
+ for frame in tqdm(clip.iter_frames(fps=fps), total=num_frames_total + start_frame):
42
+ if frame_idx >= 0:
43
+ os.makedirs(curr_chunk_dir, exist_ok=True)
44
+ frame = Image.fromarray(frame)
45
+ frame = TVF.center_crop(frame, output_size=min(frame.size))
46
+ frame = TVF.resize(frame, size=target_size, interpolation=Image.LANCZOS)
47
+ frame.save(os.path.join(curr_chunk_dir, f'{frame_idx % chunk_size:06d}.jpg'), q=95)
48
+
49
+ frame_idx += 1
50
+ if frame_idx % chunk_size == 0 and frame_idx > 0:
51
+ current_chunk_idx += 1
52
+ curr_chunk_dir = os.path.join(target_dir, f'{current_chunk_idx:06d}')
53
+
54
+ if frame_idx == num_frames_to_save:
55
+ # Stop here so not to have a partially-filled chunk
56
+ break
57
+
58
+ chunk_sizes = [len(os.listdir(d)) for d in listdir_full_paths(target_dir)]
59
+ assert len(set(chunk_sizes)) == 1, f"Bad chunk sizes: {set(chunk_sizes)}"
60
+
61
+ print('Finished successfully!')
62
+
63
+
64
+ def listdir_full_paths(d) -> List[os.PathLike]:
65
+ return sorted([os.path.join(d, x) for x in os.listdir(d)])
66
+
67
+
68
+ if __name__ == "__main__":
69
+ parser = argparse.ArgumentParser(description='Convert a long video into a dataset of frame dirs')
70
+ parser.add_argument('-s', '--source_video_path', type=str, help='Path to the source video')
71
+ parser.add_argument('-t', '--target_dir', type=str, help='Where to save the new dataset')
72
+ parser.add_argument('-n', '--num_chunks', type=int, help='How many samples should there be in the dataset?')
73
+ parser.add_argument('-cs', '--chunk_size', type=int, help='Each video length. Should be used separately from num_chunks')
74
+ parser.add_argument('-sf', '--start_frame', type=int, default=0, help='Start frame idx. Should we skip several frames?')
75
+ parser.add_argument('--target_size', type=int, default=128, help='What size should we resize to?')
76
+ parser.add_argument('--force_fps', type=int, help='What fps should we run videos with?')
77
+ args = parser.parse_args()
78
+
79
+ convert_videos_into_dataset(
80
+ video_path=args.source_video_path,
81
+ target_dir=args.target_dir,
82
+ num_chunks=args.num_chunks,
83
+ chunk_size=args.chunk_size,
84
+ start_frame=args.start_frame,
85
+ target_size=args.target_size,
86
+ force_fps=args.force_fps,
87
+ )
src/scripts/convert_videos_to_frames.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts a dataset of mp4 videos into a dataset of video frames
3
+ I.e. a directory of mp4 files becomes a directory of directories of frames
4
+ This speeds up loading during training because we do not need
5
+ """
6
+ import os
7
+ from typing import List
8
+ import argparse
9
+ from pathlib import Path
10
+ from multiprocessing import Pool
11
+ from collections import Counter
12
+
13
+ from PIL import Image
14
+ import torchvision.transforms.functional as TVF
15
+ from moviepy.editor import VideoFileClip
16
+ from tqdm import tqdm
17
+
18
+
19
+ def convert_videos_to_frames(source_dir: os.PathLike, target_dir: os.PathLike, num_workers: int, video_ext: str, **process_video_kwargs):
20
+ broken_clips_dir = f'{target_dir}_broken_clips'
21
+ os.makedirs(target_dir, exist_ok=True)
22
+ os.makedirs(broken_clips_dir, exist_ok=True)
23
+
24
+ clips_paths = [cp for cp in listdir_full_paths(source_dir) if cp.endswith(video_ext)]
25
+ clips_fps = []
26
+ tasks_kwargs = [dict(
27
+ clip_path=cp,
28
+ target_dir=target_dir,
29
+ broken_clips_dir=broken_clips_dir,
30
+ **process_video_kwargs,
31
+ ) for cp in clips_paths]
32
+ pool = Pool(processes=num_workers)
33
+
34
+ for fps in tqdm(pool.imap_unordered(task_proxy, tasks_kwargs), total=len(clips_paths)):
35
+ clips_fps.append(fps)
36
+
37
+ print(f'All possible fps: {Counter(clips_fps).most_common()}')
38
+
39
+
40
+ def task_proxy(kwargs):
41
+ """I do not know, how to pass several arguments to a pool job..."""
42
+ return process_video(**kwargs)
43
+
44
+
45
+ def process_video(
46
+ clip_path: os.PathLike, target_dir: os.PathLike, force_fps: int=None, target_size: int=None,
47
+ broken_clips_dir: os.PathLike=None, compute_fps_only: bool=False) -> int:
48
+
49
+ clip_name = os.path.basename(clip_path)
50
+ clip_name = clip_name[:clip_name.rfind('.')]
51
+
52
+ try:
53
+ clip = VideoFileClip(clip_path)
54
+ except KeyboardInterrupt:
55
+ raise
56
+ except Exception as e:
57
+ print(f'Coudnt process clip: {clip_path}')
58
+ if not broken_clips_dir is None:
59
+ Path(os.path.join(broken_clips_dir, clip_name)).touch()
60
+ return 0
61
+
62
+ if compute_fps_only:
63
+ return clip.fps
64
+
65
+ fps = clip.fps if force_fps is None else force_fps
66
+ clip_target_dir = os.path.join(target_dir, clip_name)
67
+ clip_target_dir = clip_target_dir.replace('#', '_')
68
+ os.makedirs(clip_target_dir, exist_ok=True)
69
+
70
+ frame_idx = 0
71
+ for frame in clip.iter_frames(fps=fps):
72
+ frame = Image.fromarray(frame)
73
+ if not target_size is None:
74
+ frame = TVF.resize(frame, size=target_size, interpolation=Image.LANCZOS)
75
+ frame = TVF.center_crop(frame, output_size=(target_size, target_size))
76
+ frame.save(os.path.join(clip_target_dir, f'{frame_idx:06d}.jpg'), q=95)
77
+ frame_idx += 1
78
+
79
+ return clip.fps
80
+
81
+
82
+ def listdir_full_paths(d) -> List[os.PathLike]:
83
+ return sorted([os.path.join(d, x) for x in os.listdir(d)])
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser(description='Convert a dataset of mp4 files into a dataset of individual frames')
88
+ parser.add_argument('-s', '--source_dir', type=str, help='Path to the source dataset')
89
+ parser.add_argument('-t', '--target_dir', type=str, help='Where to save the new dataset')
90
+ parser.add_argument('--video_ext', type=str, default='mp4', help='Video extension')
91
+ parser.add_argument('--target_size', type=int, default=128, help='What size should we resize to?')
92
+ parser.add_argument('--force_fps', type=int, help='What fps should we run videos with?')
93
+ parser.add_argument('--num_workers', type=int, default=8, help='Number of processes to launch')
94
+ parser.add_argument('--compute_fps_only', action='store_true', help='Should we just compute fps?')
95
+ args = parser.parse_args()
96
+
97
+ convert_videos_to_frames(
98
+ source_dir=args.source_dir,
99
+ target_dir=args.target_dir,
100
+ target_size=args.target_size,
101
+ force_fps=args.force_fps,
102
+ num_workers=args.num_workers,
103
+ video_ext=args.video_ext,
104
+ compute_fps_only=args.compute_fps_only,
105
+ )
src/scripts/crop_video_dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import argparse
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from PIL import Image
9
+
10
+
11
+ def crop_video_dataset(source_dir: str, max_num_frames: int=None, slice_n_left_frames: int=0, resize: int=None, target_dir: str=None):
12
+ dataset_name = os.path.basename(source_dir)
13
+ if target_dir is None:
14
+ max_num_frames_prefix = "" if max_num_frames is None else f"_cut{max_num_frames}"
15
+ slice_prefix = "" if slice_n_left_frames == 0 else f"_slice{slice_n_left_frames}"
16
+ new_dataset_name = f"{dataset_name}{max_num_frames_prefix}{slice_prefix}"
17
+ target_dir = os.path.join(os.path.dirname(source_dir), new_dataset_name)
18
+ all_clips_paths = listdir_full_paths(source_dir)
19
+ os.makedirs(target_dir, exist_ok=True)
20
+ slice_proportions = []
21
+
22
+ total_num_frames = 0
23
+
24
+ for source_clip_dir in tqdm(all_clips_paths, desc=f'Cropping the dataset into {target_dir}'):
25
+ all_frames = listdir_full_paths(source_clip_dir)
26
+ if len(all_frames) == 0:
27
+ print(f'{source_clip_dir} is empty. Skipping it.')
28
+ continue
29
+ target_clip_dir = os.path.join(target_dir, os.path.basename(source_clip_dir))
30
+ os.makedirs(target_clip_dir, exist_ok=True)
31
+ total_num_frames += len(all_frames)
32
+ slice_proportions.append(slice_n_left_frames / len(all_frames))
33
+ all_frames = all_frames[slice_n_left_frames:]
34
+
35
+ if not max_num_frames is None:
36
+ all_frames = all_frames[:max_num_frames]
37
+
38
+ for source_frame_path in all_frames:
39
+ target_frame_path = os.path.join(target_clip_dir, os.path.basename(source_frame_path))
40
+
41
+ if resize is None:
42
+ shutil.copy(source_frame_path, target_frame_path)
43
+ else:
44
+ assert target_frame_path.endswith('.jpg')
45
+ Image.open(source_frame_path).resize((resize, resize), resample=Image.LANCZOS).save(target_frame_path, q=95)
46
+
47
+ print(f'Done! Sliced {np.mean(slice_proportions) * 100.0 : .02f}% on average. {len(all_clips_paths) * slice_n_left_frames / total_num_frames * 100.0 : .02f}% of total num frames.')
48
+
49
+
50
+ def listdir_full_paths(d) -> List[os.PathLike]:
51
+ return sorted([os.path.join(d, x) for x in os.listdir(d)])
52
+
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser(description='Crops a video dataset temporally into several frames')
56
+ parser.add_argument('source_dir', type=str, help='Path to the dataset')
57
+ parser.add_argument('-n', '--max_num_frames', type=int, default=None, help='Number of frames to preserve')
58
+ parser.add_argument('--slice_n_left_frames', type=int, default=0, help='Number of frames to slice from the left')
59
+ parser.add_argument('--resize', type=int, default=None, help='Should we resize the dataset')
60
+ parser.add_argument('--target_dir', type=str, default=None, help='Should we resize the dataset')
61
+ args = parser.parse_args()
62
+
63
+ crop_video_dataset(
64
+ source_dir=args.source_dir,
65
+ max_num_frames=args.max_num_frames,
66
+ slice_n_left_frames=args.slice_n_left_frames,
67
+ resize=args.resize,
68
+ target_dir=args.target_dir,
69
+ )
src/scripts/frames_to_video_grid.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Converts a directory of video frames into an mp4-grid
3
+ """
4
+ import sys; sys.path.extend(['.'])
5
+ import os
6
+ import argparse
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import Tensor
12
+ import torchvision.transforms.functional as TVF
13
+ from torchvision import utils
14
+ from PIL import Image
15
+ from tqdm import tqdm
16
+ import torchvision
17
+
18
+
19
+ def frames_to_video_grid(videos_dir: os.PathLike, num_videos: int, length: int, fps: int, output_path: os.PathLike, select_random: bool=False, random_seed: int=None):
20
+ clips_paths = [os.path.join(videos_dir, d) for d in os.listdir(videos_dir)]
21
+
22
+ # bad_idx = [0, 9, 11, 16]
23
+ # clips_paths = [c for i, c in enumerate(clips_paths) if not i in bad_idx]
24
+
25
+ if select_random:
26
+ random.seed(random_seed)
27
+ clips_paths = random.sample(clips_paths, k=num_videos)
28
+ else:
29
+ clips_paths = clips_paths[:num_videos]
30
+ videos = [read_first_n_frames(d, length) for d in tqdm(clips_paths, desc='Reading data...')] # [num_videos, length, c, h, w]
31
+ videos = [fill_with_black_squares(v, length) for v in tqdm(videos, desc='Adding empty frames')] # [num_videos, length, c, h, w]
32
+ frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [video_len, num_videos, c, h, w]
33
+ frame_grids = [utils.make_grid(fs, nrow=int(np.ceil(np.sqrt(num_videos)))) for fs in tqdm(frame_grids, desc='Making grids')]
34
+
35
+ if os.path.dirname(output_path) != "":
36
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
37
+ frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C]
38
+ torchvision.io.write_video(output_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'})
39
+
40
+
41
+ def read_first_n_frames(d: os.PathLike, num_frames: int) -> Tensor:
42
+ images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))[:num_frames]]
43
+ images = [TVF.to_tensor(x) for x in images]
44
+
45
+ return torch.stack(images)
46
+
47
+
48
+ def fill_with_black_squares(video, desired_len: int) -> Tensor:
49
+ if len(video) >= desired_len:
50
+ return video
51
+
52
+ return torch.cat([
53
+ video,
54
+ torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1),
55
+ ], dim=0)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ parser = argparse.ArgumentParser()
60
+ parser.add_argument('-d', '--directory', type=str, help='Directory with video frames')
61
+ parser.add_argument('-n', '--num_videos', type=int, help='Number of videos to consider')
62
+ parser.add_argument('-l', '--length', type=int, help='Video length (in frames)')
63
+ parser.add_argument('--fps', type=int, default=25, help='FPS to save with.')
64
+ parser.add_argument('-o', '--output_path', type=str, help='Where to save the file?.')
65
+ parser.add_argument('--select_random', action='store_true', help='Select videos at random?')
66
+ parser.add_argument('--random_seed', type=int, default=None, help='Random seed when selecting videos at random')
67
+
68
+ args = parser.parse_args()
69
+
70
+ frames_to_video_grid(
71
+ videos_dir=args.directory,
72
+ num_videos=args.num_videos,
73
+ length=args.length,
74
+ fps=args.fps,
75
+ output_path=args.output_path,
76
+ select_random=args.select_random,
77
+ random_seed=args.random_seed,
78
+ )
src/scripts/generate.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generates a dataset of images using pretrained network pickle."""
2
+
3
+ import sys; sys.path.extend(['.', 'src'])
4
+ import os
5
+ import json
6
+ import random
7
+ import warnings
8
+
9
+ import click
10
+ from src import dnnlib
11
+ import numpy as np
12
+ import torch
13
+ from tqdm import tqdm
14
+ from omegaconf import OmegaConf
15
+
16
+ import src.legacy as legacy
17
+ from src.training.logging import generate_videos, save_video_frames_as_mp4, save_video_frames_as_frames_parallel
18
+
19
+ torch.set_grad_enabled(False)
20
+
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ @click.command()
25
+ @click.pass_context
26
+ @click.option('--network_pkl', help='Network pickle filename', metavar='PATH')
27
+ @click.option('--networks_dir', help='Network pickles directory. Selects a checkpoint from it automatically based on the fvd2048_16f metric.', metavar='PATH')
28
+ @click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True)
29
+ @click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
30
+ @click.option('--num_videos', type=int, help='Number of images to generate', default=50000, show_default=True)
31
+ @click.option('--batch_size', type=int, help='Batch size to use for generation', default=32, show_default=True)
32
+ @click.option('--moco_decomposition', type=bool, help='Should we do content/motion decomposition (available only for `--as_grids 1` generation)?', default=False, show_default=True)
33
+ @click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
34
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
35
+ @click.option('--save_as_mp4', help='Should we save as independent frames or mp4?', type=bool, default=False, metavar='BOOL')
36
+ @click.option('--video_len', help='Number of frames to generate', type=int, default=16, metavar='INT')
37
+ @click.option('--fps', help='FPS for mp4 saving', type=int, default=25, metavar='INT')
38
+ @click.option('--as_grids', help='Save videos as grids', type=bool, default=False, metavar='BOOl')
39
+ @click.option('--time_offset', help='Additional time offset', default=0, type=int, metavar='INT')
40
+ @click.option('--dataset_path', help='Dataset path. In case we want to use the conditioning signal.', default="", type=str, metavar='PATH')
41
+ @click.option('--hydra_cfg_path', help='Config path', default="", type=str, metavar='PATH')
42
+ @click.option('--slowmo_coef', help='Increase this value if you want to produce slow-motion videos.', default=1, type=int, metavar='INT')
43
+ def generate(
44
+ ctx: click.Context,
45
+ network_pkl: str,
46
+ networks_dir: str,
47
+ truncation_psi: float,
48
+ noise_mode: str,
49
+ num_videos: int,
50
+ batch_size: int,
51
+ moco_decomposition: bool,
52
+ seed: int,
53
+ outdir: str,
54
+ save_as_mp4: bool,
55
+ video_len: int,
56
+ fps: int,
57
+ as_grids: bool,
58
+ time_offset: int,
59
+ dataset_path: os.PathLike,
60
+ hydra_cfg_path: os.PathLike,
61
+ slowmo_coef: int,
62
+ ):
63
+ if network_pkl is None:
64
+ # output_regex = "^network-snapshot-\d{6}.pkl$"
65
+ # ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
66
+ # ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
67
+ # network_pkl = os.path.join(networks_dir, ckpts[-1])
68
+ ckpt_select_metric = 'fvd2048_16f'
69
+ metrics_file = os.path.join(networks_dir, f'metric-{ckpt_select_metric}.jsonl')
70
+ with open(metrics_file, 'r') as f:
71
+ snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
72
+ best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results'][ckpt_select_metric])[0]
73
+ network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
74
+ print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results'][ckpt_select_metric])
75
+ # Selecting a checkpoint with the best score
76
+ else:
77
+ assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir"
78
+
79
+ if moco_decomposition:
80
+ assert as_grids, f"Content/motion decomposition is available only when we generate as grids."
81
+ assert batch_size == num_videos, "Same motion is supported only for batch_size == num_videos"
82
+
83
+ print('Loading networks from "%s"...' % network_pkl)
84
+ device = torch.device('cuda')
85
+ with dnnlib.util.open_url(network_pkl) as f:
86
+ G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
87
+
88
+ os.makedirs(outdir, exist_ok=True)
89
+
90
+ random.seed(seed)
91
+ np.random.seed(seed)
92
+ torch.manual_seed(seed)
93
+
94
+ all_z = torch.randn(num_videos, G.z_dim, device=device) # [curr_batch_size, z_dim]
95
+ if dataset_path and G.c_dim > 0:
96
+ hydra_cfg_path = hydra_cfg_path or os.path.join(networks_dir, '..', "experiment_config.yaml")
97
+ hydra_cfg = OmegaConf.load(hydra_cfg_path)
98
+ training_set_kwargs = dnnlib.EasyDict(
99
+ class_name='training.dataset.VideoFramesFolderDataset',
100
+ path=dataset_path, cfg=hydra_cfg.dataset, use_labels=True, max_size=None, xflip=False)
101
+ training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs)
102
+ all_c = [training_set.get_label(random.choice(range(len(training_set)))) for _ in range(num_videos)] # [num_videos, c_dim]
103
+ all_c = torch.from_numpy(np.array(all_c)).to(device) # [num_videos, c_dim]
104
+ elif G.c_dim > 0:
105
+ warnings.warn('Assuming that the conditioning is one-hot!')
106
+ c_idx = torch.randint(low=0, high=G.c_dim, size=(num_videos, 1), device=device)
107
+ all_c = torch.zeros(num_videos, G.c_dim, device=device) # [num_videos, c_dim]
108
+ all_c.scatter_(1, c_idx, 1)
109
+ else:
110
+ all_c = torch.zeros(num_videos, G.c_dim, device=device) # [num_videos, c_dim]
111
+ ts = time_offset + torch.arange(video_len, device=device).float().unsqueeze(0).repeat(batch_size, 1) / slowmo_coef # [batch_size, video_len]
112
+ if moco_decomposition:
113
+ num_rows = num_cols = int(np.sqrt(num_videos))
114
+ motion_z = G.synthesis.motion_encoder(c=all_c[:num_rows], t=ts[:num_rows])['motion_z'] # [1, *motion_dims]
115
+ motion_z = motion_z.repeat_interleave(num_cols, dim=0) # [batch_size, *motion_dims]
116
+
117
+ all_z = all_z[:num_cols].repeat(num_rows, 1) # [num_videos, z_dim]
118
+ all_c = all_c[:num_cols].repeat(num_rows, 1) # [num_videos, z_dim]
119
+ else:
120
+ motion_z = None
121
+
122
+ # Generate images.
123
+ for batch_idx in tqdm(range((num_videos + batch_size - 1) // batch_size), desc='Generating videos'):
124
+ curr_batch_size = batch_size if batch_size * (batch_idx + 1) <= num_videos else num_videos % batch_size
125
+ z = all_z[batch_idx * batch_size:batch_idx * batch_size + curr_batch_size] # [curr_batch_size, z_dim]
126
+ c = all_c[batch_idx * batch_size:batch_idx * batch_size + curr_batch_size] # [curr_batch_size, c_dim]
127
+ videos = generate_videos(
128
+ G, z, c, ts, motion_z=motion_z, noise_mode=noise_mode,
129
+ truncation_psi=truncation_psi, as_grids=as_grids, batch_size_num_frames=128)
130
+
131
+ if as_grids:
132
+ videos = [videos]
133
+
134
+ for video_idx, video in enumerate(videos):
135
+ if save_as_mp4:
136
+ save_path = os.path.join(outdir, f'{batch_idx * batch_size + video_idx:06d}.mp4')
137
+ save_video_frames_as_mp4(video, fps, save_path)
138
+ else:
139
+ save_dir = os.path.join(outdir, f'{batch_idx * batch_size + video_idx:06d}')
140
+ video = (video * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy() # [video_len, h, w, c]
141
+ save_video_frames_as_frames_parallel(video, save_dir, time_offset=time_offset, num_processes=8)
142
+
143
+ #----------------------------------------------------------------------------
144
+
145
+ if __name__ == "__main__":
146
+ generate() # pylint: disable=no-value-for-parameter
147
+
148
+ #----------------------------------------------------------------------------
src/scripts/preprocess_ffs.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file preprocesses FaceForensics dataset by cropping it
3
+ Copied from https://github.com/pfnet-research/tgan2/blob/master/scripts/make_face_forensics.py
4
+ """
5
+
6
+ import argparse
7
+ import os
8
+ from typing import List
9
+ from multiprocessing import Pool
10
+ from PIL import Image
11
+
12
+ import cv2
13
+ # import h5py
14
+ import imageio
15
+ import numpy as np
16
+ import pandas
17
+ from tqdm import tqdm
18
+
19
+
20
+ def parse_videos(source_dir, splits: List[str], categories: List[dir]):
21
+ results = []
22
+ for split in splits:
23
+ for category in categories:
24
+ target_dir = os.path.join(source_dir, split, category)
25
+ filenames = sorted(os.listdir(target_dir))
26
+ for filename in filenames:
27
+ results.append({
28
+ 'split': split,
29
+ 'category': category,
30
+ 'filename': filename,
31
+ 'filepath': os.path.join(split, category, filename),
32
+ })
33
+ return pandas.DataFrame(results)
34
+
35
+
36
+ def crop(img, left, right, top, bottom, margin):
37
+ cols = right - left
38
+ rows = bottom - top
39
+ if cols < rows:
40
+ padding = rows - cols
41
+ left -= padding // 2
42
+ right += (padding // 2) + (padding % 2)
43
+ cols = right - left
44
+ else:
45
+ padding = cols - rows
46
+ top -= padding // 2
47
+ bottom += (padding // 2) + (padding % 2)
48
+ rows = bottom - top
49
+ assert(rows == cols)
50
+ return img[top:bottom, left:right]
51
+
52
+
53
+ def job_proxy(kwargs):
54
+ process_and_save_video(**kwargs)
55
+
56
+
57
+ def process_and_save_video(video_path: os.PathLike, mask_path: os.PathLike, img_size: int, wide_crop: bool, output_dir: os.PathLike):
58
+ try:
59
+ video = process_video(video_path, mask_path, img_size=img_size, wide_crop=wide_crop)
60
+ except KeyboardInterrupt:
61
+ raise
62
+ except:
63
+ print(f'Couldnt process {video_path}')
64
+ return
65
+
66
+ os.makedirs(output_dir, exist_ok=True)
67
+
68
+ # if os.path.isdir(output_dir) and len(os.listdir(output_dir)) > 0:
69
+ # return
70
+
71
+ for i, frame in enumerate(video):
72
+ frame = frame.transpose(1, 2, 0)
73
+ Image.fromarray(frame).save(os.path.join(output_dir, f'{i:06d}.jpg'), q=95)
74
+
75
+
76
+ def process_video(video_path, mask_path, img_size, threshold=5, margin=0.02, wide_crop: bool=False):
77
+ video_reader = imageio.get_reader(video_path)
78
+ mask_reader = imageio.get_reader(mask_path)
79
+ assert(video_reader.get_length() == mask_reader.get_length())
80
+
81
+ # Searching for the widest crop which would work for the whole video
82
+ if wide_crop:
83
+ left_most = float('inf')
84
+ top_most = float('inf')
85
+ right_most = float('-inf')
86
+ bottom_most = float('-inf')
87
+
88
+ for img, mask in zip(video_reader, mask_reader):
89
+ hist = (255 - mask).astype(np.float64).sum(axis=2)
90
+ horiz_hist = np.where(hist.mean(axis=0) > threshold)[0]
91
+ vert_hist = np.where(hist.mean(axis=1) > threshold)[0]
92
+ left, right = horiz_hist[0], horiz_hist[-1]
93
+ top, bottom = vert_hist[0], vert_hist[-1]
94
+ left_most = min(left_most, left)
95
+ top_most = min(top_most, top)
96
+ right_most = max(right_most, right)
97
+ bottom_most = max(bottom_most, bottom)
98
+
99
+ video = []
100
+ for img, mask in zip(video_reader, mask_reader):
101
+ if wide_crop:
102
+ left, right, top, bottom = left_most, right_most, top_most, bottom_most
103
+ else:
104
+ hist = (255 - mask).astype(np.float64).sum(axis=2)
105
+ horiz_hist = np.where(hist.mean(axis=0) > threshold)[0]
106
+ vert_hist = np.where(hist.mean(axis=1) > threshold)[0]
107
+ left, right = horiz_hist[0], horiz_hist[-1]
108
+ top, bottom = vert_hist[0], vert_hist[-1]
109
+
110
+ dst_img = crop(img, left, right, top, bottom, margin)
111
+
112
+ try:
113
+ dst_img = cv2.resize(
114
+ dst_img, (img_size, img_size),
115
+ interpolation=cv2.INTER_LANCZOS4).transpose(2, 0, 1)
116
+ video.append(dst_img)
117
+ except KeyboardInterrupt:
118
+ raise
119
+ except:
120
+ print(img.shape, dst_img.shape, left, right, top, bottom)
121
+
122
+ T = len(video)
123
+ video = np.concatenate(video).reshape(T, 3, img_size, img_size)
124
+ return video
125
+
126
+
127
+ # def count_frames(path):
128
+ # reader = imageio.get_reader(path)
129
+ # n_frames = 0
130
+ # while True:
131
+ # try:
132
+ # img = reader.get_next_data()
133
+ # except IndexError as e:
134
+ # break
135
+ # else:
136
+ # n_frames += 1
137
+ # return n_frames
138
+
139
+
140
+ def main():
141
+ parser = argparse.ArgumentParser()
142
+ parser.add_argument('--source_dir', type=str, default='data/FaceForensics_compressed')
143
+ parser.add_argument('--output_dir', type=str, default='data/ffs_processed')
144
+ parser.add_argument('--img_size', type=int, default=256)
145
+ parser.add_argument('--num_workers', type=int, default=8)
146
+ parser.add_argument('--wide_crop', action='store_true', help="Should we crop each frame independently (this makes a video shaking)?")
147
+ args = parser.parse_args()
148
+
149
+ # splits = ['train', 'val', 'test']
150
+ # categories = ['original', 'mask', 'altered']
151
+ splits = ['train']
152
+ categories = ['original', 'mask']
153
+ df = parse_videos(args.source_dir, splits, categories)
154
+ os.makedirs(args.output_dir, exist_ok=True)
155
+
156
+ for split in splits:
157
+ target_frame = df[df['split'] == split]
158
+ filenames = target_frame['filename'].unique()
159
+
160
+ # print('Count # of frames')
161
+ # rets = []
162
+ # for i, filename in enumerate(filenames):
163
+ # fn_frame = target_frame[target_frame['filename'] == filename]
164
+ # video_path = os.path.join(
165
+ # args.source_dir, fn_frame[fn_frame['category'] == 'original'].iloc[0]['filepath'])
166
+ # rets.append(p.apply_async(count_frames, args=(video_path,)))
167
+ # n_frames = 0
168
+ # for ret in tqdm(rets):
169
+ # n_frames += ret.get()
170
+ # print('# of frames: {}'.format(n_frames))
171
+
172
+ # h5file = h5py.File(os.path.join(args.output_dir, '{}.h5'.format(split)), 'w')
173
+ # dset = h5file.create_dataset('image', (n_frames, 3, args.img_size, args.img_size), dtype=np.uint8)
174
+ # conf = []
175
+ # start = 0
176
+
177
+ pool = Pool(processes=args.num_workers)
178
+ job_kwargs_list = []
179
+
180
+ for i, filename in enumerate(filenames):
181
+ fn_frame = target_frame[target_frame['filename'] == filename]
182
+ video_path = os.path.join(args.source_dir, fn_frame[fn_frame['category'] == 'original'].iloc[0]['filepath'])
183
+ mask_path = os.path.join(args.source_dir, fn_frame[fn_frame['category'] == 'mask'].iloc[0]['filepath'])
184
+
185
+ job_kwargs_list.append(dict(
186
+ video_path=video_path,
187
+ mask_path=mask_path,
188
+ img_size=args.img_size,
189
+ wide_crop=args.wide_crop,
190
+ output_dir=os.path.join(args.output_dir, filename[:filename.rfind('.')]),
191
+ ))
192
+
193
+ for _ in tqdm(pool.imap_unordered(job_proxy, job_kwargs_list), desc=f'Processing {split}', total=len(job_kwargs_list)):
194
+ pass
195
+ # T = len(video)
196
+ #dset[start:(start + T)] = video
197
+ # conf.append({'start': start, 'end': (start + T)})
198
+ # start += T
199
+ # conf = pandas.DataFrame(conf)
200
+ # conf.to_json(os.path.join(args.output_dir, '{}.json'.format(split)), orient='records')
201
+
202
+
203
+ if __name__ == '__main__':
204
+ main()
src/scripts/profile_model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script computes imgs/sec for a generator in the eval mode
3
+ for different batch sizes
4
+ """
5
+ import sys; sys.path.extend(['..', '.', 'src'])
6
+ import time
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import hydra
12
+ from hydra.experimental import initialize
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from tqdm import tqdm
15
+ import torch.autograd.profiler as profiler
16
+
17
+ from src import dnnlib
18
+ from src.infra.utils import recursive_instantiate
19
+
20
+
21
+ DEVICE = 'cuda'
22
+ BATCH_SIZES = [32]
23
+ NUM_WARMUP_ITERS = 5
24
+ NUM_PROFILE_ITERS = 25
25
+
26
+
27
+ def instantiate_G(cfg: DictConfig) -> nn.Module:
28
+ G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
29
+ G_kwargs.synthesis_kwargs.channel_base = int(cfg.model.generator.get('fmaps', 0.5) * 32768)
30
+ G_kwargs.synthesis_kwargs.channel_max = 512
31
+ G_kwargs.mapping_kwargs.num_layers = cfg.model.generator.get('mapping_net_n_layers', 2)
32
+ if cfg.get('num_fp16_res', 0) > 0:
33
+ G_kwargs.synthesis_kwargs.num_fp16_res = cfg.num_fp16_res
34
+ G_kwargs.synthesis_kwargs.conv_clamp = 256
35
+ G_kwargs.cfg = cfg.model.generator
36
+ G_kwargs.c_dim = 0
37
+ G_kwargs.img_resolution = cfg.get('resolution', 256)
38
+ G_kwargs.img_channels = 3
39
+
40
+ G = dnnlib.util.construct_class_by_name(**G_kwargs).eval().requires_grad_(False).to(DEVICE)
41
+
42
+ return G
43
+
44
+
45
+ @torch.no_grad()
46
+ def profile_for_batch_size(G: nn.Module, cfg: DictConfig, batch_size: int):
47
+ z = torch.randn(batch_size, G.z_dim, device=DEVICE)
48
+ c = torch.zeros(batch_size, G.c_dim, device=DEVICE)
49
+ t = torch.zeros(batch_size, 2, device=DEVICE)
50
+ times = []
51
+
52
+ for i in tqdm(range(NUM_WARMUP_ITERS), desc='Warming up'):
53
+ torch.cuda.synchronize()
54
+ fake_img = G(z, c=c, t=t).contiguous()
55
+ y = fake_img[0, 0, 0, 0].item() # sync
56
+ torch.cuda.synchronize()
57
+
58
+ time.sleep(1)
59
+
60
+ torch.cuda.reset_peak_memory_stats()
61
+
62
+ with profiler.profile(record_shapes=True, use_cuda=True) as prof:
63
+ for i in tqdm(range(NUM_PROFILE_ITERS), desc='Profiling'):
64
+ torch.cuda.synchronize()
65
+ start_time = time.time()
66
+ with profiler.record_function("forward"):
67
+ fake_img = G(z, c=c, t=t).contiguous()
68
+ y = fake_img[0, 0, 0, 0].item() # sync
69
+ torch.cuda.synchronize()
70
+ times.append(time.time() - start_time)
71
+
72
+ torch.cuda.empty_cache()
73
+ num_imgs_processed = len(times) * batch_size
74
+ total_time_spent = np.sum(times)
75
+ bandwidth = num_imgs_processed / total_time_spent
76
+ summary = prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)
77
+
78
+ print(f'[Batch size: {batch_size}] Mean: {np.mean(times):.05f}s/it. Std: {np.std(times):.05f}s')
79
+ print(f'[Batch size: {batch_size}] Imgs/sec: {bandwidth:.03f}')
80
+ print(f'[Batch size: {batch_size}] Max mem: {torch.cuda.max_memory_allocated(DEVICE) / 2**30:<6.2f} gb')
81
+
82
+ return bandwidth, summary
83
+
84
+
85
+ @hydra.main(config_path="../../configs", config_name="config.yaml")
86
+ def profile(cfg: DictConfig):
87
+ recursive_instantiate(cfg)
88
+ G = instantiate_G(cfg)
89
+ bandwidths = []
90
+ summaries = []
91
+ print(f'Number of parameters: {sum(p.numel() for p in G.parameters())}')
92
+
93
+ for batch_size in BATCH_SIZES:
94
+ bandwidth, summary = profile_for_batch_size(G, cfg, batch_size)
95
+ bandwidths.append(bandwidth)
96
+ summaries.append(summary)
97
+
98
+ best_batch_size_idx = int(np.argmax(bandwidths))
99
+ print(f'------------ Best batch size is {BATCH_SIZES[best_batch_size_idx]} ------------')
100
+ print(summaries[best_batch_size_idx])
101
+
102
+
103
+ if __name__ == '__main__':
104
+ profile()
src/scripts/project.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Given a dataset of images, it (optionally crops it) and embeds into the model
3
+ Also optionally generates random videos from the found w
4
+ """
5
+
6
+ import sys; sys.path.extend(['.', 'src'])
7
+ import os
8
+ import re
9
+ import json
10
+ import random
11
+ from typing import List, Optional, Callable
12
+ from typing import List
13
+
14
+ from PIL import Image
15
+ import click
16
+ from src import dnnlib
17
+ import numpy as np
18
+ import torch
19
+ from tqdm import tqdm
20
+ from omegaconf import OmegaConf
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torchvision import utils
24
+ from torch import Tensor
25
+ import torchvision.transforms.functional as TVF
26
+ from torchvision.utils import save_image
27
+
28
+ import legacy
29
+ from src.training.logging import generate_videos, save_video_frames_as_mp4, save_video_frames_as_frames
30
+ from src.torch_utils import misc
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def project(
35
+ _sentinel=None,
36
+ G: Callable=None,
37
+ vgg16: nn.Module=None,
38
+ target_images: List[Tensor]=None,
39
+ device: str='cuda',
40
+ use_w_init: bool=False,
41
+ use_motion_init: bool=False,
42
+ w_avg_samples = 10000,
43
+ num_steps = 1000,
44
+ initial_learning_rate = 0.1,
45
+ initial_noise_factor = 0.05,
46
+ noise_ramp_length = 0.75,
47
+ lr_rampdown_length = 0.25,
48
+ lr_rampup_length = 0.05,
49
+ #regularize_noise_weight = 1e5,
50
+ regularize_noise_weight = 0.0001,
51
+ motion_reg_type: str=None,
52
+ ):
53
+ num_videos = len(target_images)
54
+
55
+ # misc.assert_shape(target_images, [None, G.img_channels, G.img_resolution, G.img_resolution])
56
+ G = G.eval().requires_grad_(False).to(device) # type: ignore
57
+
58
+ c = torch.zeros(num_videos, G.c_dim, device=device)
59
+ ts = torch.zeros(num_videos, 1, device=device)
60
+
61
+ # Compute w stats.
62
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
63
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
64
+ w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
65
+ w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
66
+ w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
67
+
68
+ # img_mean = G.synthesis(
69
+ # ws=torch.from_numpy(w_avg).repeat(1, G.num_ws, 1).to(device),
70
+ # c=c[0], t=ts[[0]],
71
+ # )
72
+ # img_mean = (img_mean * 0.5 + 0.5).cpu().detach()
73
+ # TVF.to_pil_image(img_mean[0]).save('/tmp/data/mean.png')
74
+ # print('saved!')
75
+
76
+ # Load VGG16 feature detector.
77
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
78
+ with dnnlib.util.open_url(url) as f:
79
+ vgg16 = torch.jit.load(f).eval().to(device)
80
+
81
+ # Features for target image.
82
+ target_features = []
83
+ for img in target_images:
84
+ img = img.to(device).to(torch.float32).unsqueeze(0) * 255.0
85
+ if img.shape[2] > 256:
86
+ img = F.interpolate(img, size=(256, 256), mode='area')
87
+ target_features.append(vgg16(img, resize_images=False, return_lpips=True).squeeze(0))
88
+ target_features = torch.stack(target_features) # [num_images, lpips_dim]
89
+
90
+ if use_w_init:
91
+ w_opt = find_w_init() # [num_videos, 1, w_dim]
92
+ w_opt = w_opt.detach().requires_grad_(True) # [num_videos, num_ws, w_dim]
93
+ else:
94
+ w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
95
+ w_opt = w_opt.repeat(num_videos, G.num_ws, 1).detach().requires_grad_(True) # [num_videos, num_ws, w_dim]
96
+
97
+ # w_opt_to_ws = lambda w_opt: torch.cat([w_opt[:, [0]].repeat(1, G.num_ws // 2, 1), w_opt[:, 1:]], dim=1)
98
+
99
+ # Trying a lot of motions to find which one works best
100
+ if use_motion_init:
101
+ motion_z_opt = select_motions(motion_codes)
102
+ else:
103
+ motion_z_opt = G.synthesis.motion_encoder(c=c, t=ts)['motion_z']
104
+ # motion_z_opt.data = torch.randn_like(motion_z_opt.data) * 1e-3
105
+
106
+ motion_z_opt.requires_grad_(True)
107
+
108
+ w_result = torch.zeros([num_steps] + list(w_opt.shape), dtype=torch.float32, device=device)
109
+ # optimizer = torch.optim.Adam([w_opt] + [motion_z_opt], betas=(0.9, 0.999), lr=initial_learning_rate)
110
+ optimizer = torch.optim.Adam([w_opt], betas=(0.9, 0.999), lr=initial_learning_rate)
111
+
112
+ for step in tqdm(range(num_steps)):
113
+ # Learning rate schedule.
114
+ t = step / num_steps
115
+ w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
116
+ lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
117
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
118
+ lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
119
+ lr = initial_learning_rate * lr_ramp
120
+
121
+ for param_group in optimizer.param_groups:
122
+ param_group['lr'] = lr
123
+
124
+ # Synth images from opt_w.
125
+ w_noise = torch.randn_like(w_opt) * w_noise_scale
126
+ ws = w_opt + w_noise
127
+ #ws = w_opt_to_ws(w_opt + w_noise)
128
+ #ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
129
+ #synth_images = G.synthesis(ws, c=c, t=ts, motion_z=motion_z_opt + torch.randn_like(motion_z_opt) * w_noise_scale)
130
+ synth_images = G.synthesis(ws, c=c, t=ts, motion_z=motion_z_opt)
131
+ #synth_images = G.synthesis(ws, c=c, t=ts)
132
+
133
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
134
+ synth_images = (synth_images * 0.5 + 0.5) * 255.0
135
+ if synth_images.shape[2] > 256:
136
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
137
+
138
+ # Features for synth images.
139
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
140
+ dist = (target_features - synth_features).square().sum()
141
+
142
+ # Noise regularization.
143
+ if motion_reg_type is None:
144
+ reg_loss = 0.0
145
+ elif motion_reg_type == "norm":
146
+ reg_loss = motion_z_opt.norm(dim=2).mean()
147
+ elif motion_reg_type == "dist":
148
+ reg_loss = motion_z_opt.mean().pow(2) + (motion_z_opt.var() - 1).pow(2)
149
+ elif motion_reg_type == "sg2":
150
+ for v in noise_bufs.values():
151
+ noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
152
+ while True:
153
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
154
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
155
+ if noise.shape[2] <= 8:
156
+ break
157
+ noise = F.avg_pool2d(noise, kernel_size=2)
158
+ else:
159
+ raise NotImplementedError(f"Uknown motion_reg_type: {motion_reg_type}")
160
+
161
+ loss = dist + reg_loss * regularize_noise_weight
162
+
163
+ # Step
164
+ optimizer.zero_grad(set_to_none=True)
165
+ loss.backward()
166
+ optimizer.step()
167
+
168
+ # Save projected W for each optimization step.
169
+ w_result[step] = w_opt.detach()
170
+
171
+ # Normalize noise.
172
+ # with torch.no_grad():
173
+ # for buf in motion_z_opt.values():
174
+ # buf -= buf.mean()
175
+ # buf *= buf.square().mean().rsqrt()
176
+
177
+ return w_result, motion_z_opt
178
+
179
+ #----------------------------------------------------------------------------
180
+
181
+ @torch.no_grad()
182
+ def find_motions_init(G: Callable, vgg16: nn.Module, target_features: Tensor, c: Tensor, t: Tensor, num_motions_to_try: int=128):
183
+ motions = G.synthesis.motion_encoder(
184
+ c=c.repeat_interleave(num_motions_to_try, dim=0),
185
+ t=t.repeat_interleave(num_motions_to_try, dim=0))['motion_z'] # [num_videos * num_motions_to_try, ...]
186
+
187
+ synth_images = G.synthesis(
188
+ w_opt.repeat_interleave(num_motions_to_try, dim=0),
189
+ c=c.repeat_interleave(num_motions_to_try, dim=0),
190
+ t=t.repeat_interleave(num_motions_to_try, dim=0),
191
+ motion_z=motions)
192
+
193
+ if synth_images.shape[2] > 256:
194
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
195
+
196
+ synth_images = (synth_images * 0.5 + 0.5) * 255.0
197
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) # [num_videos * num_motions_to_try, ...]
198
+ dist = (target_features.repeat_interleave(num_motions_to_try, dim=0) - synth_features).square().sum(dim=1) # [num_videos * num_motions_to_try]
199
+ best_motions_idx = dist.view(num_videos, num_motions_to_try).argmin(dim=1) # [num_videos]
200
+ motion_z_opt = motions[best_motions_idx] # [num_videos, ...]
201
+
202
+ return motion_z_opt
203
+
204
+ #----------------------------------------------------------------------------
205
+
206
+ @torch.no_grad()
207
+ def find_w_init(G: Callable, vgg16: nn.Module, target_features: Tensor, c: Tensor, t: Tensor, l: Tensor, num_w_to_try: int=128):
208
+ z = torch.randn(num_videos * num_w_to_try, G.z_dim, device=device)
209
+ w = G.mapping(z=z, c=None) # [N, L, C]
210
+
211
+ synth_images = G.synthesis(
212
+ ws=w,
213
+ c=c.repeat_interleave(num_w_to_try, dim=0),
214
+ t=t.repeat_interleave(num_w_to_try, dim=0))
215
+ if synth_images.shape[2] > 256:
216
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
217
+ synth_images = (synth_images * 0.5 + 0.5) * 255.0
218
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) # [num_videos * num_motions_to_try, ...]
219
+ dist = (target_features.repeat_interleave(num_w_to_try, dim=0) - synth_features).square().sum(dim=1) # [num_videos * num_motions_to_try]
220
+ best_w_idx = dist.view(num_videos, num_w_to_try).argmin(dim=1) # [num_videos]
221
+ w_opt = w[best_w_idx] # [num_videos, num_ws, w_dim]
222
+
223
+ return w_opt
224
+
225
+ #----------------------------------------------------------------------------
226
+
227
+ @torch.no_grad()
228
+ def load_target_images(img_paths: List[os.PathLike], extract_faces: bool=False, ref_image: Tensor=None):
229
+ images = [Image.open(f) for f in tqdm(img_paths, desc='Loading images')]
230
+
231
+ if extract_faces:
232
+ images = extract_faces_from_images(imgs=images, ref_image=ref_image)
233
+ for p, img in zip(img_paths, images):
234
+ img.save('/tmp/data/faces_extracted/' + os.path.basename(p), q=95)
235
+ assert False
236
+ # grid = torch.stack([TVF.to_tensor(x) for x in images])
237
+ # grid = utils.make_grid(grid, nrow=8)
238
+ # save_image(grid, f'/tmp/data/faces_extracted.png')
239
+ # print('Saved the extracted images!')
240
+
241
+ # images = [x[:, 200:-400, 450:-200] for x in images]
242
+ images = [TVF.to_tensor(x) for x in images]
243
+ images = [TVF.resize(x, size=(256, 256)) for x in images]
244
+
245
+ return images
246
+
247
+ #----------------------------------------------------------------------------
248
+
249
+ @torch.no_grad()
250
+ def extract_faces_from_images(_sentinel=None, imgs: List=None, ref_image: "Image"=None, device: str='cuda'):
251
+ assert _sentinel is None
252
+ try:
253
+ import face_alignment
254
+ except ImportError:
255
+ raise ImportError("To project images with alignment, you need to install the `face_alignment` library.")
256
+
257
+ SELECTED_LANDMARKS = [38, 44]
258
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
259
+
260
+ ref_landmarks = fa.get_landmarks_from_image(np.array(ref_image))[0][SELECTED_LANDMARKS] # [2, 2]
261
+ landmarks = [fa.get_landmarks_from_image(np.array(x))[0][SELECTED_LANDMARKS] for x in imgs] # [num_imgs, 2, 2]
262
+ ref_dist = ((ref_landmarks[0] - ref_landmarks[1]) ** 2).sum() ** 0.5 # [1]
263
+ dists = [((p[0] - p[1]) ** 2).sum() ** 0.5 for p in landmarks] # [num_imgs]
264
+ resize_ratios = [ref_dist / d for d in dists] # [num_imgs]
265
+ new_sizes = [(int(r * x.size[1]), int(r * x.size[0])) for r, x in zip(resize_ratios, imgs)]
266
+ imgs_resized = [TVF.resize(x, size=s, interpolation=Image.LANCZOS) for x, s in zip(imgs, new_sizes)] # [num_imgs, Image]
267
+ bbox_left = [p[0][0] * r - ref_landmarks[0][0] for p, r in zip(landmarks, resize_ratios)]
268
+ bbox_top = [p[0][1] * r - ref_landmarks[0][1] for p, r in zip(landmarks, resize_ratios)]
269
+
270
+ out = [x.crop(box=(l, t, l + ref_image.size[0], t + ref_image.size[1])) for x, l, t in zip(imgs_resized, bbox_left, bbox_top)]
271
+
272
+ return out
273
+
274
+ #----------------------------------------------------------------------------
275
+
276
+ def pad_box_to_square(left, upper, right, lower):
277
+ h = lower - upper
278
+ w = right - left
279
+
280
+ if h == w:
281
+ return left, upper, right, lower
282
+ elif w > h:
283
+ diff = w - h
284
+ assert False, "Not implemented"
285
+ else:
286
+ pad = (h - w) // 2
287
+
288
+ return (left - pad, upper, right + pad, lower)
289
+
290
+ #----------------------------------------------------------------------------
291
+
292
+ def add_margins(box, margin, width: int=float('inf'), height: int=float('inf')):
293
+ left, upper, right, lower = box
294
+
295
+ return (
296
+ max(0, left - margin[0]),
297
+ max(0, upper - margin[1]),
298
+ min(width, right + margin[2]),
299
+ min(height, lower + margin[3]),
300
+ )
301
+
302
+ #----------------------------------------------------------------------------
303
+
304
+ def add_top_margin(box, margin_ratio: float=0.0):
305
+ left, upper, right, lower = box
306
+ height = lower - upper
307
+ margin = int(height * margin_ratio)
308
+
309
+ return (left, max(0, upper - margin), right, lower)
310
+
311
+ #----------------------------------------------------------------------------
312
+
313
+ def save_edited_w(
314
+ _sentinel=None,
315
+ G: Callable=None,
316
+ w_outdir: os.PathLike=None,
317
+ samples_outdir: os.PathLike=None,
318
+ img_names: List[str]=None,
319
+ stack_samples: bool=False,
320
+ num_frames: int = 16,
321
+ each_nth_frame: int = 3,
322
+ all_w: Tensor=None,
323
+ all_motion_z: Tensor=None,
324
+ stacked_samples_out_path: os.PathLike=None,
325
+ ):
326
+ assert _sentinel is None
327
+
328
+ # w_outdir = os.path.join(os.path.basename(images_dir))
329
+
330
+ os.makedirs(w_outdir, exist_ok=True)
331
+ num_videos = len(img_names)
332
+ device = all_w.device
333
+
334
+ if not stack_samples:
335
+ os.makedirs(samples_outdir, exist_ok=True)
336
+ else:
337
+ all_samples = []
338
+
339
+ # Generate samples from the given w and save them.
340
+ with torch.no_grad():
341
+ z = torch.randn(num_videos, G.z_dim, device=device) # [num_videos, z_dim]
342
+ c = torch.zeros(num_videos, G.c_dim, device=device) # [num_videos, c_dim]
343
+
344
+ for i, w in enumerate(all_w):
345
+ torch.save(w.cpu(), os.path.join(w_outdir, f'{img_names[i]}_w.pt'))
346
+
347
+ if all_motion_z is None:
348
+ motion_z = None
349
+ else:
350
+ motion_z = all_motion_z[i] # [...<any>...]
351
+ torch.save(motion_z.cpu(), os.path.join(w_outdir, f'{img_names[i]}_motion.pt'))
352
+ motion_z = motion_z.unsqueeze(0).to(device) # [1, ...<any>...]
353
+ motion_z = torch.randn_like(motion_z)
354
+
355
+ w = w.unsqueeze(0).to(device) # [1, num_ws, w_dim]
356
+ t = torch.linspace(0, num_frames * (1 + each_nth_frame), num_frames, device=device).unsqueeze(0)
357
+ imgs = G.synthesis(w, c=c[[i]]], t=t, motion_z=motion_z)
358
+ imgs = (imgs * 0.5 + 0.5).clamp(0, 1)
359
+ grid = utils.make_grid(imgs, nrow=num_frames).cpu()
360
+
361
+ if stack_samples:
362
+ all_samples.append(grid)
363
+ else:
364
+ # TVF.to_pil_image(grid).save(os.path.join(samples_outdir, img_names[i]) + '.jpg', q=95)
365
+ save_image(grid, os.path.join(samples_outdir, img_names[i]) + '.png')
366
+
367
+ if stack_samples:
368
+ main_grid = torch.stack(all_samples) # [num_videos, c, h, w * num_frames]
369
+ main_grid = utils.make_grid(main_grid, nrow=1)
370
+ # TVF.to_pil_image(main_grid).save(f'{images_dir}.jpg', q=95)
371
+ save_image(main_grid, stacked_samples_out_path)
372
+
373
+ #----------------------------------------------------------------------------
374
+
375
+ @click.command()
376
+ @click.pass_context
377
+ @click.option('--network_pkl', help='Network pickle filename', metavar='PATH')
378
+ @click.option('--networks_dir', help='Network pickles directory', metavar='PATH')
379
+ # @click.option('--truncation_psi', type=float, help='Truncation psi', default=1.0, show_default=True)
380
+ # @click.option('--noise_mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
381
+ # @click.option('--same_motion_codes', type=bool, help='Should we use the same motion codes for all videos?', default=False, show_default=True)
382
+ @click.option('--seed', type=int, help='Random seed', default=42, metavar='DIR')
383
+ @click.option('--images_dir', help='Where to save the output images', type=str, required=True, metavar='DIR')
384
+ # @click.option('--save_as_mp4', help='Should we save as independent frames or mp4?', type=bool, default=False, metavar='BOOL')
385
+ # @click.option('--video_len', help='Number of frames to generate', type=int, default=16, metavar='INT')
386
+ # @click.option('--fps', help='FPS for mp4 saving', type=int, default=25, metavar='INT')
387
+ # @click.option('--as_grids', help='Save videos as grids', type=bool, default=False, metavar='BOOl')
388
+ @click.option('--zero_periods', help='Zero-out periods predictor?', default=False, type=bool, metavar='BOOL')
389
+ @click.option('--num_weights_to_slice', help='Number of high-frequency coords to remove.', default=0, type=int, metavar='INT')
390
+ @click.option('--use_w_init', help='Init w by LPIPS.', default=False, type=bool, metavar='BOOL')
391
+ @click.option('--use_motion_init', help='Init motions by LPIPS.', default=False, type=bool, metavar='BOOL')
392
+ @click.option('--motion_reg_type', help='Type of the regularization for motion', default=None, type=str, metavar='STR')
393
+ @click.option('--num_steps', help='Number of the optimization steps to perform.', default=1000, type=int, metavar='INT')
394
+ @click.option('--stack_samples', help='When saving, should we stack samples together?', default=False, type=bool, metavar='BOOL')
395
+ @click.option('--extract_faces', help='Use FaceNet to extract the face?', default=False, type=bool, metavar='BOOL')
396
+
397
+ def main(
398
+ ctx: click.Context,
399
+ network_pkl: str,
400
+ networks_dir: str,
401
+ seed: int,
402
+ images_dir: str,
403
+ # save_as_mp4: bool,
404
+ # video_len: int,
405
+ # fps: int,
406
+ # as_grids: bool,
407
+ zero_periods: bool,
408
+ num_weights_to_slice: int,
409
+ use_w_init: bool,
410
+ use_motion_init: bool,
411
+ motion_reg_type: str,
412
+ num_steps: int,
413
+ stack_samples: bool,
414
+ extract_faces: bool,
415
+ ):
416
+ if network_pkl is None:
417
+ output_regex = "^network-snapshot-\d{6}.pkl$"
418
+ ckpt_regex = re.compile("^network-snapshot-\d{6}.pkl$")
419
+ # ckpts = sorted([f for f in os.listdir(networks_dir) if ckpt_regex.match(f)])
420
+ # network_pkl = os.path.join(networks_dir, ckpts[-1])
421
+ metrics_file = os.path.join(networks_dir, 'metric-fvd2048_16f.jsonl')
422
+ with open(metrics_file, 'r') as f:
423
+ snapshot_metrics_vals = [json.loads(line) for line in f.read().splitlines()]
424
+ best_snapshot = sorted(snapshot_metrics_vals, key=lambda m: m['results']['fvd2048_16f'])[0]
425
+ network_pkl = os.path.join(networks_dir, best_snapshot['snapshot_pkl'])
426
+ print(f'Using checkpoint: {network_pkl} with FVD16 of', best_snapshot['results']['fvd2048_16f'])
427
+ # Selecting a checkpoint with the best score
428
+ else:
429
+ assert networks_dir is None, "Cant have both parameters: network_pkl and networks_dir"
430
+
431
+ print('Loading networks from "%s"...' % network_pkl, end='')
432
+ device = torch.device('cuda')
433
+ with dnnlib.util.open_url(network_pkl) as f:
434
+ G = legacy.load_network_pkl(f)['G_ema'].to(device).eval() # type: ignore
435
+ print('Loaded!')
436
+
437
+ random.seed(seed)
438
+ np.random.seed(seed)
439
+ torch.manual_seed(seed)
440
+
441
+ if zero_periods:
442
+ G.synthesis.motion_encoder.time_encoder.periods_predictor.weight.data.zero_()
443
+
444
+ if num_weights_to_slice > 0:
445
+ G.synthesis.motion_encoder.time_encoder.weights[:, -num_weights_to_slice:] = 0.0
446
+
447
+ img_paths = sorted([os.path.join(images_dir, p) for p in os.listdir(images_dir) if p.endswith('.jpg')])
448
+ img_names = [n[:n.rfind('.')] for n in [os.path.basename(p) for p in img_paths]]
449
+ target_images = load_target_images(img_paths, extract_faces, ref_image=Image.open('/tmp/data/mean.png')) # [b, c, h, w]
450
+
451
+ assert G.c_dim == 0, "G.c_dim > 0 is not supported"
452
+
453
+ w_all_iters, motion_z_final = project(
454
+ G=G,
455
+ target_images=target_images,
456
+ num_steps=num_steps,
457
+ device=device,
458
+ use_w_init=use_w_init,
459
+ use_motion_init=use_motion_init,
460
+ motion_reg_type=motion_reg_type,
461
+ ) # [num_videos, num_ws, w_dim]
462
+
463
+ save_edited_w(
464
+ G=G,
465
+ w_outdir = f'{images_dir}_projected',
466
+ samples_outdir = f'{images_dir}_projected_samples',
467
+ img_names=img_names,
468
+ stack_samples=stack_samples,
469
+ all_w = w_all_iters[-1],
470
+ all_motion_z = motion_z_final,
471
+ stacked_samples_out_path = f'{images_dir}.png'
472
+ )
473
+
474
+ #----------------------------------------------------------------------------
475
+
476
+ if __name__ == "__main__":
477
+ main() # pylint: disable=no-value-for-parameter
478
+
479
+ #----------------------------------------------------------------------------
src/torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
src/torch_utils/custom_ops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import torch.utils.cpp_extension
13
+ import importlib
14
+ import hashlib
15
+ import shutil
16
+ from pathlib import Path
17
+
18
+ from torch.utils.file_baton import FileBaton
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Global options.
22
+
23
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Internal helper funcs.
27
+
28
+ def _find_compiler_bindir():
29
+ patterns = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
+ ]
35
+ for pattern in patterns:
36
+ matches = sorted(glob.glob(pattern))
37
+ if len(matches):
38
+ return matches[-1]
39
+ return None
40
+
41
+ #----------------------------------------------------------------------------
42
+ # Main entry point for compiling and loading C++/CUDA plugins.
43
+
44
+ _cached_plugins = dict()
45
+
46
+ def get_plugin(module_name, sources, **build_kwargs):
47
+ assert verbosity in ['none', 'brief', 'full']
48
+
49
+ # Already cached?
50
+ if module_name in _cached_plugins:
51
+ return _cached_plugins[module_name]
52
+
53
+ # Print status.
54
+ if verbosity == 'full':
55
+ print(f'Setting up PyTorch plugin "{module_name}"...')
56
+ elif verbosity == 'brief':
57
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
+
59
+ try: # pylint: disable=too-many-nested-blocks
60
+ # Make sure we can find the necessary compiler binaries.
61
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
+ compiler_bindir = _find_compiler_bindir()
63
+ if compiler_bindir is None:
64
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
+ os.environ['PATH'] += ';' + compiler_bindir
66
+
67
+ # Compile and load.
68
+ verbose_build = (verbosity == 'full')
69
+
70
+ # Incremental build md5sum trickery. Copies all the input source files
71
+ # into a cached build directory under a combined md5 digest of the input
72
+ # source files. Copying is done only if the combined digest has changed.
73
+ # This keeps input file timestamps and filenames the same as in previous
74
+ # extension builds, allowing for fast incremental rebuilds.
75
+ #
76
+ # This optimization is done only in case all the source files reside in
77
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
+ # environment variable is set (we take this as a signal that the user
79
+ # actually cares about this.)
80
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
81
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
+
84
+ # Compute a combined hash digest for all source files in the same
85
+ # custom op directory (usually .cu, .cpp, .py and .h files).
86
+ hash_md5 = hashlib.md5()
87
+ for src in all_source_files:
88
+ with open(src, 'rb') as f:
89
+ hash_md5.update(f.read())
90
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
+
93
+ if not os.path.isdir(digest_build_dir):
94
+ os.makedirs(digest_build_dir, exist_ok=True)
95
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
+ if baton.try_acquire():
97
+ try:
98
+ for src in all_source_files:
99
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
+ finally:
101
+ baton.release()
102
+ else:
103
+ # Someone else is copying source files under the digest dir,
104
+ # wait until done and continue.
105
+ baton.wait()
106
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
+ else:
110
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
+ module = importlib.import_module(module_name)
112
+
113
+ except:
114
+ if verbosity == 'brief':
115
+ print('Failed!')
116
+ raise
117
+
118
+ # Print status and add to cache.
119
+ if verbosity == 'full':
120
+ print(f'Done setting up PyTorch plugin "{module_name}".')
121
+ elif verbosity == 'brief':
122
+ print('Done.')
123
+ _cached_plugins[module_name] = module
124
+ return module
125
+
126
+ #----------------------------------------------------------------------------