File size: 4,616 Bytes
50261d7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import gradio as gr
import torch
import numpy as np
import requests
import random
from io import BytesIO
from utils import *
from constants import *
from inversion_utils import *
from modified_pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline
from torch import autocast, inference_mode
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
from transformers import AutoProcessor, BlipForConditionalGeneration
from share_btn import community_icon_html, loading_icon_html, share_js
from PIL import ImageFile
import random
# load pipelines
sd_model_id = "sd_model_v1-5"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder="scheduler")
sega_pipe = SemanticStableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
ImageFile.LOAD_TRUNCATED_IMAGES = True
input_image = "./examples/00001.png" # @param
source_prompt = "human face" # @param
target_prompt = "make up like a clown" # @param
num_diffusion_steps = 100 # @param
source_guidance_scale = 0 # @param
reconstruct = True # @param
skip_steps = 50 # @param
target_guidance_scale = 10 # @param
# SEGA only params
edit_concepts = ["star makeup", "heart makeup"] # @param
edit_guidance_scales = [7, 15] # @param
warmup_steps = [1, 1] # @param
reverse_editing = [True, False] # @param
thresholds = [0.95, 0.95] # @param
def invert(x0: torch.FloatTensor, prompt_src: str = "", num_inference_steps=100, cfg_scale_src=3.5, eta=1):
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
# based on the code in https://github.com/inbarhub/DDPM_inversion
# returns wt, zs, wts:
# wt - inverted latent
# wts - intermediate inverted latents
# zs - noise maps
sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
# vae encode image
with autocast("cuda"), inference_mode():
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
# find Zs and wts - forward process
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src,
prog_bar=True, num_inference_steps=num_diffusion_steps)
return zs, wts
def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
# reverse process (via Zs and wT)
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar],
prog_bar=True, zs=zs[skip:])
# vae decode image
with autocast("cuda"), inference_mode():
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
if x0_dec.dim() < 4:
x0_dec = x0_dec[None, :, :, :]
img = image_grid(x0_dec)
return img
def edit(wts, zs,
tar_prompt="",
steps=100,
skip=36,
tar_cfg_scale=8,
edit_concept="",
guidnace_scale=7,
warmup=1,
neg_guidance=False,
threshold=0.95
):
# SEGA
# parse concepts and neg guidance
editing_args = dict(
editing_prompt=edit_concept,
reverse_editing_direction=neg_guidance,
edit_warmup_steps=warmup,
edit_guidance_scale=guidnace_scale,
edit_threshold=threshold,
edit_momentum_scale=0.5,
edit_mom_beta=0.6,
eta=1,
)
latnets = wts[skip].expand(1, -1, -1, -1)
sega_out = sega_pipe(prompt=tar_prompt, latents=latnets, guidance_scale=tar_cfg_scale,
num_images_per_prompt=1,
num_inference_steps=steps,
use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
return sega_out.images[0]
with open('prompt.txt') as file:
lines = file.readlines()
lines = [line.strip() for line in lines]
for i in range(10000):
try:
input_image = 'origin_face' + str(i + 1) + '.png'
randn = random.randint(0, len(lines)-1)
target_prompt = lines[randn]
x0 = load_512(input_image, device=device)
# noise maps and latents
zs, wts = invert(x0=x0, prompt_src=source_prompt, num_inference_steps=num_diffusion_steps,
cfg_scale_src=source_guidance_scale)
if reconstruct:
ddpm_out_img = sample(zs, wts, prompt_tar=target_prompt, skip=skip_steps, cfg_scale_tar=target_guidance_scale)
ddpm_out_img.save(f'makeup/edit_{i+1}.png')
except:
continue
|