yonishafir commited on
Commit
fd8f569
·
verified ·
1 Parent(s): 9ba3183

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +118 -0
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### How To Use
2
+
3
+ ```python
4
+ from diffusers import (
5
+ AutoencoderKL,
6
+ StableDiffusionXLControlNetInpaintPipeline,
7
+ LCMScheduler,
8
+ )
9
+ from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
10
+ from controlnet import ControlNetModel, ControlNetConditioningEmbedding
11
+ import os
12
+ from torchvision import transforms
13
+ import torch
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ import pandas as pd
17
+ from PIL import Image
18
+
19
+
20
+ def download_image(url):
21
+ response = requests.get(url)
22
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
23
+
24
+
25
+ def get_masked_image(path_to_images_dir, image_name, image, image_mask, width, height):
26
+ image_mask = image_mask # inpaint area is white
27
+ image_mask = add_margins_to_ratio(image_mask, 1.5)
28
+ image_mask = image_mask.resize((width, height)) # object to remove is white (1)
29
+ image_mask_pil = image_mask
30
+ orig_image = np.array(image.convert("RGB")).astype(np.uint8)
31
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
32
+ image_mask = np.array(image_mask_pil.convert("L")).astype(np.float32) / 255.0
33
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
34
+ masked_image_to_present = image.copy()
35
+ masked_image_to_present[image_mask > 0.5] = (0.5,0.5,0.5) # set as masked pixel
36
+ image[image_mask > 0.5] = 0.5 # set as masked pixel - s.t. will be grey
37
+ image = Image.fromarray((image * 255.0).astype(np.uint8))
38
+ masked_image_to_present = Image.fromarray((masked_image_to_present * 255.0).astype(np.uint8))
39
+ return image, image_mask_pil, masked_image_to_present
40
+
41
+
42
+ img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
43
+ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
44
+
45
+ init_image = download_image(img_url).resize((1024, 1024))
46
+ mask_image = download_image(mask_url).resize((1024, 1024))
47
+
48
+ # Load, init model
49
+ controlnet = ControlNetModel().from_config('/home/ubuntu/spring/Infra/project_x/bria2_controlnet_inpainting/config_controlnet_vae.json', torch_dtype=torch.float16)
50
+ controlnet.controlnet_cond_embedding = ControlNetConditioningEmbedding(
51
+ conditioning_embedding_channels=320,
52
+ conditioning_channels = 5
53
+ )
54
+
55
+ controlnet.load_state_dict(torch.load(local_ckpt_dir + local_ckpt_dir_suffix))
56
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
57
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) #force_zeros_for_empty_prompt=False, # vae=vae)
58
+
59
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
60
+ pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
61
+ pipe.fuse_lora()
62
+
63
+ pipe = pipe.to('cuda:0')
64
+ pipe.enable_xformers_memory_efficient_attention()
65
+
66
+ generator = torch.Generator(device='cuda:0').manual_seed(123456)
67
+
68
+
69
+
70
+
71
+ vae = pipe.vae
72
+
73
+ masked_image, image_mask, masked_image_to_present = get_masked_image(path_to_images_dir, image_name, image_mask, img, width, height)
74
+ masked_image_tensor = image_transforms(masked_image)
75
+ masked_image_tensor = (masked_image_tensor - 0.5) / 0.5
76
+
77
+ masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda")
78
+ # masked_image_tensor = masked_image_tensor.permute((0,3,1,2))
79
+ control_latents = vae.encode(
80
+ masked_image_tensor[:, :3, :, :].to(vae.dtype)
81
+ ).latent_dist.sample()
82
+ control_latents = control_latents * vae.config.scaling_factor
83
+
84
+
85
+ image_mask = np.array(image_mask)[:,:]
86
+ mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, ...]
87
+ # binarize the mask
88
+ mask_tensor = torch.where(mask_tensor > 128.0, 255.0, 0)
89
+
90
+ if normalize_mask_to_0_1:
91
+ mask_tensor = mask_tensor / 255.0
92
+
93
+ mask_tensor = mask_tensor.to(device="cuda")
94
+ mask_resized = torch.nn.functional.interpolate(mask_tensor[None, ...], size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
95
+ # mask_resized = mask_resized.to(torch.float16)
96
+ masked_image = torch.cat([control_latents, mask_resized], dim=1)
97
+
98
+ gen_img = pipe(negative_prompt=default_negative_prompt, prompt=caption,
99
+ controlnet_conditioning_sale=1.0,
100
+ num_inference_steps=12,
101
+ height=height, width=width,
102
+ image = masked_image, # control image
103
+ init_image = img,
104
+ mask_image = mask_tensor,
105
+ guidance_scale = 1.2,
106
+ generator=generator).images[0]
107
+
108
+
109
+
110
+
111
+
112
+
113
+ prompt = "A park bench"
114
+ generator = torch.Generator(device='cuda:0').manual_seed(123456)
115
+ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image,generator=generator,guidance_scale=5,strength=1).images[0]
116
+ image.save("./a_park_bench.png")
117
+ ```
118
+