Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from torch.optim import AdamW | |
| from diffusers import StableDiffusionPipeline | |
| from torch import autocast, inference_mode | |
| import torch | |
| import numpy as np | |
| from scheduling_ddim import DDIMScheduler | |
| device = 'cuda' | |
| # don't forget to add your token or comment if already logged in | |
| pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", | |
| scheduler=DDIMScheduler(beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| beta_start=0.00085), | |
| use_auth_token="").to(device) | |
| _ = pipe.vae.requires_grad_(False) | |
| _ = pipe.text_encoder.requires_grad_(False) | |
| _ = pipe.unet.requires_grad_(False) | |
| def preprocess(image): | |
| w, h = image.size | |
| w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | |
| image = image.resize((w, h), resample=Image.LANCZOS) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = image[None].transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| return 2.0 * image - 1.0 | |
| def im2latent(pipe, im, generator): | |
| init_image = preprocess(im).to(pipe.device) | |
| init_latent_dist = pipe.vae.encode(init_image).latent_dist | |
| init_latents = init_latent_dist.sample(generator=generator) | |
| return init_latents * 0.18215 | |
| def image_mod(init_image, source_prompt, prompt, scale, steps, seed): | |
| # fix seed | |
| g = torch.Generator(device=pipe.device).manual_seed(84) | |
| image_latents = im2latent(pipe, init_image, g) | |
| pipe.scheduler.set_timesteps(steps) | |
| # use text describing an image | |
| # source_prompt = "a photo of a woman" | |
| context = pipe._encode_prompt(source_prompt, pipe.device, 1, False, "") | |
| decoded_latents = image_latents.clone() | |
| with autocast(device), inference_mode(): | |
| # we are pivoting timesteps as we are moving in opposite direction | |
| timesteps = pipe.scheduler.timesteps.flip(0) | |
| # this would be our targets for pivoting | |
| init_trajectory = torch.empty(len(timesteps), *decoded_latents.size()[1:], device=decoded_latents.device, dtype=decoded_latents.dtype) | |
| for i, t in enumerate(tqdm(timesteps)): | |
| init_trajectory[i:i+1] = decoded_latents | |
| noise_pred = pipe.unet(decoded_latents, t, encoder_hidden_states=context).sample | |
| decoded_latents = pipe.scheduler.reverse_step(noise_pred, t, decoded_latents).next_sample | |
| # we would need to flip trajectory values for pivoting in right direction | |
| init_trajectory = init_trajectory.cpu().flip(0) | |
| latents = decoded_latents.clone() | |
| context_uncond = pipe._encode_prompt("", pipe.device, 1, False, "") | |
| # we will be optimizing uncond text embedding | |
| context_uncond.requires_grad_(True) | |
| # use same text | |
| # prompt = "a photo of a woman" | |
| context_cond = pipe._encode_prompt(prompt, pipe.device, 1, False, "") | |
| # default lr works | |
| opt = AdamW([context_uncond]) | |
| # concat latents for classifier-free guidance | |
| latents = torch.cat([latents, latents]) | |
| latents.requires_grad_(True) | |
| context = torch.cat((context_uncond, context_cond)) | |
| with autocast(device): | |
| for i, t in enumerate(tqdm(pipe.scheduler.timesteps)): | |
| latents = pipe.scheduler.scale_model_input(latents, t) | |
| uncond, cond = pipe.unet(latents, t, encoder_hidden_states=context).sample.chunk(2) | |
| with torch.enable_grad(): | |
| latents = pipe.scheduler.step(uncond + scale * (cond - uncond), t, latents, generator=g).prev_sample | |
| opt.zero_grad() | |
| # optimize uncond text emb | |
| pivot_value = init_trajectory[[i]].to(pipe.device) | |
| (latents - pivot_value).mean().backward() | |
| opt.step() | |
| latents = latents.detach() | |
| images = pipe.decode_latents(latents) | |
| im = pipe.numpy_to_pil(images)[0] | |
| return im | |
| demo = gr.Interface( | |
| image_mod, | |
| inputs=[gr.Image(type="pil"), gr.Textbox("a photo of a person"), gr.Textbox("a photo of a person"), gr.Slider(0, 10, 0.5, 0.1), gr.Slider(0, 100, 51, 1), gr.Number(42)], | |
| outputs="image", | |
| flagging_options=["blurry", "incorrect", "other"], examples=[ | |
| os.path.join(os.path.dirname(__file__), "images/00001.jpg"), | |
| ]) | |
| if __name__ == "__main__": | |
| demo.launch() | |