| from __future__ import annotations |
|
|
| import math |
| import random |
| import sys |
| from argparse import ArgumentParser |
|
|
| import einops |
| import k_diffusion as K |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from tqdm.auto import tqdm |
| from einops import rearrange |
| from omegaconf import OmegaConf |
| from PIL import Image, ImageOps |
| from torch import autocast |
|
|
| import json |
| import matplotlib.pyplot as plt |
| import seaborn |
| from pathlib import Path |
|
|
| sys.path.append("./") |
|
|
| from clip_similarity import ClipSimilarity |
| from edit_dataset import EditDatasetEval |
|
|
| sys.path.append("./stable_diffusion") |
|
|
| from ldm.util import instantiate_from_config |
|
|
|
|
| class CFGDenoiser(nn.Module): |
| def __init__(self, model): |
| super().__init__() |
| self.inner_model = model |
|
|
| def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): |
| cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) |
| cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) |
| cfg_cond = { |
| "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], |
| "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], |
| } |
| out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) |
| return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) |
|
|
|
|
| def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): |
| print(f"Loading model from {ckpt}") |
| pl_sd = torch.load(ckpt, map_location="cpu") |
| if "global_step" in pl_sd: |
| print(f"Global Step: {pl_sd['global_step']}") |
| sd = pl_sd["state_dict"] |
| if vae_ckpt is not None: |
| print(f"Loading VAE from {vae_ckpt}") |
| vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] |
| sd = { |
| k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v |
| for k, v in sd.items() |
| } |
| model = instantiate_from_config(config.model) |
| m, u = model.load_state_dict(sd, strict=False) |
| if len(m) > 0 and verbose: |
| print("missing keys:") |
| print(m) |
| if len(u) > 0 and verbose: |
| print("unexpected keys:") |
| print(u) |
| return model |
|
|
| class ImageEditor(nn.Module): |
| def __init__(self, config, ckpt, vae_ckpt=None): |
| super().__init__() |
| |
| config = OmegaConf.load(config) |
| self.model = load_model_from_config(config, ckpt, vae_ckpt) |
| self.model.eval().cuda() |
| self.model_wrap = K.external.CompVisDenoiser(self.model) |
| self.model_wrap_cfg = CFGDenoiser(self.model_wrap) |
| self.null_token = self.model.get_learned_conditioning([""]) |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| edit: str, |
| scale_txt: float = 7.5, |
| scale_img: float = 1.0, |
| steps: int = 100, |
| ) -> torch.Tensor: |
| assert image.dim() == 3 |
| assert image.size(1) % 64 == 0 |
| assert image.size(2) % 64 == 0 |
| with torch.no_grad(), autocast("cuda"), self.model.ema_scope(): |
| cond = { |
| "c_crossattn": [self.model.get_learned_conditioning([edit])], |
| "c_concat": [self.model.encode_first_stage(image[None]).mode()], |
| } |
| uncond = { |
| "c_crossattn": [self.model.get_learned_conditioning([""])], |
| "c_concat": [torch.zeros_like(cond["c_concat"][0])], |
| } |
| extra_args = { |
| "uncond": uncond, |
| "cond": cond, |
| "image_cfg_scale": scale_img, |
| "text_cfg_scale": scale_txt, |
| } |
| sigmas = self.model_wrap.get_sigmas(steps) |
| x = torch.randn_like(cond["c_concat"][0]) * sigmas[0] |
| x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args) |
| x = self.model.decode_first_stage(x)[0] |
| return x |
|
|
|
|
| def compute_metrics(config, |
| model_path, |
| vae_ckpt, |
| data_path, |
| output_path, |
| scales_img, |
| scales_txt, |
| num_samples = 5000, |
| split = "test", |
| steps = 50, |
| res = 512, |
| seed = 0): |
| editor = ImageEditor(config, model_path, vae_ckpt).cuda() |
| clip_similarity = ClipSimilarity().cuda() |
|
|
|
|
|
|
| outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl") |
| Path(output_path).mkdir(parents=True, exist_ok=True) |
|
|
| for scale_txt in scales_txt: |
| for scale_img in scales_img: |
| dataset = EditDatasetEval( |
| path=data_path, |
| split=split, |
| res=res |
| ) |
| assert num_samples <= len(dataset) |
| print(f'Processing t={scale_txt}, i={scale_img}') |
| torch.manual_seed(seed) |
| perm = torch.randperm(len(dataset)) |
| count = 0 |
| i = 0 |
|
|
| sim_0_avg = 0 |
| sim_1_avg = 0 |
| sim_direction_avg = 0 |
| sim_image_avg = 0 |
| count = 0 |
|
|
| pbar = tqdm(total=num_samples) |
| while count < num_samples: |
| |
| idx = perm[i].item() |
| sample = dataset[idx] |
| i += 1 |
|
|
| gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps) |
|
|
| sim_0, sim_1, sim_direction, sim_image = clip_similarity( |
| sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]] |
| ) |
| sim_0_avg += sim_0.item() |
| sim_1_avg += sim_1.item() |
| sim_direction_avg += sim_direction.item() |
| sim_image_avg += sim_image.item() |
| count += 1 |
| pbar.update(count) |
| pbar.close() |
|
|
| sim_0_avg /= count |
| sim_1_avg /= count |
| sim_direction_avg /= count |
| sim_image_avg /= count |
|
|
| with open(outpath, "a") as f: |
| f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n") |
| return outpath |
|
|
| def plot_metrics(metrics_file, output_path): |
| |
| with open(metrics_file, 'r') as f: |
| data = [json.loads(line) for line in f] |
| |
| plt.rcParams.update({'font.size': 11.5}) |
| seaborn.set_style("darkgrid") |
| plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200) |
|
|
| x = [d["sim_direction"] for d in data] |
| y = [d["sim_image"] for d in data] |
|
|
| plt.plot(x, y, marker='o', linewidth=2, markersize=4) |
|
|
| plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10) |
| plt.ylabel("CLIP Image Similarity", labelpad=10) |
|
|
| plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight") |
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument("--resolution", default=512, type=int) |
| parser.add_argument("--steps", default=100, type=int) |
| parser.add_argument("--config", default="configs/generate.yaml", type=str) |
| parser.add_argument("--output_path", default="analysis/", type=str) |
| parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str) |
| parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str) |
| parser.add_argument("--vae-ckpt", default=None, type=str) |
| args = parser.parse_args() |
|
|
| scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2] |
| scales_txt = [7.5] |
| |
| metrics_file = compute_metrics( |
| args.config, |
| args.ckpt, |
| args.vae_ckpt, |
| args.dataset, |
| args.output_path, |
| scales_img, |
| scales_txt, |
| steps = args.steps, |
| ) |
| |
| plot_metrics(metrics_file, args.output_path) |
| |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|