|
|
import os |
|
|
import spaces |
|
|
import time |
|
|
from glob import glob |
|
|
from typing import Callable, Optional, Tuple, Union, Dict |
|
|
import random |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision.datasets import VisionDataset |
|
|
from tqdm import tqdm |
|
|
from util.img_utils import clear_color |
|
|
|
|
|
from latent_models import PipelineWrapper |
|
|
|
|
|
|
|
|
def set_seed(seed: int) -> None: |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MinusOneToOne(torch.nn.Module): |
|
|
def forward(self, tensor: torch.Tensor) -> torch.Tensor: |
|
|
return tensor * 2 - 1 |
|
|
|
|
|
|
|
|
class ResizePIL(torch.nn.Module): |
|
|
def __init__(self, image_size: Optional[Union[int, Tuple[int, int]]] = None): |
|
|
super().__init__() |
|
|
if isinstance(image_size, int): |
|
|
image_size = (image_size, image_size) |
|
|
self.image_size = image_size |
|
|
|
|
|
def forward(self, pil_image: Image.Image) -> Image.Image: |
|
|
if self.image_size is not None: |
|
|
pil_image = pil_image.resize(self.image_size) |
|
|
return pil_image |
|
|
|
|
|
|
|
|
def get_loader(datadir: str, batch_size: int = 1, |
|
|
crop_to: Optional[Union[int, Tuple[int, int]]] = None, |
|
|
include_path: bool = False) -> DataLoader: |
|
|
transform = transforms.Compose([ |
|
|
ResizePIL(crop_to), |
|
|
transforms.ToTensor(), |
|
|
MinusOneToOne(), |
|
|
]) |
|
|
loader = DataLoader(FoldersDataset(datadir, transform, include_path=include_path), |
|
|
batch_size=batch_size, |
|
|
shuffle=True, num_workers=0, drop_last=False) |
|
|
return loader |
|
|
|
|
|
|
|
|
class FoldersDataset(VisionDataset): |
|
|
def __init__(self, root: str, transforms: Optional[Callable] = None, |
|
|
include_path: bool = False) -> None: |
|
|
super().__init__(root, transforms) |
|
|
self.include_path = include_path |
|
|
self.root = root |
|
|
|
|
|
if os.path.isdir(root): |
|
|
self.fpaths = glob(os.path.join(root, '**', '*.png'), recursive=True) |
|
|
self.fpaths += glob(os.path.join(root, '**', '*.JPEG'), recursive=True) |
|
|
self.fpaths += glob(os.path.join(root, '**', '*.jpg'), recursive=True) |
|
|
self.fpaths = sorted(self.fpaths) |
|
|
assert len(self.fpaths) > 0, "File list is empty. Check the root." |
|
|
elif os.path.exists(root): |
|
|
self.fpaths = [root] |
|
|
else: |
|
|
raise FileNotFoundError(f"File not found: {root}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.fpaths) |
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, str]: |
|
|
fpath = self.fpaths[index] |
|
|
img = Image.open(fpath).convert('RGB') |
|
|
|
|
|
if self.transforms is not None: |
|
|
img = self.transforms(img) |
|
|
|
|
|
path = "" |
|
|
if self.include_path: |
|
|
dirname = os.path.dirname(fpath) |
|
|
|
|
|
path = dirname[len(self.root) + 1:] |
|
|
return img, os.path.basename(fpath).split(os.extsep)[0], path |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def compress(model: PipelineWrapper, |
|
|
img_to_compress: torch.Tensor, |
|
|
num_noises: int, |
|
|
loaded_indices, |
|
|
device, |
|
|
): |
|
|
|
|
|
dtype = model.dtype |
|
|
|
|
|
prompt_embeds = model.encode_prompt("", None) |
|
|
|
|
|
set_seed(88888888) |
|
|
if img_to_compress is None: |
|
|
img_to_compress = torch.zeros(1, 3, model.get_image_size(), model.get_image_size(), device=device) |
|
|
enc_im = model.encode_image(img_to_compress.to(dtype)) |
|
|
kwargs = model.get_pre_kwargs(height=img_to_compress.shape[-2], width=img_to_compress.shape[-1], |
|
|
prompt_embeds=prompt_embeds) |
|
|
|
|
|
set_seed(100000) |
|
|
xt = torch.randn(1, *enc_im.shape[1:], device=device, dtype=dtype) |
|
|
|
|
|
result_noise_indices = [] |
|
|
|
|
|
pbar = tqdm(model.timesteps) |
|
|
for idx, t in enumerate(pbar): |
|
|
set_seed(idx) |
|
|
noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) |
|
|
|
|
|
_, epst, _ = model.get_epst(xt, t, prompt_embeds, 0.0, **kwargs) |
|
|
x_0_hat = model.get_x_0_hat(xt, epst, t) |
|
|
if loaded_indices is None: |
|
|
|
|
|
if t >= 1: |
|
|
dot_prod = torch.matmul(noise.view(noise.shape[0], -1), |
|
|
(enc_im - x_0_hat).view(enc_im.shape[0], -1).transpose(0, 1)) |
|
|
best_idx = torch.argmax(dot_prod) |
|
|
best_noise = noise[best_idx] |
|
|
else: |
|
|
best_noise = noise[0] |
|
|
else: |
|
|
if t >= 1: |
|
|
best_idx = loaded_indices[idx] |
|
|
best_noise = noise[best_idx] |
|
|
else: |
|
|
best_noise = noise[0] |
|
|
if t >= 1: |
|
|
result_noise_indices.append(best_idx) |
|
|
|
|
|
xt = model.finish_step(xt, x_0_hat, epst, t, best_noise.unsqueeze(0), eta=None) |
|
|
|
|
|
try: |
|
|
img = model.decode_image(xt) |
|
|
except torch.OutOfMemoryError: |
|
|
img = model.decode_image(xt.to('cpu')) |
|
|
|
|
|
return img, torch.tensor(result_noise_indices).squeeze().cpu() |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate_ours(model: PipelineWrapper, |
|
|
num_noises: int, |
|
|
num_noises_to_optimize: int, |
|
|
prompt: str = "", |
|
|
negative_prompt: Optional[str] = None, |
|
|
indices = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
device = model.device |
|
|
dtype = model.dtype |
|
|
|
|
|
|
|
|
|
|
|
set_seed(88888888) |
|
|
if prompt is None: |
|
|
prompt = "" |
|
|
prompt_embeds = model.encode_prompt(prompt, negative_prompt) |
|
|
|
|
|
kwargs = model.get_pre_kwargs(height=model.get_image_size(), |
|
|
width=model.get_image_size(), |
|
|
prompt_embeds=prompt_embeds) |
|
|
|
|
|
set_seed(100000) |
|
|
xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) |
|
|
|
|
|
result_noise_indices = [] |
|
|
pbar = tqdm(model.timesteps) |
|
|
for idx, t in enumerate(pbar): |
|
|
set_seed(idx) |
|
|
noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) |
|
|
|
|
|
_, epst_uncond, epst_cond = model.get_epst(xt, t, prompt_embeds, 1.0, return_everything=True, **kwargs) |
|
|
|
|
|
x_0_hat = model.get_x_0_hat(xt, epst_uncond, t) |
|
|
if t >= 1: |
|
|
if indices is None: |
|
|
prev_classif_score = epst_uncond - epst_cond |
|
|
set_seed(int(time.time_ns() & 0xFFFFFFFF)) |
|
|
noise_indices = torch.randint(0, num_noises, size=(num_noises_to_optimize,), device=device) |
|
|
loss = torch.matmul(noise[noise_indices].view(num_noises_to_optimize, -1), |
|
|
prev_classif_score.view(prev_classif_score.shape[0], -1).transpose(0, 1)) |
|
|
best_idx = noise_indices[torch.argmax(loss)] |
|
|
else: |
|
|
best_idx = indices[idx] |
|
|
best_noise = noise[best_idx] |
|
|
result_noise_indices.append(best_idx) |
|
|
|
|
|
else: |
|
|
best_noise = torch.zeros_like(noise[0]) |
|
|
xt = model.finish_step(xt, x_0_hat, epst_uncond, t, best_noise) |
|
|
|
|
|
try: |
|
|
img = model.decode_image(xt) |
|
|
except torch.OutOfMemoryError: |
|
|
img = model.decode_image(xt.to('cpu')) |
|
|
return img, torch.stack(result_noise_indices).squeeze().cpu() |
|
|
|
|
|
|
|
|
def decompress(model: PipelineWrapper, |
|
|
image_size: Tuple[int, int], |
|
|
indices: Dict[str, torch.Tensor], |
|
|
num_noises: int, |
|
|
prompt: str = "", |
|
|
negative_prompt: Optional[str] = None, |
|
|
tedit: int = 0, |
|
|
new_prompt: str = "", |
|
|
new_negative_prompt: Optional[str] = None, |
|
|
guidance_scale: float = 3.0, |
|
|
num_pursuit_noises: Optional[int] = 1, |
|
|
num_pursuit_coef_bits: Optional[int] = 3, |
|
|
t_range: Tuple[int, int] = (999, 0), |
|
|
robust_randn: bool = False |
|
|
) -> torch.Tensor: |
|
|
noise_indices = indices['noise_indices'] |
|
|
coeffs_indices = indices['coeff_indices'] |
|
|
num_pursuit_noises = num_pursuit_noises if num_pursuit_noises is not None else 1 |
|
|
num_pursuit_coef_bits = num_pursuit_coef_bits if num_pursuit_coef_bits is not None else 1 |
|
|
|
|
|
device = model.device |
|
|
dtype = model.dtype |
|
|
|
|
|
|
|
|
set_seed(88888888) |
|
|
orig_prompt_embeds = model.encode_prompt(prompt, negative_prompt) |
|
|
kwargs_orig = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], |
|
|
prompt_embeds=orig_prompt_embeds) |
|
|
if new_prompt != prompt or new_negative_prompt != negative_prompt: |
|
|
new_prompt_embeds = model.encode_prompt(new_prompt, new_negative_prompt) |
|
|
kwargs_new = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], |
|
|
prompt_embeds=new_prompt_embeds) |
|
|
else: |
|
|
new_prompt_embeds = orig_prompt_embeds |
|
|
kwargs_new = kwargs_orig |
|
|
|
|
|
set_seed(100000) |
|
|
xt = torch.randn(1, *model.get_latent_shape(image_size), device=device, dtype=dtype) |
|
|
|
|
|
pbar = tqdm(model.timesteps) |
|
|
for idx, t in enumerate(pbar): |
|
|
set_seed(idx) |
|
|
|
|
|
dont_optimize_t = not (t_range[0] >= t >= t_range[1]) |
|
|
|
|
|
|
|
|
if robust_randn: |
|
|
noise = get_robust_randn(num_noises if not dont_optimize_t else 1, xt.shape[1:], device, dtype) |
|
|
else: |
|
|
noise = torch.randn(num_noises if not dont_optimize_t else 1, *xt.shape[1:], device=device, dtype=dtype) |
|
|
|
|
|
curr_embs = orig_prompt_embeds if idx < tedit else new_prompt_embeds |
|
|
curr_kwargs = kwargs_orig if idx < tedit else kwargs_new |
|
|
epst = model.get_epst(xt, t, curr_embs, guidance_scale, **curr_kwargs) |
|
|
x_0_hat = model.get_x_0_hat(xt, epst, t) |
|
|
|
|
|
curr_t_noise_indices = noise_indices[idx] |
|
|
best_noise = noise[curr_t_noise_indices[0]] |
|
|
pursuit_coefs = torch.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:] |
|
|
if num_pursuit_noises > 1: |
|
|
curr_t_coeffs_indices = coeffs_indices[idx] |
|
|
if curr_t_coeffs_indices[0] == -1: |
|
|
continue |
|
|
for pursuit_idx in range(1, num_pursuit_noises): |
|
|
pursuit_coef = pursuit_coefs[curr_t_coeffs_indices[pursuit_idx]] |
|
|
best_noise = best_noise * torch.sqrt(pursuit_coef) + noise[ |
|
|
curr_t_noise_indices[pursuit_idx]] * torch.sqrt(1 - pursuit_coef) |
|
|
best_noise /= best_noise.std() |
|
|
best_noise = best_noise.unsqueeze(0) |
|
|
xt = model.finish_step(xt, x_0_hat, epst, t, best_noise) |
|
|
img = model.decode_image(xt) |
|
|
return img |
|
|
|
|
|
|
|
|
def inf_generate(model: PipelineWrapper, |
|
|
prompt: str = "", |
|
|
negative_prompt: Optional[str] = None, |
|
|
guidance_scale: float = 7.0, |
|
|
record: int = 0, |
|
|
save_root: str = "") -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
device = model.device |
|
|
dtype = model.dtype |
|
|
|
|
|
model.set_timesteps(model.num_timesteps, device=device) |
|
|
|
|
|
prompt_embeds = model.encode_prompt(prompt, negative_prompt) |
|
|
kwargs = model.get_pre_kwargs(height=model.get_image_size(), |
|
|
width=model.get_image_size(), |
|
|
prompt_embeds=prompt_embeds) |
|
|
|
|
|
xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) |
|
|
pbar = tqdm(model.timesteps) |
|
|
for idx, t in enumerate(pbar): |
|
|
noise = torch.randn(1, *xt.shape[1:], device=device, dtype=dtype) |
|
|
|
|
|
epst = model.get_epst(xt, t, prompt_embeds, guidance_scale, **kwargs) |
|
|
x_0_hat = model.get_x_0_hat(xt, epst, t) |
|
|
xt = model.finish_step(xt, x_0_hat, epst, t, noise) |
|
|
|
|
|
if record and not idx % record: |
|
|
img = model.decode_image(x_0_hat) |
|
|
plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(t.item()).zfill(4)}.png"), |
|
|
clear_color(img[0].unsqueeze(0), normalize=False)) |
|
|
try: |
|
|
img = model.decode_image(xt) |
|
|
except torch.OutOfMemoryError: |
|
|
img = model.decode_image(xt.to('cpu')) |
|
|
|
|
|
return img |
|
|
|