|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import math |
|
|
import os |
|
|
import sys |
|
|
from argparse import ArgumentParser |
|
|
from pathlib import Path |
|
|
import einops |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image, ImageOps |
|
|
from torch import autocast |
|
|
|
|
|
from edm_sampler.external import CompVisDenoiser |
|
|
from edm_sampler.sampling import sample_euler_ancestral |
|
|
|
|
|
sys.path.append("./stable_diffusion") |
|
|
|
|
|
from stable_diffusion.ldm.util import instantiate_from_config |
|
|
|
|
|
|
|
|
class CFGDenoiser(nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.inner_model = model |
|
|
|
|
|
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): |
|
|
cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) |
|
|
cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) |
|
|
cfg_cond = { |
|
|
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], |
|
|
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], } |
|
|
out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) |
|
|
return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) |
|
|
|
|
|
|
|
|
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): |
|
|
print(f"Loading model from {ckpt}") |
|
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
|
if "global_step" in pl_sd: |
|
|
print(f"Global Step: {pl_sd['global_step']}") |
|
|
sd = pl_sd["state_dict"] |
|
|
if vae_ckpt is not None: |
|
|
print(f"Loading VAE from {vae_ckpt}") |
|
|
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] |
|
|
sd = {k: vae_sd[k[len("first_stage_model."):]] if k.startswith("first_stage_model.") else v for k, v in |
|
|
sd.items()} |
|
|
model = instantiate_from_config(config.model) |
|
|
m, u = model.load_state_dict(sd, strict=False) |
|
|
if len(m) > 0 and verbose: |
|
|
print("missing keys:") |
|
|
print(m) |
|
|
if len(u) > 0 and verbose: |
|
|
print("unexpected keys:") |
|
|
print(u) |
|
|
return model |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = ArgumentParser() |
|
|
parser.add_argument("--data-path", default="../data/clip-filtered-dataset/", type=str) |
|
|
parser.add_argument("--resolution", default=512, type=int) |
|
|
parser.add_argument("--steps", default=100, type=int) |
|
|
parser.add_argument("--config", default="configs/generate.yaml", type=str) |
|
|
parser.add_argument("--ckpt", required=True, type=str) |
|
|
parser.add_argument("--vae-ckpt", default=None, type=str) |
|
|
parser.add_argument("--identifier", required=True, type=str) |
|
|
parser.add_argument("--cfg-text-list", default=[3.5, 5.5, 7.5, 9.5, 11.5], type=float, nargs="+") |
|
|
parser.add_argument("--cfg-image-list", default=[1.5], type=float, nargs="+") |
|
|
parser.add_argument("--seed", type=int, default=10086) |
|
|
parser.add_argument("--sample-num", type=int, default=200) |
|
|
parser.add_argument("--eval-type", default="edit", |
|
|
|
|
|
choices=["edit", "depth", "hed", "seg"], type=str) |
|
|
args = parser.parse_args() |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
config = OmegaConf.load(args.config) |
|
|
model = load_model_from_config(config, args.ckpt, args.vae_ckpt) |
|
|
model.eval().cuda() |
|
|
model_wrap = CompVisDenoiser(model) |
|
|
model_wrap_cfg = CFGDenoiser(model_wrap) |
|
|
null_token = model.get_learned_conditioning([""]) |
|
|
|
|
|
with open(Path(args.data_path, "seeds.json")) as f: |
|
|
seeds = json.load(f) |
|
|
|
|
|
total = len(seeds) |
|
|
i_start = int(total * 0.9) |
|
|
i_end = i_start + args.sample_num |
|
|
|
|
|
output_dir = f"imgs/qualitative/{args.identifier}" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
output_dir = os.path.join(output_dir, args.eval_type) |
|
|
|
|
|
for i in range(i_start, i_end): |
|
|
print(f"===========================> Processing {i}/{total} <===========================") |
|
|
name, i_seeds = seeds[i] |
|
|
output_sub_dir = os.path.join(output_dir, name) |
|
|
os.makedirs(output_sub_dir, exist_ok=True) |
|
|
propt_dir = Path(args.data_path, name) |
|
|
if args.eval_type == "edit": |
|
|
with open(propt_dir.joinpath("prompt.json")) as fp: |
|
|
edit_instruction = json.load(fp)["edit"] |
|
|
elif args.eval_type == "depth": |
|
|
edit_instruction = "Transfer to a depth map" |
|
|
elif args.eval_type == "hed": |
|
|
edit_instruction = "Transfer to a hed map" |
|
|
elif args.eval_type == "seg": |
|
|
edit_instruction = "Transfer to a segmentation map" |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
image_seed = i_seeds[0] |
|
|
|
|
|
input_image_path = propt_dir.joinpath(f"{image_seed}_0.jpg") |
|
|
input_image = Image.open(input_image_path).convert("RGB") |
|
|
width, height = input_image.size |
|
|
factor = args.resolution / max(width, height) |
|
|
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) |
|
|
width = int((width * factor) // 64) * 64 |
|
|
height = int((height * factor) // 64) * 64 |
|
|
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) |
|
|
|
|
|
for cfg_text in args.cfg_text_list: |
|
|
for cfg_image in args.cfg_image_list: |
|
|
output_image_path = os.path.join(output_sub_dir, f"{image_seed}_text{cfg_text}_image{cfg_image}.jpg") |
|
|
with torch.no_grad(), autocast("cuda"), model.ema_scope(): |
|
|
cond = {} |
|
|
cond["c_crossattn"] = [model.get_learned_conditioning([edit_instruction])] |
|
|
new_input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 |
|
|
new_input_image = rearrange(new_input_image, "h w c -> 1 c h w").to(model.device) |
|
|
cond["c_concat"] = [model.encode_first_stage(new_input_image).mode()] |
|
|
|
|
|
uncond = {} |
|
|
uncond["c_crossattn"] = [null_token] |
|
|
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] |
|
|
|
|
|
sigmas = model_wrap.get_sigmas(args.steps) |
|
|
print(f"Editing the image {input_image_path}, with cfg_text={cfg_text}, cfg_image={cfg_image}") |
|
|
extra_args = {"cond": cond, "uncond": uncond, "text_cfg_scale": cfg_text, |
|
|
"image_cfg_scale": cfg_image, } |
|
|
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0] |
|
|
z = sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) |
|
|
x = model.decode_first_stage(z) |
|
|
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
x = 255.0 * rearrange(x, "1 c h w -> h w c") |
|
|
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy()) |
|
|
edited_image.save(output_image_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|