Spaces:
Configuration error
Configuration error
Erwann Millon commited on
Commit ·
e0f92a0
1
Parent(s): ec39fe8
refactoring and cleanup
Browse files- ImageState.py +118 -74
- animation.py +8 -6
- app.py +7 -7
- backend.py +104 -90
- edit.py +17 -12
- img_processing.py +40 -36
- loaders.py +20 -20
- masking.py +21 -23
- presets.py +30 -4
- prompts.py +31 -7
- unwrapped.yaml +0 -37
- utils.py +3 -1
ImageState.py
CHANGED
|
@@ -1,183 +1,227 @@
|
|
| 1 |
-
|
| 2 |
import gc
|
|
|
|
| 3 |
import imageio
|
| 4 |
import glob
|
| 5 |
import uuid
|
| 6 |
from animation import clear_img_dir
|
| 7 |
-
from backend import
|
| 8 |
-
import importlib
|
| 9 |
-
import gradio as gr
|
| 10 |
-
import matplotlib.pyplot as plt
|
| 11 |
import torch
|
| 12 |
import torchvision
|
| 13 |
import wandb
|
| 14 |
-
from icecream import ic
|
| 15 |
-
from torch import nn
|
| 16 |
-
from torchvision.transforms.functional import resize
|
| 17 |
-
from tqdm import tqdm
|
| 18 |
-
from transformers import CLIPModel, CLIPProcessor
|
| 19 |
-
import lpips
|
| 20 |
-
from backend import get_resized_tensor
|
| 21 |
from edit import blend_paths
|
| 22 |
-
from img_processing import *
|
| 23 |
from img_processing import custom_to_pil
|
| 24 |
-
from
|
|
|
|
| 25 |
num = 0
|
| 26 |
|
| 27 |
-
|
|
|
|
| 28 |
def __init__(self, iterations) -> None:
|
| 29 |
self.iterations = iterations
|
| 30 |
self.transforms = []
|
| 31 |
|
|
|
|
| 32 |
class ImageState:
|
| 33 |
-
def __init__(self, vqgan, prompt_optimizer:
|
| 34 |
self.vqgan = vqgan
|
| 35 |
self.device = vqgan.device
|
| 36 |
self.blend_latent = None
|
| 37 |
self.quant = True
|
| 38 |
self.path1 = None
|
| 39 |
self.path2 = None
|
|
|
|
|
|
|
|
|
|
| 40 |
self.transform_history = []
|
| 41 |
self.attn_mask = None
|
| 42 |
self.prompt_optim = prompt_optimizer
|
| 43 |
self._load_vectors()
|
| 44 |
self.init_transforms()
|
|
|
|
| 45 |
def _load_vectors(self):
|
| 46 |
-
self.lip_vector = torch.load(
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
| 50 |
images = []
|
| 51 |
-
folder = self.
|
| 52 |
paths = glob.glob(folder + "/*")
|
| 53 |
frame_duration = total_duration / len(paths)
|
| 54 |
print(len(paths), "frame dur", frame_duration)
|
| 55 |
durations = [frame_duration] * len(paths)
|
| 56 |
if extend_frames:
|
| 57 |
-
durations
|
| 58 |
-
durations
|
| 59 |
for file_name in os.listdir(folder):
|
| 60 |
-
if file_name.endswith(
|
| 61 |
file_path = os.path.join(folder, file_name)
|
| 62 |
images.append(imageio.imread(file_path))
|
| 63 |
imageio.mimsave(gif_name, images, duration=durations)
|
| 64 |
return gif_name
|
|
|
|
| 65 |
def init_transforms(self):
|
| 66 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
| 67 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
| 68 |
self.asian_transform = torch.zeros_like(self.lip_vector)
|
| 69 |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
|
|
|
| 70 |
def clear_transforms(self):
|
| 71 |
-
global num
|
| 72 |
self.init_transforms()
|
| 73 |
clear_img_dir("./img_history")
|
| 74 |
-
num = 0
|
| 75 |
return self._render_all_transformations()
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
return new_latent
|
| 79 |
-
def _decode_latent_to_pil(self, latent):
|
| 80 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
| 81 |
return custom_to_pil(current_im)
|
|
|
|
| 82 |
def _get_mask(self, img, mask=None):
|
| 83 |
if img and "mask" in img and img["mask"] is not None:
|
| 84 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
| 85 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
| 86 |
print("mask set successfully")
|
| 87 |
-
print(type(attn_mask))
|
| 88 |
-
print(attn_mask.shape)
|
| 89 |
else:
|
| 90 |
attn_mask = mask
|
| 91 |
return attn_mask
|
|
|
|
| 92 |
def set_mask(self, img):
|
| 93 |
self.attn_mask = self._get_mask(img)
|
| 94 |
x = self.attn_mask.clone()
|
| 95 |
x = x.detach().cpu()
|
| 96 |
-
x = torch.clamp(x, -1., 1.)
|
| 97 |
-
x = (x + 1.)/2.
|
| 98 |
x = x.numpy()
|
| 99 |
x = (255 * x).astype(np.uint8)
|
| 100 |
x = Image.fromarray(x, "L")
|
| 101 |
return x
|
| 102 |
-
|
|
|
|
| 103 |
def _render_all_transformations(self, return_twice=True):
|
| 104 |
global num
|
| 105 |
-
|
| 106 |
-
self.
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
| 110 |
if self.quant:
|
| 111 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
| 112 |
-
image = self.
|
| 113 |
-
|
| 114 |
-
if not os.path.exists("img_history"):
|
| 115 |
-
os.mkdir("./img_history")
|
| 116 |
-
if not os.path.exists(img_dir):
|
| 117 |
-
os.mkdir(img_dir)
|
| 118 |
-
image.save(f"{img_dir}/img_{num:06}.png")
|
| 119 |
num += 1
|
| 120 |
return (image, image) if return_twice else image
|
|
|
|
| 121 |
def apply_rb_vector(self, weight):
|
| 122 |
self.blue_eyes = weight * self.blue_eyes_vector
|
| 123 |
return self._render_all_transformations()
|
|
|
|
| 124 |
def apply_lip_vector(self, weight):
|
| 125 |
self.lip_size = weight * self.lip_vector
|
| 126 |
return self._render_all_transformations()
|
|
|
|
| 127 |
def update_quant(self, val):
|
| 128 |
self.quant = val
|
| 129 |
return self._render_all_transformations()
|
|
|
|
| 130 |
def apply_asian_vector(self, weight):
|
| 131 |
self.asian_transform = weight * self.asian_vector
|
| 132 |
return self._render_all_transformations()
|
|
|
|
| 133 |
def update_images(self, path1, path2, blend_weight):
|
| 134 |
if path1 is None and path2 is None:
|
| 135 |
return None
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
self.path1, self.path2 = path1, path2
|
| 139 |
-
if self.
|
| 140 |
-
clear_img_dir(self.
|
| 141 |
return self.blend(blend_weight)
|
| 142 |
-
|
|
|
|
| 143 |
def blend(self, weight):
|
| 144 |
-
_, latent = blend_paths(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
self.blend_latent = latent
|
| 146 |
return self._render_all_transformations()
|
| 147 |
-
|
|
|
|
| 148 |
def rewind(self, index):
|
| 149 |
if not self.transform_history:
|
| 150 |
-
print("
|
| 151 |
return self._render_all_transformations()
|
| 152 |
prompt_transform = self.transform_history[-1]
|
| 153 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
| 154 |
print(latent_index)
|
| 155 |
-
self.current_prompt_transforms[-1] = prompt_transform.transforms[
|
|
|
|
|
|
|
| 156 |
return self._render_all_transformations()
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
if log:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
|
| 173 |
negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
|
| 174 |
-
self.prompt_optim.set_params(
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
transform_log.transforms.append(transform.detach().cpu())
|
| 179 |
self.current_prompt_transforms[-1] = transform
|
| 180 |
-
with torch.
|
| 181 |
image = self._render_all_transformations(return_twice=False)
|
| 182 |
if log:
|
| 183 |
wandb.log({"image": wandb.Image(image)})
|
|
@@ -187,4 +231,4 @@ class ImageState:
|
|
| 187 |
self.attn_mask = None
|
| 188 |
self.transform_history.append(transform_log)
|
| 189 |
gc.collect()
|
| 190 |
-
torch.cuda.empty_cache()
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
import gc
|
| 3 |
+
import os
|
| 4 |
import imageio
|
| 5 |
import glob
|
| 6 |
import uuid
|
| 7 |
from animation import clear_img_dir
|
| 8 |
+
from backend import ImagePromptEditor, log
|
|
|
|
|
|
|
|
|
|
| 9 |
import torch
|
| 10 |
import torchvision
|
| 11 |
import wandb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from edit import blend_paths
|
|
|
|
| 13 |
from img_processing import custom_to_pil
|
| 14 |
+
from PIL import Image
|
| 15 |
+
|
| 16 |
num = 0
|
| 17 |
|
| 18 |
+
|
| 19 |
+
class PromptTransformHistory:
|
| 20 |
def __init__(self, iterations) -> None:
|
| 21 |
self.iterations = iterations
|
| 22 |
self.transforms = []
|
| 23 |
|
| 24 |
+
|
| 25 |
class ImageState:
|
| 26 |
+
def __init__(self, vqgan, prompt_optimizer: ImagePromptEditor) -> None:
|
| 27 |
self.vqgan = vqgan
|
| 28 |
self.device = vqgan.device
|
| 29 |
self.blend_latent = None
|
| 30 |
self.quant = True
|
| 31 |
self.path1 = None
|
| 32 |
self.path2 = None
|
| 33 |
+
self.img_dir = "./img_history"
|
| 34 |
+
if not os.path.exists(self.img_dir):
|
| 35 |
+
os.mkdir(self.img_dir)
|
| 36 |
self.transform_history = []
|
| 37 |
self.attn_mask = None
|
| 38 |
self.prompt_optim = prompt_optimizer
|
| 39 |
self._load_vectors()
|
| 40 |
self.init_transforms()
|
| 41 |
+
|
| 42 |
def _load_vectors(self):
|
| 43 |
+
self.lip_vector = torch.load(
|
| 44 |
+
"./latent_vectors/lipvector.pt", map_location=self.device
|
| 45 |
+
)
|
| 46 |
+
self.blue_eyes_vector = torch.load(
|
| 47 |
+
"./latent_vectors/2blue_eyes.pt", map_location=self.device
|
| 48 |
+
)
|
| 49 |
+
self.asian_vector = torch.load(
|
| 50 |
+
"./latent_vectors/asian10.pt", map_location=self.device
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
|
| 54 |
images = []
|
| 55 |
+
folder = self.img_dir
|
| 56 |
paths = glob.glob(folder + "/*")
|
| 57 |
frame_duration = total_duration / len(paths)
|
| 58 |
print(len(paths), "frame dur", frame_duration)
|
| 59 |
durations = [frame_duration] * len(paths)
|
| 60 |
if extend_frames:
|
| 61 |
+
durations[0] = 1.5
|
| 62 |
+
durations[-1] = 3
|
| 63 |
for file_name in os.listdir(folder):
|
| 64 |
+
if file_name.endswith(".png"):
|
| 65 |
file_path = os.path.join(folder, file_name)
|
| 66 |
images.append(imageio.imread(file_path))
|
| 67 |
imageio.mimsave(gif_name, images, duration=durations)
|
| 68 |
return gif_name
|
| 69 |
+
|
| 70 |
def init_transforms(self):
|
| 71 |
self.blue_eyes = torch.zeros_like(self.lip_vector)
|
| 72 |
self.lip_size = torch.zeros_like(self.lip_vector)
|
| 73 |
self.asian_transform = torch.zeros_like(self.lip_vector)
|
| 74 |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]
|
| 75 |
+
|
| 76 |
def clear_transforms(self):
|
|
|
|
| 77 |
self.init_transforms()
|
| 78 |
clear_img_dir("./img_history")
|
|
|
|
| 79 |
return self._render_all_transformations()
|
| 80 |
+
|
| 81 |
+
def _latent_to_pil(self, latent):
|
|
|
|
|
|
|
| 82 |
current_im = self.vqgan.decode(latent.to(self.device))[0]
|
| 83 |
return custom_to_pil(current_im)
|
| 84 |
+
|
| 85 |
def _get_mask(self, img, mask=None):
|
| 86 |
if img and "mask" in img and img["mask"] is not None:
|
| 87 |
attn_mask = torchvision.transforms.ToTensor()(img["mask"])
|
| 88 |
attn_mask = torch.ceil(attn_mask[0].to(self.device))
|
| 89 |
print("mask set successfully")
|
|
|
|
|
|
|
| 90 |
else:
|
| 91 |
attn_mask = mask
|
| 92 |
return attn_mask
|
| 93 |
+
|
| 94 |
def set_mask(self, img):
|
| 95 |
self.attn_mask = self._get_mask(img)
|
| 96 |
x = self.attn_mask.clone()
|
| 97 |
x = x.detach().cpu()
|
| 98 |
+
x = torch.clamp(x, -1.0, 1.0)
|
| 99 |
+
x = (x + 1.0) / 2.0
|
| 100 |
x = x.numpy()
|
| 101 |
x = (255 * x).astype(np.uint8)
|
| 102 |
x = Image.fromarray(x, "L")
|
| 103 |
return x
|
| 104 |
+
|
| 105 |
+
@torch.inference_mode()
|
| 106 |
def _render_all_transformations(self, return_twice=True):
|
| 107 |
global num
|
| 108 |
+
current_vector_transforms = (
|
| 109 |
+
self.blue_eyes,
|
| 110 |
+
self.lip_size,
|
| 111 |
+
self.asian_transform,
|
| 112 |
+
sum(self.current_prompt_transforms),
|
| 113 |
+
)
|
| 114 |
new_latent = self.blend_latent + sum(current_vector_transforms)
|
| 115 |
if self.quant:
|
| 116 |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
|
| 117 |
+
image = self._latent_to_pil(new_latent)
|
| 118 |
+
image.save(f"{self.img_dir}/img_{num:06}.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
num += 1
|
| 120 |
return (image, image) if return_twice else image
|
| 121 |
+
|
| 122 |
def apply_rb_vector(self, weight):
|
| 123 |
self.blue_eyes = weight * self.blue_eyes_vector
|
| 124 |
return self._render_all_transformations()
|
| 125 |
+
|
| 126 |
def apply_lip_vector(self, weight):
|
| 127 |
self.lip_size = weight * self.lip_vector
|
| 128 |
return self._render_all_transformations()
|
| 129 |
+
|
| 130 |
def update_quant(self, val):
|
| 131 |
self.quant = val
|
| 132 |
return self._render_all_transformations()
|
| 133 |
+
|
| 134 |
def apply_asian_vector(self, weight):
|
| 135 |
self.asian_transform = weight * self.asian_vector
|
| 136 |
return self._render_all_transformations()
|
| 137 |
+
|
| 138 |
def update_images(self, path1, path2, blend_weight):
|
| 139 |
if path1 is None and path2 is None:
|
| 140 |
return None
|
| 141 |
+
|
| 142 |
+
# Duplicate paths if one is empty
|
| 143 |
+
if path1 is None:
|
| 144 |
+
path1 = path2
|
| 145 |
+
if path2 is None:
|
| 146 |
+
path2 = path1
|
| 147 |
+
|
| 148 |
self.path1, self.path2 = path1, path2
|
| 149 |
+
if self.img_dir:
|
| 150 |
+
clear_img_dir(self.img_dir)
|
| 151 |
return self.blend(blend_weight)
|
| 152 |
+
|
| 153 |
+
@torch.inference_mode()
|
| 154 |
def blend(self, weight):
|
| 155 |
+
_, latent = blend_paths(
|
| 156 |
+
self.vqgan,
|
| 157 |
+
self.path1,
|
| 158 |
+
self.path2,
|
| 159 |
+
weight=weight,
|
| 160 |
+
show=False,
|
| 161 |
+
device=self.device,
|
| 162 |
+
)
|
| 163 |
self.blend_latent = latent
|
| 164 |
return self._render_all_transformations()
|
| 165 |
+
|
| 166 |
+
@torch.inference_mode()
|
| 167 |
def rewind(self, index):
|
| 168 |
if not self.transform_history:
|
| 169 |
+
print("No history")
|
| 170 |
return self._render_all_transformations()
|
| 171 |
prompt_transform = self.transform_history[-1]
|
| 172 |
latent_index = int(index / 100 * (prompt_transform.iterations - 1))
|
| 173 |
print(latent_index)
|
| 174 |
+
self.current_prompt_transforms[-1] = prompt_transform.transforms[
|
| 175 |
+
latent_index
|
| 176 |
+
].to(self.device)
|
| 177 |
return self._render_all_transformations()
|
| 178 |
+
|
| 179 |
+
def _init_logging(lr, iterations, lpips_weight, positive_prompts, negative_prompts):
|
| 180 |
+
wandb.init(reinit=True, project="face-editor")
|
| 181 |
+
wandb.config.update({"Positive Prompts": positive_prompts})
|
| 182 |
+
wandb.config.update({"Negative Prompts": negative_prompts})
|
| 183 |
+
wandb.config.update(
|
| 184 |
+
dict(lr=lr, iterations=iterations, lpips_weight=lpips_weight)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def apply_prompts(
|
| 188 |
+
self,
|
| 189 |
+
positive_prompts,
|
| 190 |
+
negative_prompts,
|
| 191 |
+
lr,
|
| 192 |
+
iterations,
|
| 193 |
+
lpips_weight,
|
| 194 |
+
reconstruction_steps,
|
| 195 |
+
):
|
| 196 |
if log:
|
| 197 |
+
self._init_logging(
|
| 198 |
+
lr, iterations, lpips_weight, positive_prompts, negative_prompts
|
| 199 |
+
)
|
| 200 |
+
transform_log = PromptTransformHistory(iterations + reconstruction_steps)
|
| 201 |
+
transform_log.transforms.append(
|
| 202 |
+
torch.zeros_like(self.blend_latent, requires_grad=False)
|
| 203 |
+
)
|
| 204 |
+
self.current_prompt_transforms.append(
|
| 205 |
+
torch.zeros_like(self.blend_latent, requires_grad=False)
|
| 206 |
+
)
|
| 207 |
positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
|
| 208 |
negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
|
| 209 |
+
self.prompt_optim.set_params(
|
| 210 |
+
lr,
|
| 211 |
+
iterations,
|
| 212 |
+
lpips_weight,
|
| 213 |
+
attn_mask=self.attn_mask,
|
| 214 |
+
reconstruction_steps=reconstruction_steps,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
for i, transform in enumerate(
|
| 218 |
+
self.prompt_optim.optimize(
|
| 219 |
+
self.blend_latent, positive_prompts, negative_prompts
|
| 220 |
+
)
|
| 221 |
+
):
|
| 222 |
transform_log.transforms.append(transform.detach().cpu())
|
| 223 |
self.current_prompt_transforms[-1] = transform
|
| 224 |
+
with torch.inference_mode():
|
| 225 |
image = self._render_all_transformations(return_twice=False)
|
| 226 |
if log:
|
| 227 |
wandb.log({"image": wandb.Image(image)})
|
|
|
|
| 231 |
self.attn_mask = None
|
| 232 |
self.transform_history.append(transform_log)
|
| 233 |
gc.collect()
|
| 234 |
+
torch.cuda.empty_cache()
|
animation.py
CHANGED
|
@@ -8,21 +8,23 @@ def clear_img_dir(img_dir):
|
|
| 8 |
os.mkdir("img_history")
|
| 9 |
if not os.path.exists(img_dir):
|
| 10 |
os.mkdir(img_dir)
|
| 11 |
-
for filename in glob.glob(img_dir+"/*"):
|
| 12 |
os.remove(filename)
|
| 13 |
|
| 14 |
|
| 15 |
-
def create_gif(
|
|
|
|
|
|
|
| 16 |
images = []
|
| 17 |
paths = glob.glob(folder + "/*")
|
| 18 |
frame_duration = total_duration / len(paths)
|
| 19 |
print(len(paths), "frame dur", frame_duration)
|
| 20 |
durations = [frame_duration] * len(paths)
|
| 21 |
if extend_frames:
|
| 22 |
-
durations
|
| 23 |
-
durations
|
| 24 |
for file_name in os.listdir(folder):
|
| 25 |
-
if file_name.endswith(
|
| 26 |
file_path = os.path.join(folder, file_name)
|
| 27 |
images.append(imageio.imread(file_path))
|
| 28 |
imageio.mimsave(gif_name, images, duration=durations)
|
|
@@ -30,4 +32,4 @@ def create_gif(total_duration, extend_frames, folder="./img_history", gif_name="
|
|
| 30 |
|
| 31 |
|
| 32 |
if __name__ == "__main__":
|
| 33 |
-
create_gif()
|
|
|
|
| 8 |
os.mkdir("img_history")
|
| 9 |
if not os.path.exists(img_dir):
|
| 10 |
os.mkdir(img_dir)
|
| 11 |
+
for filename in glob.glob(img_dir + "/*"):
|
| 12 |
os.remove(filename)
|
| 13 |
|
| 14 |
|
| 15 |
+
def create_gif(
|
| 16 |
+
total_duration, extend_frames, folder="./img_history", gif_name="face_edit.gif"
|
| 17 |
+
):
|
| 18 |
images = []
|
| 19 |
paths = glob.glob(folder + "/*")
|
| 20 |
frame_duration = total_duration / len(paths)
|
| 21 |
print(len(paths), "frame dur", frame_duration)
|
| 22 |
durations = [frame_duration] * len(paths)
|
| 23 |
if extend_frames:
|
| 24 |
+
durations[0] = 1.5
|
| 25 |
+
durations[-1] = 3
|
| 26 |
for file_name in os.listdir(folder):
|
| 27 |
+
if file_name.endswith(".png"):
|
| 28 |
file_path = os.path.join(folder, file_name)
|
| 29 |
images.append(imageio.imread(file_path))
|
| 30 |
imageio.mimsave(gif_name, images, duration=durations)
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|
| 35 |
+
create_gif()
|
app.py
CHANGED
|
@@ -14,7 +14,7 @@ from transformers import CLIPModel, CLIPProcessor
|
|
| 14 |
from lpips import LPIPS
|
| 15 |
|
| 16 |
import edit
|
| 17 |
-
from backend import
|
| 18 |
from ImageState import ImageState
|
| 19 |
from loaders import load_default
|
| 20 |
# from animation import create_gif
|
|
@@ -29,14 +29,14 @@ processor = ProcessorGradientFlow(device=device)
|
|
| 29 |
# clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 30 |
lpips_fn = LPIPS(net='vgg').to(device)
|
| 31 |
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 32 |
-
promptoptim =
|
|
|
|
| 33 |
def set_img_from_example(state, img):
|
| 34 |
return state.update_images(img, img, 0)
|
| 35 |
def get_cleared_mask():
|
| 36 |
return gr.Image.update(value=None)
|
| 37 |
-
# mask.clear()
|
| 38 |
-
|
| 39 |
class StateWrapper:
|
|
|
|
| 40 |
def create_gif(state, *args, **kwargs):
|
| 41 |
return state, state[0].create_gif(*args, **kwargs)
|
| 42 |
def apply_asian_vector(state, *args, **kwargs):
|
|
@@ -46,7 +46,6 @@ class StateWrapper:
|
|
| 46 |
def apply_lip_vector(state, *args, **kwargs):
|
| 47 |
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
| 48 |
def apply_prompts(state, *args, **kwargs):
|
| 49 |
-
print(state[1])
|
| 50 |
for image in state[0].apply_prompts(*args, **kwargs):
|
| 51 |
yield state, *image
|
| 52 |
def apply_rb_vector(state, *args, **kwargs):
|
|
@@ -69,9 +68,10 @@ class StateWrapper:
|
|
| 69 |
return state, *state[0].update_images(*args, **kwargs)
|
| 70 |
def update_requant(state, *args, **kwargs):
|
| 71 |
return state, *state[0].update_requant(*args, **kwargs)
|
|
|
|
|
|
|
| 72 |
with gr.Blocks(css="styles.css") as demo:
|
| 73 |
-
|
| 74 |
-
state = gr.State([ImageState(vqgan, promptoptim), str(uuid.uuid4())])
|
| 75 |
with gr.Row():
|
| 76 |
with gr.Column(scale=1):
|
| 77 |
with gr.Row():
|
|
|
|
| 14 |
from lpips import LPIPS
|
| 15 |
|
| 16 |
import edit
|
| 17 |
+
from backend import ImagePromptEditor, ProcessorGradientFlow
|
| 18 |
from ImageState import ImageState
|
| 19 |
from loaders import load_default
|
| 20 |
# from animation import create_gif
|
|
|
|
| 29 |
# clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 30 |
lpips_fn = LPIPS(net='vgg').to(device)
|
| 31 |
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 32 |
+
promptoptim = ImagePromptEditor(vqgan, clip, processor, lpips_fn=lpips_fn, quantize=True)
|
| 33 |
+
|
| 34 |
def set_img_from_example(state, img):
|
| 35 |
return state.update_images(img, img, 0)
|
| 36 |
def get_cleared_mask():
|
| 37 |
return gr.Image.update(value=None)
|
|
|
|
|
|
|
| 38 |
class StateWrapper:
|
| 39 |
+
"""This extremely ugly code is a hacky fix to allow con"""
|
| 40 |
def create_gif(state, *args, **kwargs):
|
| 41 |
return state, state[0].create_gif(*args, **kwargs)
|
| 42 |
def apply_asian_vector(state, *args, **kwargs):
|
|
|
|
| 46 |
def apply_lip_vector(state, *args, **kwargs):
|
| 47 |
return state, *state[0].apply_lip_vector(*args, **kwargs)
|
| 48 |
def apply_prompts(state, *args, **kwargs):
|
|
|
|
| 49 |
for image in state[0].apply_prompts(*args, **kwargs):
|
| 50 |
yield state, *image
|
| 51 |
def apply_rb_vector(state, *args, **kwargs):
|
|
|
|
| 68 |
return state, *state[0].update_images(*args, **kwargs)
|
| 69 |
def update_requant(state, *args, **kwargs):
|
| 70 |
return state, *state[0].update_requant(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
with gr.Blocks(css="styles.css") as demo:
|
| 74 |
+
state = gr.State([ImageState(vqgan, promptoptim)])
|
|
|
|
| 75 |
with gr.Row():
|
| 76 |
with gr.Column(scale=1):
|
| 77 |
with gr.Row():
|
backend.py
CHANGED
|
@@ -1,77 +1,65 @@
|
|
| 1 |
-
# from functools import cache
|
| 2 |
-
import importlib
|
| 3 |
-
|
| 4 |
-
import gradio as gr
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import torch
|
| 7 |
import torchvision
|
| 8 |
import wandb
|
| 9 |
-
from icecream import ic
|
| 10 |
from torch import nn
|
| 11 |
-
from torchvision.transforms.functional import resize
|
| 12 |
from tqdm import tqdm
|
| 13 |
-
from transformers import
|
| 14 |
-
import
|
| 15 |
-
|
| 16 |
-
from img_processing import *
|
| 17 |
-
from img_processing import custom_to_pil
|
| 18 |
-
from loaders import load_default
|
| 19 |
-
import glob
|
| 20 |
-
import gc
|
| 21 |
|
| 22 |
global log
|
| 23 |
-
log=False
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# ic.enable()
|
| 27 |
-
def get_resized_tensor(x):
|
| 28 |
-
if len(x.shape) == 2:
|
| 29 |
-
re = x.unsqueeze(0)
|
| 30 |
-
else: re = x
|
| 31 |
-
re = resize(re, (10, 10))
|
| 32 |
-
return re
|
| 33 |
-
class ProcessorGradientFlow():
|
| 34 |
"""
|
| 35 |
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
| 36 |
-
The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
|
| 37 |
"""
|
|
|
|
| 38 |
def __init__(self, device="cuda") -> None:
|
| 39 |
self.device = device
|
| 40 |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 41 |
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 42 |
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
| 43 |
self.normalize = torchvision.transforms.Normalize(
|
| 44 |
-
self.image_mean,
|
| 45 |
-
self.image_std
|
| 46 |
)
|
| 47 |
self.resize = torchvision.transforms.Resize(224)
|
| 48 |
self.center_crop = torchvision.transforms.CenterCrop(224)
|
|
|
|
| 49 |
def preprocess_img(self, images):
|
| 50 |
images = self.center_crop(images)
|
| 51 |
images = self.resize(images)
|
| 52 |
images = self.center_crop(images)
|
| 53 |
images = self.normalize(images)
|
| 54 |
return images
|
|
|
|
| 55 |
def __call__(self, images=[], **kwargs):
|
| 56 |
processed_inputs = self.processor(**kwargs)
|
| 57 |
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
| 58 |
-
processed_inputs = {
|
|
|
|
|
|
|
| 59 |
return processed_inputs
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
super().__init__()
|
| 76 |
self.latent = None
|
| 77 |
self.device = vqgan.device
|
|
@@ -86,14 +74,17 @@ class ImagePromptOptimizer(nn.Module):
|
|
| 86 |
self.quantize = quantize
|
| 87 |
self.lpips_weight = lpips_weight
|
| 88 |
self.perceptual_loss = lpips_fn
|
|
|
|
| 89 |
def set_latent(self, latent):
|
| 90 |
self.latent = latent.detach().to(self.device)
|
|
|
|
| 91 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
| 92 |
self._attn_mask = attn_mask
|
| 93 |
self.iterations = iterations
|
| 94 |
self.lr = lr
|
| 95 |
self.lpips_weight = lpips_weight
|
| 96 |
self.reconstruction_steps = reconstruction_steps
|
|
|
|
| 97 |
def forward(self, vector):
|
| 98 |
base_latent = self.latent.detach().requires_grad_()
|
| 99 |
trans_latent = base_latent + vector
|
|
@@ -103,19 +94,22 @@ class ImagePromptOptimizer(nn.Module):
|
|
| 103 |
z_q = trans_latent
|
| 104 |
dec = self.vqgan.decode(z_q)
|
| 105 |
return dec
|
|
|
|
| 106 |
def _get_clip_similarity(self, prompts, image, weights=None):
|
| 107 |
if isinstance(prompts, str):
|
| 108 |
prompts = [prompts]
|
| 109 |
elif not isinstance(prompts, list):
|
| 110 |
raise TypeError("Provide prompts as string or list of strings")
|
| 111 |
-
clip_inputs = self.clip_preprocessor(
|
| 112 |
-
images=image, return_tensors="pt", padding=True
|
|
|
|
| 113 |
clip_outputs = self.clip(**clip_inputs)
|
| 114 |
similarity_logits = clip_outputs.logits_per_image
|
| 115 |
if weights:
|
| 116 |
similarity_logits *= weights
|
| 117 |
return similarity_logits.sum()
|
| 118 |
-
|
|
|
|
| 119 |
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
| 120 |
if neg_prompts:
|
| 121 |
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
|
@@ -123,6 +117,7 @@ class ImagePromptOptimizer(nn.Module):
|
|
| 123 |
neg_logits = torch.tensor([1], device=self.device)
|
| 124 |
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
| 125 |
return loss
|
|
|
|
| 126 |
def visualize(self, processed_img):
|
| 127 |
if self.make_grid:
|
| 128 |
self.index += 1
|
|
@@ -131,74 +126,93 @@ class ImagePromptOptimizer(nn.Module):
|
|
| 131 |
else:
|
| 132 |
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
| 133 |
plt.show()
|
|
|
|
| 134 |
def _attn_mask(self, grad):
|
| 135 |
newgrad = grad
|
| 136 |
if self._attn_mask is not None:
|
| 137 |
newgrad = grad * (self._attn_mask)
|
| 138 |
return newgrad
|
|
|
|
| 139 |
def _attn_mask_inverse(self, grad):
|
| 140 |
newgrad = grad
|
| 141 |
if self._attn_mask is not None:
|
| 142 |
newgrad = grad * ((self._attn_mask - 1) * -1)
|
| 143 |
return newgrad
|
|
|
|
| 144 |
def _get_next_inputs(self, transformed_img):
|
| 145 |
-
processed_img = loop_post_process(transformed_img)
|
| 146 |
processed_img.retain_grad()
|
|
|
|
| 147 |
lpips_input = processed_img.clone()
|
| 148 |
lpips_input.register_hook(self._attn_mask_inverse)
|
| 149 |
lpips_input.retain_grad()
|
|
|
|
| 150 |
clip_input = processed_img.clone()
|
| 151 |
clip_input.register_hook(self._attn_mask)
|
| 152 |
clip_input.retain_grad()
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
| 156 |
self.set_latent(latent)
|
| 157 |
-
transformed_img = self(
|
|
|
|
|
|
|
| 158 |
original_img = loop_post_process(transformed_img)
|
| 159 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
| 160 |
optim = torch.optim.Adam([vector], lr=self.lr)
|
| 161 |
-
|
| 162 |
-
plt.figure(figsize=(35, 25))
|
| 163 |
-
self.index = 1
|
| 164 |
for i in tqdm(range(self.iterations)):
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
with torch.autocast("cuda"):
|
| 169 |
-
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_input)
|
| 170 |
-
print("CLIP loss", clip_loss)
|
| 171 |
-
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
| 172 |
-
print("LPIPS loss: ", perceptual_loss)
|
| 173 |
-
if log:
|
| 174 |
-
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 175 |
-
wandb.log({"CLIP Loss": clip_loss})
|
| 176 |
-
clip_loss.backward(retain_graph=True)
|
| 177 |
-
perceptual_loss.backward(retain_graph=True)
|
| 178 |
-
p2 = processed_img.grad
|
| 179 |
-
print("Sum Loss", perceptual_loss + clip_loss)
|
| 180 |
-
optim.step()
|
| 181 |
-
# if i % self.iterations // 10 == 0:
|
| 182 |
-
# self.visualize(transformed_img)
|
| 183 |
-
yield vector
|
| 184 |
-
if self.make_grid:
|
| 185 |
-
plt.savefig(f"plot {pos_prompts[0]}.png")
|
| 186 |
-
plt.show()
|
| 187 |
-
print("lpips solo op")
|
| 188 |
for i in range(self.reconstruction_steps):
|
| 189 |
-
|
| 190 |
-
transformed_img = self(vector)
|
| 191 |
-
processed_img = loop_post_process(transformed_img) #* self.attn_mask
|
| 192 |
-
processed_img.retain_grad()
|
| 193 |
-
lpips_input = processed_img.clone()
|
| 194 |
-
lpips_input.register_hook(self._attn_mask_inverse)
|
| 195 |
-
lpips_input.retain_grad()
|
| 196 |
-
with torch.autocast("cuda"):
|
| 197 |
-
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight
|
| 198 |
-
if log:
|
| 199 |
-
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 200 |
-
print("LPIPS loss: ", perceptual_loss)
|
| 201 |
-
perceptual_loss.backward(retain_graph=True)
|
| 202 |
-
optim.step()
|
| 203 |
-
yield vector
|
| 204 |
yield vector if self.return_val == "vector" else self.latent + vector
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
import torch
|
| 3 |
import torchvision
|
| 4 |
import wandb
|
|
|
|
| 5 |
from torch import nn
|
|
|
|
| 6 |
from tqdm import tqdm
|
| 7 |
+
from transformers import CLIPProcessor
|
| 8 |
+
from img_processing import get_pil, loop_post_process
|
| 9 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
global log
|
| 12 |
+
log = False
|
| 13 |
+
|
| 14 |
+
class ProcessorGradientFlow:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
This wraps the huggingface CLIP processor to allow backprop through the image processing step.
|
| 17 |
+
The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow.
|
| 18 |
"""
|
| 19 |
+
|
| 20 |
def __init__(self, device="cuda") -> None:
|
| 21 |
self.device = device
|
| 22 |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 23 |
self.image_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 24 |
self.image_std = [0.26862954, 0.26130258, 0.27577711]
|
| 25 |
self.normalize = torchvision.transforms.Normalize(
|
| 26 |
+
self.image_mean, self.image_std
|
|
|
|
| 27 |
)
|
| 28 |
self.resize = torchvision.transforms.Resize(224)
|
| 29 |
self.center_crop = torchvision.transforms.CenterCrop(224)
|
| 30 |
+
|
| 31 |
def preprocess_img(self, images):
|
| 32 |
images = self.center_crop(images)
|
| 33 |
images = self.resize(images)
|
| 34 |
images = self.center_crop(images)
|
| 35 |
images = self.normalize(images)
|
| 36 |
return images
|
| 37 |
+
|
| 38 |
def __call__(self, images=[], **kwargs):
|
| 39 |
processed_inputs = self.processor(**kwargs)
|
| 40 |
processed_inputs["pixel_values"] = self.preprocess_img(images)
|
| 41 |
+
processed_inputs = {
|
| 42 |
+
key: value.to(self.device) for (key, value) in processed_inputs.items()
|
| 43 |
+
}
|
| 44 |
return processed_inputs
|
| 45 |
|
| 46 |
+
|
| 47 |
+
class ImagePromptEditor(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
vqgan,
|
| 51 |
+
clip,
|
| 52 |
+
clip_preprocessor,
|
| 53 |
+
lpips_fn,
|
| 54 |
+
iterations=100,
|
| 55 |
+
lr=0.01,
|
| 56 |
+
save_vector=True,
|
| 57 |
+
return_val="vector",
|
| 58 |
+
quantize=True,
|
| 59 |
+
make_grid=False,
|
| 60 |
+
lpips_weight=6.2,
|
| 61 |
+
) -> None:
|
| 62 |
+
|
| 63 |
super().__init__()
|
| 64 |
self.latent = None
|
| 65 |
self.device = vqgan.device
|
|
|
|
| 74 |
self.quantize = quantize
|
| 75 |
self.lpips_weight = lpips_weight
|
| 76 |
self.perceptual_loss = lpips_fn
|
| 77 |
+
|
| 78 |
def set_latent(self, latent):
|
| 79 |
self.latent = latent.detach().to(self.device)
|
| 80 |
+
|
| 81 |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask):
|
| 82 |
self._attn_mask = attn_mask
|
| 83 |
self.iterations = iterations
|
| 84 |
self.lr = lr
|
| 85 |
self.lpips_weight = lpips_weight
|
| 86 |
self.reconstruction_steps = reconstruction_steps
|
| 87 |
+
|
| 88 |
def forward(self, vector):
|
| 89 |
base_latent = self.latent.detach().requires_grad_()
|
| 90 |
trans_latent = base_latent + vector
|
|
|
|
| 94 |
z_q = trans_latent
|
| 95 |
dec = self.vqgan.decode(z_q)
|
| 96 |
return dec
|
| 97 |
+
|
| 98 |
def _get_clip_similarity(self, prompts, image, weights=None):
|
| 99 |
if isinstance(prompts, str):
|
| 100 |
prompts = [prompts]
|
| 101 |
elif not isinstance(prompts, list):
|
| 102 |
raise TypeError("Provide prompts as string or list of strings")
|
| 103 |
+
clip_inputs = self.clip_preprocessor(
|
| 104 |
+
text=prompts, images=image, return_tensors="pt", padding=True
|
| 105 |
+
)
|
| 106 |
clip_outputs = self.clip(**clip_inputs)
|
| 107 |
similarity_logits = clip_outputs.logits_per_image
|
| 108 |
if weights:
|
| 109 |
similarity_logits *= weights
|
| 110 |
return similarity_logits.sum()
|
| 111 |
+
|
| 112 |
+
def _get_CLIP_loss(self, pos_prompts, neg_prompts, image):
|
| 113 |
pos_logits = self._get_clip_similarity(pos_prompts, image)
|
| 114 |
if neg_prompts:
|
| 115 |
neg_logits = self._get_clip_similarity(neg_prompts, image)
|
|
|
|
| 117 |
neg_logits = torch.tensor([1], device=self.device)
|
| 118 |
loss = -torch.log(pos_logits) + torch.log(neg_logits)
|
| 119 |
return loss
|
| 120 |
+
|
| 121 |
def visualize(self, processed_img):
|
| 122 |
if self.make_grid:
|
| 123 |
self.index += 1
|
|
|
|
| 126 |
else:
|
| 127 |
plt.imshow(get_pil(processed_img[0]).detach().cpu())
|
| 128 |
plt.show()
|
| 129 |
+
|
| 130 |
def _attn_mask(self, grad):
|
| 131 |
newgrad = grad
|
| 132 |
if self._attn_mask is not None:
|
| 133 |
newgrad = grad * (self._attn_mask)
|
| 134 |
return newgrad
|
| 135 |
+
|
| 136 |
def _attn_mask_inverse(self, grad):
|
| 137 |
newgrad = grad
|
| 138 |
if self._attn_mask is not None:
|
| 139 |
newgrad = grad * ((self._attn_mask - 1) * -1)
|
| 140 |
return newgrad
|
| 141 |
+
|
| 142 |
def _get_next_inputs(self, transformed_img):
|
| 143 |
+
processed_img = loop_post_process(transformed_img) # * self.attn_mask
|
| 144 |
processed_img.retain_grad()
|
| 145 |
+
|
| 146 |
lpips_input = processed_img.clone()
|
| 147 |
lpips_input.register_hook(self._attn_mask_inverse)
|
| 148 |
lpips_input.retain_grad()
|
| 149 |
+
|
| 150 |
clip_input = processed_img.clone()
|
| 151 |
clip_input.register_hook(self._attn_mask)
|
| 152 |
clip_input.retain_grad()
|
| 153 |
+
|
| 154 |
+
return (processed_img, lpips_input, clip_input)
|
| 155 |
+
|
| 156 |
+
def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts):
|
| 157 |
+
optim.zero_grad()
|
| 158 |
+
transformed_img = self(vector)
|
| 159 |
+
processed_img, lpips_input, clip_input = self._get_next_inputs(
|
| 160 |
+
transformed_img
|
| 161 |
+
)
|
| 162 |
+
with torch.autocast("cuda"):
|
| 163 |
+
clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input)
|
| 164 |
+
print("CLIP loss", clip_loss)
|
| 165 |
+
perceptual_loss = (
|
| 166 |
+
self.perceptual_loss(lpips_input, original_img.clone())
|
| 167 |
+
* self.lpips_weight
|
| 168 |
+
)
|
| 169 |
+
print("LPIPS loss: ", perceptual_loss)
|
| 170 |
+
print("Sum Loss", perceptual_loss + clip_loss)
|
| 171 |
+
if log:
|
| 172 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 173 |
+
wandb.log({"CLIP Loss": clip_loss})
|
| 174 |
+
|
| 175 |
+
# These gradients will be masked if attn_mask has been set
|
| 176 |
+
clip_loss.backward(retain_graph=True)
|
| 177 |
+
perceptual_loss.backward(retain_graph=True)
|
| 178 |
+
|
| 179 |
+
optim.step()
|
| 180 |
+
yield vector
|
| 181 |
+
|
| 182 |
+
def _optimize_LPIPS(self, vector, original_img, optim):
|
| 183 |
+
optim.zero_grad()
|
| 184 |
+
transformed_img = self(vector)
|
| 185 |
+
processed_img = loop_post_process(transformed_img) # * self.attn_mask
|
| 186 |
+
processed_img.retain_grad()
|
| 187 |
+
|
| 188 |
+
lpips_input = processed_img.clone()
|
| 189 |
+
lpips_input.register_hook(self._attn_mask_inverse)
|
| 190 |
+
lpips_input.retain_grad()
|
| 191 |
+
with torch.autocast("cuda"):
|
| 192 |
+
perceptual_loss = (
|
| 193 |
+
self.perceptual_loss(lpips_input, original_img.clone())
|
| 194 |
+
* self.lpips_weight
|
| 195 |
+
)
|
| 196 |
+
if log:
|
| 197 |
+
wandb.log({"Perceptual Loss": perceptual_loss})
|
| 198 |
+
print("LPIPS loss: ", perceptual_loss)
|
| 199 |
+
perceptual_loss.backward(retain_graph=True)
|
| 200 |
+
optim.step()
|
| 201 |
+
yield vector
|
| 202 |
|
| 203 |
def optimize(self, latent, pos_prompts, neg_prompts):
|
| 204 |
self.set_latent(latent)
|
| 205 |
+
transformed_img = self(
|
| 206 |
+
torch.zeros_like(self.latent, requires_grad=True, device=self.device)
|
| 207 |
+
)
|
| 208 |
original_img = loop_post_process(transformed_img)
|
| 209 |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device)
|
| 210 |
optim = torch.optim.Adam([vector], lr=self.lr)
|
| 211 |
+
|
|
|
|
|
|
|
| 212 |
for i in tqdm(range(self.iterations)):
|
| 213 |
+
yield self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts)
|
| 214 |
+
|
| 215 |
+
print("Running LPIPS optim only")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
for i in range(self.reconstruction_steps):
|
| 217 |
+
yield self._optimize_LPIPS(vector, original_img, transformed_img, optim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
yield vector if self.return_val == "vector" else self.latent + vector
|
edit.py
CHANGED
|
@@ -12,7 +12,7 @@ import PIL
|
|
| 12 |
import taming
|
| 13 |
import torch
|
| 14 |
|
| 15 |
-
from loaders import load_config
|
| 16 |
from utils import get_device
|
| 17 |
|
| 18 |
|
|
@@ -25,11 +25,14 @@ def get_embedding(model, path=None, img=None, device="cpu"):
|
|
| 25 |
z, _, [_, _, indices] = model.encode(x_processed)
|
| 26 |
return z
|
| 27 |
|
| 28 |
-
|
| 29 |
-
def blend_paths(
|
|
|
|
|
|
|
| 30 |
x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
|
| 31 |
y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
|
| 32 |
-
x_latent
|
|
|
|
| 33 |
z = torch.lerp(x_latent, y_latent, weight)
|
| 34 |
if quantize:
|
| 35 |
z = model.quantize(z)[0]
|
|
@@ -45,14 +48,16 @@ def blend_paths(model, path1, path2, quantize=False, weight=0.5, show=True, devi
|
|
| 45 |
plt.show()
|
| 46 |
return custom_to_pil(decoded), z
|
| 47 |
|
|
|
|
| 48 |
if __name__ == "__main__":
|
| 49 |
device = get_device()
|
| 50 |
-
|
| 51 |
-
conf_path = "./unwrapped.yaml"
|
| 52 |
-
config = load_config(conf_path, display=False)
|
| 53 |
-
model = taming.models.vqgan.VQModel(**config.model.params)
|
| 54 |
-
sd = torch.load("./vqgan_only.pt", map_location="mps")
|
| 55 |
-
model.load_state_dict(sd, strict=True)
|
| 56 |
model.to(device)
|
| 57 |
-
blend_paths(
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import taming
|
| 13 |
import torch
|
| 14 |
|
| 15 |
+
from loaders import load_config, load_default
|
| 16 |
from utils import get_device
|
| 17 |
|
| 18 |
|
|
|
|
| 25 |
z, _, [_, _, indices] = model.encode(x_processed)
|
| 26 |
return z
|
| 27 |
|
| 28 |
+
|
| 29 |
+
def blend_paths(
|
| 30 |
+
model, path1, path2, quantize=False, weight=0.5, show=True, device="cuda"
|
| 31 |
+
):
|
| 32 |
x = preprocess(PIL.Image.open(path1), target_image_size=256).to(device)
|
| 33 |
y = preprocess(PIL.Image.open(path2), target_image_size=256).to(device)
|
| 34 |
+
x_latent = get_embedding(model, path=path1, device=device)
|
| 35 |
+
y_latent = get_embedding(model, path=path2, device=device)
|
| 36 |
z = torch.lerp(x_latent, y_latent, weight)
|
| 37 |
if quantize:
|
| 38 |
z = model.quantize(z)[0]
|
|
|
|
| 48 |
plt.show()
|
| 49 |
return custom_to_pil(decoded), z
|
| 50 |
|
| 51 |
+
|
| 52 |
if __name__ == "__main__":
|
| 53 |
device = get_device()
|
| 54 |
+
model = load_default(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
model.to(device)
|
| 56 |
+
blend_paths(
|
| 57 |
+
model,
|
| 58 |
+
"./test_data/face.jpeg",
|
| 59 |
+
"./test_data/face2.jpeg",
|
| 60 |
+
quantize=False,
|
| 61 |
+
weight=0.5,
|
| 62 |
+
)
|
| 63 |
+
plt.show()
|
img_processing.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
import io
|
| 2 |
-
import os
|
| 3 |
-
import sys
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
import PIL
|
| 7 |
import requests
|
| 8 |
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
import torchvision.transforms as T
|
| 11 |
import torchvision.transforms.functional as TF
|
| 12 |
from PIL import Image, ImageDraw, ImageFont
|
|
@@ -20,10 +17,10 @@ def download_image(url):
|
|
| 20 |
|
| 21 |
def preprocess(img, target_image_size=256, map_dalle=False):
|
| 22 |
s = min(img.size)
|
| 23 |
-
|
| 24 |
if s < target_image_size:
|
| 25 |
-
raise ValueError(f
|
| 26 |
-
|
| 27 |
r = target_image_size / s
|
| 28 |
s = (round(r * img.size[1]), round(r * img.size[0]))
|
| 29 |
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
|
|
@@ -31,42 +28,49 @@ def preprocess(img, target_image_size=256, map_dalle=False):
|
|
| 31 |
img = torch.unsqueeze(T.ToTensor()(img), 0)
|
| 32 |
return img
|
| 33 |
|
|
|
|
| 34 |
def preprocess_vqgan(x):
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
def custom_to_pil(x, process=True, mode="RGB"):
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
| 50 |
|
| 51 |
def get_pil(x):
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
def loop_post_process(x):
|
| 58 |
-
|
| 59 |
-
|
|
|
|
| 60 |
|
| 61 |
def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import io
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
import PIL
|
| 5 |
import requests
|
| 6 |
import torch
|
|
|
|
| 7 |
import torchvision.transforms as T
|
| 8 |
import torchvision.transforms.functional as TF
|
| 9 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
| 17 |
|
| 18 |
def preprocess(img, target_image_size=256, map_dalle=False):
|
| 19 |
s = min(img.size)
|
| 20 |
+
|
| 21 |
if s < target_image_size:
|
| 22 |
+
raise ValueError(f"min dim for image {s} < {target_image_size}")
|
| 23 |
+
|
| 24 |
r = target_image_size / s
|
| 25 |
s = (round(r * img.size[1]), round(r * img.size[0]))
|
| 26 |
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
|
|
|
|
| 28 |
img = torch.unsqueeze(T.ToTensor()(img), 0)
|
| 29 |
return img
|
| 30 |
|
| 31 |
+
|
| 32 |
def preprocess_vqgan(x):
|
| 33 |
+
x = 2.0 * x - 1.0
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
|
| 37 |
def custom_to_pil(x, process=True, mode="RGB"):
|
| 38 |
+
x = x.detach().cpu()
|
| 39 |
+
if process:
|
| 40 |
+
x = torch.clamp(x, -1.0, 1.0)
|
| 41 |
+
x = (x + 1.0) / 2.0
|
| 42 |
+
x = x.permute(1, 2, 0).numpy()
|
| 43 |
+
if process:
|
| 44 |
+
x = (255 * x).astype(np.uint8)
|
| 45 |
+
x = Image.fromarray(x)
|
| 46 |
+
if not x.mode == mode:
|
| 47 |
+
x = x.convert(mode)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
|
| 51 |
def get_pil(x):
|
| 52 |
+
x = torch.clamp(x, -1.0, 1.0)
|
| 53 |
+
x = (x + 1.0) / 2.0
|
| 54 |
+
x = x.permute(1, 2, 0)
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
|
| 58 |
def loop_post_process(x):
|
| 59 |
+
x = get_pil(x.squeeze())
|
| 60 |
+
return x.permute(2, 0, 1).unsqueeze(0)
|
| 61 |
+
|
| 62 |
|
| 63 |
def stack_reconstructions(input, x0, x1, x2, x3, titles=[]):
|
| 64 |
+
assert input.size == x1.size == x2.size == x3.size
|
| 65 |
+
w, h = input.size[0], input.size[1]
|
| 66 |
+
img = Image.new("RGB", (5 * w, h))
|
| 67 |
+
img.paste(input, (0, 0))
|
| 68 |
+
img.paste(x0, (1 * w, 0))
|
| 69 |
+
img.paste(x1, (2 * w, 0))
|
| 70 |
+
img.paste(x2, (3 * w, 0))
|
| 71 |
+
img.paste(x3, (4 * w, 0))
|
| 72 |
+
for i, title in enumerate(titles):
|
| 73 |
+
ImageDraw.Draw(img).text(
|
| 74 |
+
(i * w, 0), f"{title}", (255, 255, 255), font=font
|
| 75 |
+
) # coordinates, text, color, font
|
| 76 |
+
return img
|
loaders.py
CHANGED
|
@@ -10,17 +10,17 @@ from utils import get_device
|
|
| 10 |
|
| 11 |
|
| 12 |
def load_config(config_path, display=False):
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
def load_default(device):
|
| 19 |
-
|
| 20 |
-
conf_path = "./unwrapped.yaml"
|
| 21 |
config = load_config(conf_path, display=False)
|
| 22 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
| 23 |
-
sd = torch.load("./
|
| 24 |
model.load_state_dict(sd, strict=True)
|
| 25 |
model.to(device)
|
| 26 |
del sd
|
|
@@ -34,17 +34,14 @@ def load_vqgan(config, ckpt_path=None, is_gumbel=False):
|
|
| 34 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 35 |
return model.eval()
|
| 36 |
|
| 37 |
-
def load_ffhq():
|
| 38 |
-
conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
|
| 39 |
-
ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt"
|
| 40 |
-
vqgan = load_model(load_config(conf), ckpt, True, True)[0]
|
| 41 |
|
| 42 |
def reconstruct_with_vqgan(x, model):
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
| 48 |
def get_obj_from_str(string, reload=False):
|
| 49 |
module, cls = string.rsplit(".", 1)
|
| 50 |
if reload:
|
|
@@ -52,12 +49,13 @@ def get_obj_from_str(string, reload=False):
|
|
| 52 |
importlib.reload(module_imp)
|
| 53 |
return getattr(importlib.import_module(module, package=None), cls)
|
| 54 |
|
| 55 |
-
def instantiate_from_config(config):
|
| 56 |
|
| 57 |
-
|
|
|
|
| 58 |
raise KeyError("Expected key `target` to instantiate.")
|
| 59 |
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 60 |
|
|
|
|
| 61 |
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
| 62 |
model = instantiate_from_config(config)
|
| 63 |
if sd is not None:
|
|
@@ -78,5 +76,7 @@ def load_model(config, ckpt, gpu, eval_mode):
|
|
| 78 |
else:
|
| 79 |
pl_sd = {"state_dict": None}
|
| 80 |
global_step = None
|
| 81 |
-
model = load_model_from_config(
|
| 82 |
-
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def load_config(config_path, display=False):
|
| 13 |
+
config = OmegaConf.load(config_path)
|
| 14 |
+
if display:
|
| 15 |
+
print(yaml.dump(OmegaConf.to_container(config)))
|
| 16 |
+
return config
|
| 17 |
+
|
| 18 |
|
| 19 |
def load_default(device):
|
| 20 |
+
conf_path = "./celeba_vqgan/unwrapped.yaml"
|
|
|
|
| 21 |
config = load_config(conf_path, display=False)
|
| 22 |
model = taming.models.vqgan.VQModel(**config.model.params)
|
| 23 |
+
sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)
|
| 24 |
model.load_state_dict(sd, strict=True)
|
| 25 |
model.to(device)
|
| 26 |
del sd
|
|
|
|
| 34 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 35 |
return model.eval()
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
def reconstruct_with_vqgan(x, model):
|
| 39 |
+
z, _, [_, _, indices] = model.encode(x)
|
| 40 |
+
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
|
| 41 |
+
xrec = model.decode(z)
|
| 42 |
+
return xrec
|
| 43 |
+
|
| 44 |
+
|
| 45 |
def get_obj_from_str(string, reload=False):
|
| 46 |
module, cls = string.rsplit(".", 1)
|
| 47 |
if reload:
|
|
|
|
| 49 |
importlib.reload(module_imp)
|
| 50 |
return getattr(importlib.import_module(module, package=None), cls)
|
| 51 |
|
|
|
|
| 52 |
|
| 53 |
+
def instantiate_from_config(config):
|
| 54 |
+
if "target" not in config:
|
| 55 |
raise KeyError("Expected key `target` to instantiate.")
|
| 56 |
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 57 |
|
| 58 |
+
|
| 59 |
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
|
| 60 |
model = instantiate_from_config(config)
|
| 61 |
if sd is not None:
|
|
|
|
| 76 |
else:
|
| 77 |
pl_sd = {"state_dict": None}
|
| 78 |
global_step = None
|
| 79 |
+
model = load_model_from_config(
|
| 80 |
+
config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode
|
| 81 |
+
)["model"]
|
| 82 |
+
return model, global_step
|
masking.py
CHANGED
|
@@ -3,30 +3,28 @@ import sys
|
|
| 3 |
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import torch
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
from backend import ImagePromptOptimizer, ImageState, ProcessorGradientFlow
|
| 17 |
-
from loaders import load_default
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
plt.imshow(x)
|
| 30 |
-
plt.show()
|
| 31 |
-
state.apply_prompts("a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask)
|
| 32 |
-
print('done')
|
|
|
|
| 3 |
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import torch
|
| 6 |
+
from backend import ImagePromptEditor, ImageState, ProcessorGradientFlow
|
| 7 |
+
from loaders import load_default
|
| 8 |
+
from transformers import CLIPModel
|
| 9 |
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
sys.path.append("taming-transformers")
|
| 12 |
+
device = "cuda"
|
| 13 |
|
| 14 |
+
vqgan = load_default(device)
|
| 15 |
+
vqgan.eval()
|
| 16 |
|
| 17 |
+
processor = ProcessorGradientFlow(device=device)
|
| 18 |
+
clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 19 |
+
clip.to(device)
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
promptoptim = ImagePromptEditor(vqgan, clip, processor, quantize=True)
|
| 22 |
+
state = ImageState(vqgan, promptoptim)
|
| 23 |
+
mask = torch.load("eyebrow_mask.pt")
|
| 24 |
+
x = state.blend("./test_data/face.jpeg", "./test_data/face2.jpeg", 0.5)
|
| 25 |
+
plt.imshow(x)
|
| 26 |
+
plt.show()
|
| 27 |
+
state.apply_prompts(
|
| 28 |
+
"a picture of a woman with big eyebrows", "", 0.009, 40, None, mask=mask
|
| 29 |
+
)
|
| 30 |
+
print("done")
|
|
|
|
|
|
|
|
|
|
|
|
presets.py
CHANGED
|
@@ -1,16 +1,42 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
|
|
|
| 3 |
def set_preset(config_str):
|
| 4 |
-
choices=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
if config_str == choices[0]:
|
| 6 |
return set_small_local()
|
| 7 |
elif config_str == choices[1]:
|
| 8 |
return set_major_local()
|
| 9 |
elif config_str == choices[2]:
|
| 10 |
return set_major_global()
|
|
|
|
|
|
|
| 11 |
def set_small_local():
|
| 12 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def set_major_local():
|
| 14 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def set_major_global():
|
| 16 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
+
|
| 4 |
def set_preset(config_str):
|
| 5 |
+
choices = [
|
| 6 |
+
"Small Masked Changes (e.g. add lipstick)",
|
| 7 |
+
"Major Masked Changes (e.g. change hair color or nose size)",
|
| 8 |
+
"Major Global Changes (e.g. change race / gender",
|
| 9 |
+
]
|
| 10 |
if config_str == choices[0]:
|
| 11 |
return set_small_local()
|
| 12 |
elif config_str == choices[1]:
|
| 13 |
return set_major_local()
|
| 14 |
elif config_str == choices[2]:
|
| 15 |
return set_major_global()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
def set_small_local():
|
| 19 |
+
return (
|
| 20 |
+
gr.Slider.update(value=25),
|
| 21 |
+
gr.Slider.update(value=0.15),
|
| 22 |
+
gr.Slider.update(value=1),
|
| 23 |
+
gr.Slider.update(value=4),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
def set_major_local():
|
| 28 |
+
return (
|
| 29 |
+
gr.Slider.update(value=25),
|
| 30 |
+
gr.Slider.update(value=0.25),
|
| 31 |
+
gr.Slider.update(value=35),
|
| 32 |
+
gr.Slider.update(value=10),
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
def set_major_global():
|
| 37 |
+
return (
|
| 38 |
+
gr.Slider.update(value=30),
|
| 39 |
+
gr.Slider.update(value=0.1),
|
| 40 |
+
gr.Slider.update(value=2),
|
| 41 |
+
gr.Slider.update(value=0.2),
|
| 42 |
+
)
|
prompts.py
CHANGED
|
@@ -1,17 +1,41 @@
|
|
| 1 |
import random
|
|
|
|
|
|
|
| 2 |
class PromptSet:
|
| 3 |
def __init__(self, pos, neg, config=None):
|
| 4 |
self.positive = pos
|
| 5 |
self.negative = neg
|
| 6 |
self.config = config
|
|
|
|
|
|
|
| 7 |
example_prompts = (
|
| 8 |
-
PromptSet(
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
PromptSet(
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
)
|
|
|
|
|
|
|
| 15 |
def get_random_prompts():
|
| 16 |
prompt = random.choice(example_prompts)
|
| 17 |
-
return prompt.positive, prompt.negative
|
|
|
|
| 1 |
import random
|
| 2 |
+
|
| 3 |
+
|
| 4 |
class PromptSet:
|
| 5 |
def __init__(self, pos, neg, config=None):
|
| 6 |
self.positive = pos
|
| 7 |
self.negative = neg
|
| 8 |
self.config = config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
example_prompts = (
|
| 12 |
+
PromptSet(
|
| 13 |
+
"a picture of a woman with light blonde hair",
|
| 14 |
+
"a picture of a person with dark hair | a picture of a person with brown hair",
|
| 15 |
+
),
|
| 16 |
+
PromptSet(
|
| 17 |
+
"A picture of a woman with very thick eyebrows",
|
| 18 |
+
"a picture of a person with very thin eyebrows | a picture of a person with no eyebrows",
|
| 19 |
+
),
|
| 20 |
+
PromptSet(
|
| 21 |
+
"A picture of a woman wearing bright red lipstick",
|
| 22 |
+
"a picture of a person wearing no lipstick | a picture of a person wearing dark lipstick",
|
| 23 |
+
),
|
| 24 |
+
PromptSet(
|
| 25 |
+
"A picture of a beautiful chinese woman | a picture of a Japanese woman | a picture of an Asian woman",
|
| 26 |
+
"a picture of a white woman | a picture of an Indian woman | a picture of a black woman",
|
| 27 |
+
),
|
| 28 |
+
PromptSet(
|
| 29 |
+
"A picture of a handsome man | a picture of a masculine man",
|
| 30 |
+
"a picture of a woman | a picture of a feminine person",
|
| 31 |
+
),
|
| 32 |
+
PromptSet(
|
| 33 |
+
"A picture of a woman with a very big nose",
|
| 34 |
+
"a picture of a person with a small nose | a picture of a person with a normal nose",
|
| 35 |
+
),
|
| 36 |
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
def get_random_prompts():
|
| 40 |
prompt = random.choice(example_prompts)
|
| 41 |
+
return prompt.positive, prompt.negative
|
unwrapped.yaml
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
model:
|
| 2 |
-
target: taming.models.vqgan.VQModel
|
| 3 |
-
params:
|
| 4 |
-
embed_dim: 256
|
| 5 |
-
n_embed: 1024
|
| 6 |
-
ddconfig:
|
| 7 |
-
double_z: false
|
| 8 |
-
z_channels: 256
|
| 9 |
-
resolution: 256
|
| 10 |
-
in_channels: 3
|
| 11 |
-
out_ch: 3
|
| 12 |
-
ch: 128
|
| 13 |
-
ch_mult:
|
| 14 |
-
- 1
|
| 15 |
-
- 1
|
| 16 |
-
- 2
|
| 17 |
-
- 2
|
| 18 |
-
- 4
|
| 19 |
-
num_res_blocks: 2
|
| 20 |
-
attn_resolutions:
|
| 21 |
-
- 16
|
| 22 |
-
dropout: 0.0
|
| 23 |
-
lossconfig:
|
| 24 |
-
target: taming.modules.losses.vqperceptual.DummyLoss
|
| 25 |
-
data:
|
| 26 |
-
target: cutlit.DataModuleFromConfig
|
| 27 |
-
params:
|
| 28 |
-
batch_size: 24
|
| 29 |
-
num_workers: 24
|
| 30 |
-
train:
|
| 31 |
-
target: taming.data.faceshq.CelebAHQTrain
|
| 32 |
-
params:
|
| 33 |
-
size: 256
|
| 34 |
-
validation:
|
| 35 |
-
target: taming.data.faceshq.CelebAHQValidation
|
| 36 |
-
params:
|
| 37 |
-
size: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
|
@@ -7,9 +7,11 @@ import torch.nn.functional as F
|
|
| 7 |
from skimage.color import lab2rgb, rgb2lab
|
| 8 |
from torch import nn
|
| 9 |
|
|
|
|
| 10 |
def freeze_module(module):
|
| 11 |
for param in module.parameters():
|
| 12 |
-
|
|
|
|
| 13 |
|
| 14 |
def get_device():
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 7 |
from skimage.color import lab2rgb, rgb2lab
|
| 8 |
from torch import nn
|
| 9 |
|
| 10 |
+
|
| 11 |
def freeze_module(module):
|
| 12 |
for param in module.parameters():
|
| 13 |
+
param.requires_grad = False
|
| 14 |
+
|
| 15 |
|
| 16 |
def get_device():
|
| 17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|