camenduru commited on
Commit
be1693d
·
verified ·
1 Parent(s): b541c0f

thanks to InstantX ❤

Browse files
.gitattributes CHANGED
@@ -33,6 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- examples/0.png filter=lfs diff=lfs merge=lfs -text
37
- examples/1.png filter=lfs diff=lfs merge=lfs -text
38
- examples/applications.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/kaifu_resize.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/sam_resize.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/schmidhuber_resize.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,128 +1,13 @@
1
  ---
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
- language:
4
- - en
5
- library_name: diffusers
6
- pipeline_tag: text-to-image
7
  ---
8
 
9
- # InstantID Model Card
10
-
11
- <div align="center">
12
-
13
- [**Project Page**](https://instantid.github.io/) **|** [**Paper**](https://arxiv.org/abs/2401.07519) **|** [**Code**](https://github.com/InstantID/InstantID) **|** [🤗 **Gradio demo**](https://huggingface.co/spaces/InstantX/InstantID)
14
-
15
-
16
- </div>
17
-
18
- ## Introduction
19
-
20
- InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks.
21
-
22
- <div align="center">
23
- <img src='examples/applications.png'>
24
- </div>
25
-
26
-
27
- ## Usage
28
-
29
- You can directly download the model in this repository.
30
- You also can download the model in python script:
31
-
32
- ```python
33
- from huggingface_hub import hf_hub_download
34
- hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
35
- hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
36
- hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
37
- ```
38
-
39
- For face encoder, you need to manutally download via this [URL](https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304) to `models/antelopev2`.
40
-
41
- ```python
42
- # !pip install opencv-python transformers accelerate insightface
43
- import diffusers
44
- from diffusers.utils import load_image
45
- from diffusers.models import ControlNetModel
46
-
47
- import cv2
48
- import torch
49
- import numpy as np
50
- from PIL import Image
51
-
52
- from insightface.app import FaceAnalysis
53
- from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
54
-
55
- # prepare 'antelopev2' under ./models
56
- app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
57
- app.prepare(ctx_id=0, det_size=(640, 640))
58
-
59
- # prepare models under ./checkpoints
60
- face_adapter = f'./checkpoints/ip-adapter.bin'
61
- controlnet_path = f'./checkpoints/ControlNetModel'
62
-
63
- # load IdentityNet
64
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
65
-
66
- pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
67
- ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
68
- ... )
69
- pipe.cuda()
70
-
71
- # load adapter
72
- pipe.load_ip_adapter_instantid(face_adapter)
73
- ```
74
-
75
- Then, you can customized your own face images
76
-
77
- ```python
78
- # load an image
79
- image = load_image("your-example.jpg")
80
-
81
- # prepare face emb
82
- face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
83
- face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
84
- face_emb = face_info['embedding']
85
- face_kps = draw_kps(face_image, face_info['kps'])
86
-
87
- pipe.set_ip_adapter_scale(0.8)
88
-
89
- prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
90
- negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
91
-
92
- # generate image
93
- image = pipe(
94
- ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
95
- ... ).images[0]
96
- ```
97
-
98
- For more details, please follow the instructions in our [GitHub repository](https://github.com/InstantID/InstantID).
99
-
100
- ## Usage Tips
101
- 1. If you're not satisfied with the similarity, try to increase the weight of "IdentityNet Strength" and "Adapter Strength".
102
- 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it is still too high, then decrease the IdentityNet strength.
103
- 3. If you find that text control is not as expected, decrease Adapter strength.
104
- 4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
105
-
106
- ## Demos
107
-
108
- <div align="center">
109
- <img src='examples/0.png'>
110
- </div>
111
-
112
- <div align="center">
113
- <img src='examples/1.png'>
114
- </div>
115
-
116
- ## Disclaimer
117
-
118
- This project is released under Apache License and aims to positively impact the field of AI-driven image generation. Users are granted the freedom to create images using this tool, but they are obligated to comply with local laws and utilize it responsibly. The developers will not assume any responsibility for potential misuse by users.
119
-
120
- ## Citation
121
- ```bibtex
122
- @article{wang2024instantid,
123
- title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
124
- author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
125
- journal={arXiv preprint arXiv:2401.07519},
126
- year={2024}
127
- }
128
- ```
 
1
  ---
2
+ title: InstantID
3
+ emoji: 😻
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.15.0
8
+ app_file: app.py
9
+ pinned: false
10
  license: apache-2.0
 
 
 
 
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+
8
+ import PIL
9
+ from PIL import Image
10
+
11
+ import diffusers
12
+ from diffusers.utils import load_image
13
+ from diffusers.models import ControlNetModel
14
+
15
+ import insightface
16
+ from insightface.app import FaceAnalysis
17
+
18
+ from style_template import styles
19
+ from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
20
+
21
+ import spaces
22
+ import gradio as gr
23
+
24
+ # global variable
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ STYLE_NAMES = list(styles.keys())
28
+ DEFAULT_STYLE_NAME = "Watercolor"
29
+
30
+ # download checkpoints
31
+ from huggingface_hub import hf_hub_download
32
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
33
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
34
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
35
+
36
+ # Load face encoder
37
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
38
+ app.prepare(ctx_id=0, det_size=(640, 640))
39
+
40
+ # Path to InstantID models
41
+ face_adapter = f'./checkpoints/ip-adapter.bin'
42
+ controlnet_path = f'./checkpoints/ControlNetModel'
43
+
44
+ # Load pipeline
45
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
46
+
47
+ base_model_path = 'GHArt/Unstable_Diffusers_YamerMIX_V9_xl_fp16'
48
+
49
+ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
50
+ base_model_path,
51
+ controlnet=controlnet,
52
+ torch_dtype=torch.float16,
53
+ safety_checker=None,
54
+ feature_extractor=None,
55
+ )
56
+ pipe.cuda()
57
+ pipe.load_ip_adapter_instantid(face_adapter)
58
+
59
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
60
+ if randomize_seed:
61
+ seed = random.randint(0, MAX_SEED)
62
+ return seed
63
+
64
+ def swap_to_gallery(images):
65
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
66
+
67
+ def upload_example_to_gallery(images, prompt, style, negative_prompt):
68
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
69
+
70
+ def remove_back_to_files():
71
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
72
+
73
+ def remove_tips():
74
+ return gr.update(visible=False)
75
+
76
+ def get_example():
77
+ case = [
78
+ [
79
+ ['./examples/yann-lecun_resize.jpg'],
80
+ "a man",
81
+ "Snow",
82
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
83
+ ],
84
+ [
85
+ ['./examples/musk_resize.jpeg'],
86
+ "a man",
87
+ "Mars",
88
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
89
+ ],
90
+ [
91
+ ['./examples/sam_resize.png'],
92
+ "a man",
93
+ "Jungle",
94
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
95
+ ],
96
+ [
97
+ ['./examples/schmidhuber_resize.png'],
98
+ "a man",
99
+ "Neon",
100
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
101
+ ],
102
+ [
103
+ ['./examples/kaifu_resize.png'],
104
+ "a man",
105
+ "Vibrant Color",
106
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
107
+ ],
108
+ ]
109
+ return case
110
+
111
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
112
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
113
+
114
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
115
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
116
+
117
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
118
+ stickwidth = 4
119
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
120
+ kps = np.array(kps)
121
+
122
+ w, h = image_pil.size
123
+ out_img = np.zeros([h, w, 3])
124
+
125
+ for i in range(len(limbSeq)):
126
+ index = limbSeq[i]
127
+ color = color_list[index[0]]
128
+
129
+ x = kps[index][:, 0]
130
+ y = kps[index][:, 1]
131
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
132
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
133
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
134
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
135
+ out_img = (out_img * 0.6).astype(np.uint8)
136
+
137
+ for idx_kp, kp in enumerate(kps):
138
+ color = color_list[idx_kp]
139
+ x, y = kp
140
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
141
+
142
+ out_img_pil = Image.fromarray(out_img.astype(np.uint8))
143
+ return out_img_pil
144
+
145
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
146
+ pad_to_max_side=False, mode=PIL.Image.BILINEAR, base_pixel_number=64):
147
+
148
+ w, h = input_image.size
149
+ if size is not None:
150
+ w_resize_new, h_resize_new = size
151
+ else:
152
+ ratio = min_side / min(h, w)
153
+ w, h = round(ratio*w), round(ratio*h)
154
+ ratio = max_side / max(h, w)
155
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
156
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
157
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
158
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
159
+
160
+ if pad_to_max_side:
161
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
162
+ offset_x = (max_side - w_resize_new) // 2
163
+ offset_y = (max_side - h_resize_new) // 2
164
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
165
+ input_image = Image.fromarray(res)
166
+ return input_image
167
+
168
+ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
169
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
170
+ return p.replace("{prompt}", positive), n + ' ' + negative
171
+
172
+ @spaces.GPU
173
+ def generate_image(face_image, pose_image, prompt, negative_prompt, style_name, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
174
+
175
+ if face_image is None:
176
+ raise gr.Error(f"Cannot find any input face image! Please upload the face image")
177
+
178
+ if prompt is None:
179
+ prompt = "a person"
180
+
181
+ # apply the style template
182
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
183
+
184
+ face_image = load_image(face_image[0])
185
+ face_image = resize_img(face_image)
186
+ face_image_cv2 = convert_from_image_to_cv2(face_image)
187
+ height, width, _ = face_image_cv2.shape
188
+
189
+ # Extract face features
190
+ face_info = app.get(face_image_cv2)
191
+
192
+ if len(face_info) == 0:
193
+ raise gr.Error(f"Cannot find any face in the image! Please upload another person image")
194
+
195
+ face_info = face_info[-1]
196
+ face_emb = face_info['embedding']
197
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info['kps'])
198
+
199
+ if pose_image is not None:
200
+ pose_image = load_image(pose_image[0])
201
+ pose_image = resize_img(pose_image)
202
+ pose_image_cv2 = convert_from_image_to_cv2(pose_image)
203
+
204
+ face_info = app.get(pose_image_cv2)
205
+
206
+ if len(face_info) == 0:
207
+ raise gr.Error(f"Cannot find any face in the reference image! Please upload another person image")
208
+
209
+ face_info = face_info[-1]
210
+ face_kps = draw_kps(pose_image, face_info['kps'])
211
+
212
+ width, height = face_kps.size
213
+
214
+ if enhance_face_region:
215
+ control_mask = np.zeros([height, width, 3])
216
+ x1, y1, x2, y2 = face_info['bbox']
217
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
218
+ control_mask[y1:y2, x1:x2] = 255
219
+ control_mask = Image.fromarray(control_mask.astype(np.uint8))
220
+ else:
221
+ control_mask = None
222
+
223
+ generator = torch.Generator(device=device).manual_seed(seed)
224
+
225
+ print("Start inference...")
226
+ print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
227
+
228
+ pipe.set_ip_adapter_scale(adapter_strength_ratio)
229
+ images = pipe(
230
+ prompt=prompt,
231
+ negative_prompt=negative_prompt,
232
+ image_embeds=face_emb,
233
+ image=face_kps,
234
+ control_mask=control_mask,
235
+ controlnet_conditioning_scale=float(identitynet_strength_ratio),
236
+ num_inference_steps=num_steps,
237
+ guidance_scale=guidance_scale,
238
+ height=height,
239
+ width=width,
240
+ generator=generator
241
+ ).images
242
+
243
+ return images, gr.update(visible=True)
244
+
245
+ ### Description
246
+ title = r"""
247
+ <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
248
+ """
249
+
250
+ description = r"""
251
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
252
+
253
+ How to use:<br>
254
+ 1. Upload a person image. For multiple person images, we will only detect the biggest face. Make sure face is not too small and not significantly blocked or blurred.
255
+ 2. (Optionally) upload another person image as reference pose. If not uploaded, we will use the first person image to extract landmarks. If you use a cropped face at step1, it is recommeneded to upload it to extract a new pose.
256
+ 3. Enter a text prompt as done in normal text-to-image models.
257
+ 4. Click the <b>Submit</b> button to start customizing.
258
+ 5. Share your customizd photo with your friends, enjoy😊!
259
+ """
260
+
261
+ article = r"""
262
+ ---
263
+ 📝 **Citation**
264
+ <br>
265
+ If our work is helpful for your research or applications, please cite us via:
266
+ ```bibtex
267
+ @article{wang2024instantid,
268
+ title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
269
+ author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
270
+ journal={arXiv preprint arXiv:2401.07519},
271
+ year={2024}
272
+ }
273
+ ```
274
+ 📧 **Contact**
275
+ <br>
276
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
277
+ """
278
+
279
+ tips = r"""
280
+ ### Usage tips of InstantID
281
+ 1. If you're unsatisfied with the similarity, increase the weight of controlnet_conditioning_scale (IdentityNet) and ip_adapter_scale (Adapter).
282
+ 2. If the generated image is over-saturated, decrease the ip_adapter_scale. If not work, decrease controlnet_conditioning_scale.
283
+ 3. If text control is not as expected, decrease ip_adapter_scale.
284
+ 4. Find a good base model always makes a difference.
285
+ """
286
+
287
+ css = '''
288
+ .gradio-container {width: 85% !important}
289
+ '''
290
+ with gr.Blocks(css=css) as demo:
291
+
292
+ # description
293
+ gr.Markdown(title)
294
+ gr.Markdown(description)
295
+
296
+ with gr.Row():
297
+ with gr.Column():
298
+
299
+ # upload face image
300
+ face_files = gr.Files(
301
+ label="Upload a photo of your face",
302
+ file_types=["image"]
303
+ )
304
+ uploaded_faces = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
305
+ with gr.Column(visible=False) as clear_button_face:
306
+ remove_and_reupload_faces = gr.ClearButton(value="Remove and upload new ones", components=face_files, size="sm")
307
+
308
+ # optional: upload a reference pose image
309
+ pose_files = gr.Files(
310
+ label="Upload a reference pose image (optional)",
311
+ file_types=["image"]
312
+ )
313
+ uploaded_poses = gr.Gallery(label="Your images", visible=False, columns=1, rows=1, height=512)
314
+ with gr.Column(visible=False) as clear_button_pose:
315
+ remove_and_reupload_poses = gr.ClearButton(value="Remove and upload new ones", components=pose_files, size="sm")
316
+
317
+ # prompt
318
+ prompt = gr.Textbox(label="Prompt",
319
+ info="Give simple prompt is enough to achieve good face fedility",
320
+ placeholder="A photo of a person",
321
+ value="")
322
+
323
+ submit = gr.Button("Submit", variant="primary")
324
+
325
+ style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
326
+
327
+ # strength
328
+ identitynet_strength_ratio = gr.Slider(
329
+ label="IdentityNet strength (for fedility)",
330
+ minimum=0,
331
+ maximum=1.5,
332
+ step=0.05,
333
+ value=0.80,
334
+ )
335
+ adapter_strength_ratio = gr.Slider(
336
+ label="Image adapter strength (for detail)",
337
+ minimum=0,
338
+ maximum=1.5,
339
+ step=0.05,
340
+ value=0.80,
341
+ )
342
+
343
+ with gr.Accordion(open=False, label="Advanced Options"):
344
+ negative_prompt = gr.Textbox(
345
+ label="Negative Prompt",
346
+ placeholder="low quality",
347
+ value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
348
+ )
349
+ num_steps = gr.Slider(
350
+ label="Number of sample steps",
351
+ minimum=20,
352
+ maximum=100,
353
+ step=1,
354
+ value=30,
355
+ )
356
+ guidance_scale = gr.Slider(
357
+ label="Guidance scale",
358
+ minimum=0.1,
359
+ maximum=10.0,
360
+ step=0.1,
361
+ value=5,
362
+ )
363
+ seed = gr.Slider(
364
+ label="Seed",
365
+ minimum=0,
366
+ maximum=MAX_SEED,
367
+ step=1,
368
+ value=42,
369
+ )
370
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
371
+ enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
372
+
373
+ with gr.Column():
374
+ gallery = gr.Gallery(label="Generated Images")
375
+ usage_tips = gr.Markdown(label="Usage tips of InstantID", value=tips ,visible=False)
376
+
377
+ face_files.upload(fn=swap_to_gallery, inputs=face_files, outputs=[uploaded_faces, clear_button_face, face_files])
378
+ pose_files.upload(fn=swap_to_gallery, inputs=pose_files, outputs=[uploaded_poses, clear_button_pose, pose_files])
379
+
380
+ remove_and_reupload_faces.click(fn=remove_back_to_files, outputs=[uploaded_faces, clear_button_face, face_files])
381
+ remove_and_reupload_poses.click(fn=remove_back_to_files, outputs=[uploaded_poses, clear_button_pose, pose_files])
382
+
383
+ submit.click(
384
+ fn=remove_tips,
385
+ outputs=usage_tips,
386
+ ).then(
387
+ fn=randomize_seed_fn,
388
+ inputs=[seed, randomize_seed],
389
+ outputs=seed,
390
+ queue=False,
391
+ api_name=False,
392
+ ).then(
393
+ fn=generate_image,
394
+ inputs=[face_files, pose_files, prompt, negative_prompt, style, enhance_face_region, num_steps, identitynet_strength_ratio, adapter_strength_ratio, guidance_scale, seed],
395
+ outputs=[gallery, usage_tips]
396
+ )
397
+
398
+ gr.Examples(
399
+ examples=get_example(),
400
+ inputs=[face_files, prompt, style, negative_prompt],
401
+ run_on_click=True,
402
+ fn=upload_example_to_gallery,
403
+ outputs=[uploaded_faces, clear_button_face, face_files],
404
+ )
405
+
406
+ gr.Markdown(article)
407
+
408
+ demo.launch()
examples/kaifu_resize.png ADDED

Git LFS Details

  • SHA256: b7302f0f7d0ff61be67bf13d172ad2393b6cb2bc985f048089f4e901145324d7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
examples/musk_resize.jpeg ADDED
examples/sam_resize.png ADDED

Git LFS Details

  • SHA256: 1390d8a9a1be7b8f5388c0bc8483b2d5cca6c0f0adeb6eecd970a4413b1f1deb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
examples/schmidhuber_resize.png ADDED

Git LFS Details

  • SHA256: 51beaa72d1eb9f56118118fae8775bda818bcb56b03220f3cd39daa425f57a9a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.23 MB
examples/yann-lecun_resize.jpg ADDED
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ import xformers
8
+ import xformers.ops
9
+ xformers_available = True
10
+ except Exception as e:
11
+ xformers_available = False
12
+
13
+
14
+
15
+ class RegionControler(object):
16
+ def __init__(self) -> None:
17
+ self.prompt_image_conditioning = []
18
+ region_control = RegionControler()
19
+
20
+
21
+ class AttnProcessor(nn.Module):
22
+ r"""
23
+ Default processor for performing attention-related computations.
24
+ """
25
+ def __init__(
26
+ self,
27
+ hidden_size=None,
28
+ cross_attention_dim=None,
29
+ ):
30
+ super().__init__()
31
+
32
+ def __call__(
33
+ self,
34
+ attn,
35
+ hidden_states,
36
+ encoder_hidden_states=None,
37
+ attention_mask=None,
38
+ temb=None,
39
+ ):
40
+ residual = hidden_states
41
+
42
+ if attn.spatial_norm is not None:
43
+ hidden_states = attn.spatial_norm(hidden_states, temb)
44
+
45
+ input_ndim = hidden_states.ndim
46
+
47
+ if input_ndim == 4:
48
+ batch_size, channel, height, width = hidden_states.shape
49
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
50
+
51
+ batch_size, sequence_length, _ = (
52
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53
+ )
54
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
55
+
56
+ if attn.group_norm is not None:
57
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58
+
59
+ query = attn.to_q(hidden_states)
60
+
61
+ if encoder_hidden_states is None:
62
+ encoder_hidden_states = hidden_states
63
+ elif attn.norm_cross:
64
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65
+
66
+ key = attn.to_k(encoder_hidden_states)
67
+ value = attn.to_v(encoder_hidden_states)
68
+
69
+ query = attn.head_to_batch_dim(query)
70
+ key = attn.head_to_batch_dim(key)
71
+ value = attn.head_to_batch_dim(value)
72
+
73
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
74
+ hidden_states = torch.bmm(attention_probs, value)
75
+ hidden_states = attn.batch_to_head_dim(hidden_states)
76
+
77
+ # linear proj
78
+ hidden_states = attn.to_out[0](hidden_states)
79
+ # dropout
80
+ hidden_states = attn.to_out[1](hidden_states)
81
+
82
+ if input_ndim == 4:
83
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
84
+
85
+ if attn.residual_connection:
86
+ hidden_states = hidden_states + residual
87
+
88
+ hidden_states = hidden_states / attn.rescale_output_factor
89
+
90
+ return hidden_states
91
+
92
+
93
+ class IPAttnProcessor(nn.Module):
94
+ r"""
95
+ Attention processor for IP-Adapater.
96
+ Args:
97
+ hidden_size (`int`):
98
+ The hidden size of the attention layer.
99
+ cross_attention_dim (`int`):
100
+ The number of channels in the `encoder_hidden_states`.
101
+ scale (`float`, defaults to 1.0):
102
+ the weight scale of image prompt.
103
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104
+ The context length of the image features.
105
+ """
106
+
107
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
108
+ super().__init__()
109
+
110
+ self.hidden_size = hidden_size
111
+ self.cross_attention_dim = cross_attention_dim
112
+ self.scale = scale
113
+ self.num_tokens = num_tokens
114
+
115
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
116
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
117
+
118
+ def __call__(
119
+ self,
120
+ attn,
121
+ hidden_states,
122
+ encoder_hidden_states=None,
123
+ attention_mask=None,
124
+ temb=None,
125
+ ):
126
+ residual = hidden_states
127
+
128
+ if attn.spatial_norm is not None:
129
+ hidden_states = attn.spatial_norm(hidden_states, temb)
130
+
131
+ input_ndim = hidden_states.ndim
132
+
133
+ if input_ndim == 4:
134
+ batch_size, channel, height, width = hidden_states.shape
135
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
136
+
137
+ batch_size, sequence_length, _ = (
138
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
139
+ )
140
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
141
+
142
+ if attn.group_norm is not None:
143
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
144
+
145
+ query = attn.to_q(hidden_states)
146
+
147
+ if encoder_hidden_states is None:
148
+ encoder_hidden_states = hidden_states
149
+ else:
150
+ # get encoder_hidden_states, ip_hidden_states
151
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
152
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
153
+ if attn.norm_cross:
154
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
155
+
156
+ key = attn.to_k(encoder_hidden_states)
157
+ value = attn.to_v(encoder_hidden_states)
158
+
159
+ query = attn.head_to_batch_dim(query)
160
+ key = attn.head_to_batch_dim(key)
161
+ value = attn.head_to_batch_dim(value)
162
+
163
+ if xformers_available:
164
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
165
+ else:
166
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
167
+ hidden_states = torch.bmm(attention_probs, value)
168
+ hidden_states = attn.batch_to_head_dim(hidden_states)
169
+
170
+ # for ip-adapter
171
+ ip_key = self.to_k_ip(ip_hidden_states)
172
+ ip_value = self.to_v_ip(ip_hidden_states)
173
+
174
+ ip_key = attn.head_to_batch_dim(ip_key)
175
+ ip_value = attn.head_to_batch_dim(ip_value)
176
+
177
+ if xformers_available:
178
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
179
+ else:
180
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
181
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
182
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
183
+
184
+ # region control
185
+ if len(region_control.prompt_image_conditioning) == 1:
186
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
187
+ if region_mask is not None:
188
+ h, w = region_mask.shape[:2]
189
+ ratio = (h * w / query.shape[1]) ** 0.5
190
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
191
+ else:
192
+ mask = torch.ones_like(ip_hidden_states)
193
+ ip_hidden_states = ip_hidden_states * mask
194
+
195
+ hidden_states = hidden_states + self.scale * ip_hidden_states
196
+
197
+ # linear proj
198
+ hidden_states = attn.to_out[0](hidden_states)
199
+ # dropout
200
+ hidden_states = attn.to_out[1](hidden_states)
201
+
202
+ if input_ndim == 4:
203
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
204
+
205
+ if attn.residual_connection:
206
+ hidden_states = hidden_states + residual
207
+
208
+ hidden_states = hidden_states / attn.rescale_output_factor
209
+
210
+ return hidden_states
211
+
212
+
213
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
214
+ # TODO attention_mask
215
+ query = query.contiguous()
216
+ key = key.contiguous()
217
+ value = value.contiguous()
218
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
219
+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
220
+ return hidden_states
221
+
222
+
223
+ class AttnProcessor2_0(torch.nn.Module):
224
+ r"""
225
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
226
+ """
227
+ def __init__(
228
+ self,
229
+ hidden_size=None,
230
+ cross_attention_dim=None,
231
+ ):
232
+ super().__init__()
233
+ if not hasattr(F, "scaled_dot_product_attention"):
234
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
235
+
236
+ def __call__(
237
+ self,
238
+ attn,
239
+ hidden_states,
240
+ encoder_hidden_states=None,
241
+ attention_mask=None,
242
+ temb=None,
243
+ ):
244
+ residual = hidden_states
245
+
246
+ if attn.spatial_norm is not None:
247
+ hidden_states = attn.spatial_norm(hidden_states, temb)
248
+
249
+ input_ndim = hidden_states.ndim
250
+
251
+ if input_ndim == 4:
252
+ batch_size, channel, height, width = hidden_states.shape
253
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
254
+
255
+ batch_size, sequence_length, _ = (
256
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
257
+ )
258
+
259
+ if attention_mask is not None:
260
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261
+ # scaled_dot_product_attention expects attention_mask shape to be
262
+ # (batch, heads, source_length, target_length)
263
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
264
+
265
+ if attn.group_norm is not None:
266
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
267
+
268
+ query = attn.to_q(hidden_states)
269
+
270
+ if encoder_hidden_states is None:
271
+ encoder_hidden_states = hidden_states
272
+ elif attn.norm_cross:
273
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
274
+
275
+ key = attn.to_k(encoder_hidden_states)
276
+ value = attn.to_v(encoder_hidden_states)
277
+
278
+ inner_dim = key.shape[-1]
279
+ head_dim = inner_dim // attn.heads
280
+
281
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
+
283
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
284
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
285
+
286
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
287
+ # TODO: add support for attn.scale when we move to Torch 2.1
288
+ hidden_states = F.scaled_dot_product_attention(
289
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
290
+ )
291
+
292
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
293
+ hidden_states = hidden_states.to(query.dtype)
294
+
295
+ # linear proj
296
+ hidden_states = attn.to_out[0](hidden_states)
297
+ # dropout
298
+ hidden_states = attn.to_out[1](hidden_states)
299
+
300
+ if input_ndim == 4:
301
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
302
+
303
+ if attn.residual_connection:
304
+ hidden_states = hidden_states + residual
305
+
306
+ hidden_states = hidden_states / attn.rescale_output_factor
307
+
308
+ return hidden_states
ip_adapter/resampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class Resampler(nn.Module):
79
+ def __init__(
80
+ self,
81
+ dim=1024,
82
+ depth=8,
83
+ dim_head=64,
84
+ heads=16,
85
+ num_queries=8,
86
+ embedding_dim=768,
87
+ output_dim=1024,
88
+ ff_mult=4,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
+
94
+ self.proj_in = nn.Linear(embedding_dim, dim)
95
+
96
+ self.proj_out = nn.Linear(dim, output_dim)
97
+ self.norm_out = nn.LayerNorm(output_dim)
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult),
106
+ ]
107
+ )
108
+ )
109
+
110
+ def forward(self, x):
111
+
112
+ latents = self.latents.repeat(x.size(0), 1, 1)
113
+
114
+ x = self.proj_in(x)
115
+
116
+ for attn, ff in self.layers:
117
+ latents = attn(x, latents) + latents
118
+ latents = ff(latents) + latents
119
+
120
+ latents = self.proj_out(latents)
121
+ return self.norm_out(latents)
ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")
models/antelopev2/1k3d68.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
3
+ size 143607619
models/antelopev2/2d106det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
3
+ size 5030888
models/antelopev2/genderage.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
3
+ size 1322532
models/antelopev2/glintr100.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf
3
+ size 260665334
models/antelopev2/scrfd_10g_bnkps.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
pipeline_stable_diffusion_xl_instantid.py ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import cv2
19
+ import math
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from transformers import CLIPTokenizer
26
+
27
+ from diffusers.image_processor import PipelineImageInput
28
+
29
+ from diffusers.models import ControlNetModel
30
+
31
+ from diffusers.utils import (
32
+ deprecate,
33
+ logging,
34
+ replace_example_docstring,
35
+ )
36
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
37
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
38
+
39
+ from diffusers import StableDiffusionXLControlNetPipeline
40
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
41
+ from diffusers.utils.import_utils import is_xformers_available
42
+
43
+ from ip_adapter.resampler import Resampler
44
+
45
+ from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
46
+ from ip_adapter.attention_processor import region_control
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ EXAMPLE_DOC_STRING = """
52
+ Examples:
53
+ ```py
54
+ >>> # !pip install opencv-python transformers accelerate insightface
55
+ >>> import diffusers
56
+ >>> from diffusers.utils import load_image
57
+ >>> from diffusers.models import ControlNetModel
58
+
59
+ >>> import cv2
60
+ >>> import torch
61
+ >>> import numpy as np
62
+ >>> from PIL import Image
63
+
64
+ >>> from insightface.app import FaceAnalysis
65
+ >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
66
+
67
+ >>> # download 'antelopev2' under ./models
68
+ >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
69
+ >>> app.prepare(ctx_id=0, det_size=(640, 640))
70
+
71
+ >>> # download models under ./checkpoints
72
+ >>> face_adapter = f'./checkpoints/ip-adapter.bin'
73
+ >>> controlnet_path = f'./checkpoints/ControlNetModel'
74
+
75
+ >>> # load IdentityNet
76
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
77
+
78
+ >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
79
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
80
+ ... )
81
+ >>> pipe.cuda()
82
+
83
+ >>> # load adapter
84
+ >>> pipe.load_ip_adapter_instantid(face_adapter)
85
+
86
+ >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
87
+ >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
88
+
89
+ >>> # load an image
90
+ >>> image = load_image("your-example.jpg")
91
+
92
+ >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
93
+ >>> face_emb = face_info['embedding']
94
+ >>> face_kps = draw_kps(face_image, face_info['kps'])
95
+
96
+ >>> pipe.set_ip_adapter_scale(0.8)
97
+
98
+ >>> # generate image
99
+ >>> image = pipe(
100
+ ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
101
+ ... ).images[0]
102
+ ```
103
+ """
104
+
105
+
106
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
107
+ class LongPromptWeight(object):
108
+
109
+ """
110
+ Copied from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py
111
+ """
112
+
113
+ def __init__(self) -> None:
114
+ pass
115
+
116
+ def parse_prompt_attention(self, text):
117
+ """
118
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
119
+ Accepted tokens are:
120
+ (abc) - increases attention to abc by a multiplier of 1.1
121
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
122
+ [abc] - decreases attention to abc by a multiplier of 1.1
123
+ \( - literal character '('
124
+ \[ - literal character '['
125
+ \) - literal character ')'
126
+ \] - literal character ']'
127
+ \\ - literal character '\'
128
+ anything else - just text
129
+
130
+ >>> parse_prompt_attention('normal text')
131
+ [['normal text', 1.0]]
132
+ >>> parse_prompt_attention('an (important) word')
133
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
134
+ >>> parse_prompt_attention('(unbalanced')
135
+ [['unbalanced', 1.1]]
136
+ >>> parse_prompt_attention('\(literal\]')
137
+ [['(literal]', 1.0]]
138
+ >>> parse_prompt_attention('(unnecessary)(parens)')
139
+ [['unnecessaryparens', 1.1]]
140
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
141
+ [['a ', 1.0],
142
+ ['house', 1.5730000000000004],
143
+ [' ', 1.1],
144
+ ['on', 1.0],
145
+ [' a ', 1.1],
146
+ ['hill', 0.55],
147
+ [', sun, ', 1.1],
148
+ ['sky', 1.4641000000000006],
149
+ ['.', 1.1]]
150
+ """
151
+ import re
152
+
153
+ re_attention = re.compile(
154
+ r"""
155
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
156
+ \)|]|[^\\()\[\]:]+|:
157
+ """,
158
+ re.X,
159
+ )
160
+
161
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
162
+
163
+ res = []
164
+ round_brackets = []
165
+ square_brackets = []
166
+
167
+ round_bracket_multiplier = 1.1
168
+ square_bracket_multiplier = 1 / 1.1
169
+
170
+ def multiply_range(start_position, multiplier):
171
+ for p in range(start_position, len(res)):
172
+ res[p][1] *= multiplier
173
+
174
+ for m in re_attention.finditer(text):
175
+ text = m.group(0)
176
+ weight = m.group(1)
177
+
178
+ if text.startswith("\\"):
179
+ res.append([text[1:], 1.0])
180
+ elif text == "(":
181
+ round_brackets.append(len(res))
182
+ elif text == "[":
183
+ square_brackets.append(len(res))
184
+ elif weight is not None and len(round_brackets) > 0:
185
+ multiply_range(round_brackets.pop(), float(weight))
186
+ elif text == ")" and len(round_brackets) > 0:
187
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
188
+ elif text == "]" and len(square_brackets) > 0:
189
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
190
+ else:
191
+ parts = re.split(re_break, text)
192
+ for i, part in enumerate(parts):
193
+ if i > 0:
194
+ res.append(["BREAK", -1])
195
+ res.append([part, 1.0])
196
+
197
+ for pos in round_brackets:
198
+ multiply_range(pos, round_bracket_multiplier)
199
+
200
+ for pos in square_brackets:
201
+ multiply_range(pos, square_bracket_multiplier)
202
+
203
+ if len(res) == 0:
204
+ res = [["", 1.0]]
205
+
206
+ # merge runs of identical weights
207
+ i = 0
208
+ while i + 1 < len(res):
209
+ if res[i][1] == res[i + 1][1]:
210
+ res[i][0] += res[i + 1][0]
211
+ res.pop(i + 1)
212
+ else:
213
+ i += 1
214
+
215
+ return res
216
+
217
+ def get_prompts_tokens_with_weights(self, clip_tokenizer: CLIPTokenizer, prompt: str):
218
+ """
219
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
220
+
221
+ Args:
222
+ pipe (CLIPTokenizer)
223
+ A CLIPTokenizer
224
+ prompt (str)
225
+ A prompt string with weights
226
+
227
+ Returns:
228
+ text_tokens (list)
229
+ A list contains token ids
230
+ text_weight (list)
231
+ A list contains the correspodent weight of token ids
232
+
233
+ Example:
234
+ import torch
235
+ from transformers import CLIPTokenizer
236
+
237
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
238
+ "stablediffusionapi/deliberate-v2"
239
+ , subfolder = "tokenizer"
240
+ , dtype = torch.float16
241
+ )
242
+
243
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
244
+ clip_tokenizer = clip_tokenizer
245
+ ,prompt = "a (red:1.5) cat"*70
246
+ )
247
+ """
248
+ texts_and_weights = self.parse_prompt_attention(prompt)
249
+ text_tokens, text_weights = [], []
250
+ for word, weight in texts_and_weights:
251
+ # tokenize and discard the starting and the ending token
252
+ token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
253
+ # the returned token is a 1d list: [320, 1125, 539, 320]
254
+
255
+ # merge the new tokens to the all tokens holder: text_tokens
256
+ text_tokens = [*text_tokens, *token]
257
+
258
+ # each token chunk will come with one weight, like ['red cat', 2.0]
259
+ # need to expand weight for each token.
260
+ chunk_weights = [weight] * len(token)
261
+
262
+ # append the weight back to the weight holder: text_weights
263
+ text_weights = [*text_weights, *chunk_weights]
264
+ return text_tokens, text_weights
265
+
266
+ def group_tokens_and_weights(self, token_ids: list, weights: list, pad_last_block=False):
267
+ """
268
+ Produce tokens and weights in groups and pad the missing tokens
269
+
270
+ Args:
271
+ token_ids (list)
272
+ The token ids from tokenizer
273
+ weights (list)
274
+ The weights list from function get_prompts_tokens_with_weights
275
+ pad_last_block (bool)
276
+ Control if fill the last token list to 75 tokens with eos
277
+ Returns:
278
+ new_token_ids (2d list)
279
+ new_weights (2d list)
280
+
281
+ Example:
282
+ token_groups,weight_groups = group_tokens_and_weights(
283
+ token_ids = token_id_list
284
+ , weights = token_weight_list
285
+ )
286
+ """
287
+ bos, eos = 49406, 49407
288
+
289
+ # this will be a 2d list
290
+ new_token_ids = []
291
+ new_weights = []
292
+ while len(token_ids) >= 75:
293
+ # get the first 75 tokens
294
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
295
+ head_75_weights = [weights.pop(0) for _ in range(75)]
296
+
297
+ # extract token ids and weights
298
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
299
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
300
+
301
+ # add 77 token and weights chunk to the holder list
302
+ new_token_ids.append(temp_77_token_ids)
303
+ new_weights.append(temp_77_weights)
304
+
305
+ # padding the left
306
+ if len(token_ids) >= 0:
307
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
308
+
309
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
310
+ new_token_ids.append(temp_77_token_ids)
311
+
312
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
313
+ new_weights.append(temp_77_weights)
314
+
315
+ return new_token_ids, new_weights
316
+
317
+ def get_weighted_text_embeddings_sdxl(
318
+ self,
319
+ pipe: StableDiffusionXLPipeline,
320
+ prompt: str = "",
321
+ prompt_2: str = None,
322
+ neg_prompt: str = "",
323
+ neg_prompt_2: str = None,
324
+ prompt_embeds=None,
325
+ negative_prompt_embeds=None,
326
+ pooled_prompt_embeds=None,
327
+ negative_pooled_prompt_embeds=None,
328
+ extra_emb=None,
329
+ extra_emb_alpha=0.6,
330
+ ):
331
+ """
332
+ This function can process long prompt with weights, no length limitation
333
+ for Stable Diffusion XL
334
+
335
+ Args:
336
+ pipe (StableDiffusionPipeline)
337
+ prompt (str)
338
+ prompt_2 (str)
339
+ neg_prompt (str)
340
+ neg_prompt_2 (str)
341
+ Returns:
342
+ prompt_embeds (torch.Tensor)
343
+ neg_prompt_embeds (torch.Tensor)
344
+ """
345
+ #
346
+ if prompt_embeds is not None and \
347
+ negative_prompt_embeds is not None and \
348
+ pooled_prompt_embeds is not None and \
349
+ negative_pooled_prompt_embeds is not None:
350
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
351
+
352
+ if prompt_2:
353
+ prompt = f"{prompt} {prompt_2}"
354
+
355
+ if neg_prompt_2:
356
+ neg_prompt = f"{neg_prompt} {neg_prompt_2}"
357
+
358
+ eos = pipe.tokenizer.eos_token_id
359
+
360
+ # tokenizer 1
361
+ prompt_tokens, prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
362
+ neg_prompt_tokens, neg_prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
363
+
364
+ # tokenizer 2
365
+ # prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
366
+ # neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
367
+ # tokenizer 2 遇到 !! !!!! 等多感叹号和tokenizer 1的效果不一致
368
+ prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
369
+ neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
370
+
371
+ # padding the shorter one for prompt set 1
372
+ prompt_token_len = len(prompt_tokens)
373
+ neg_prompt_token_len = len(neg_prompt_tokens)
374
+
375
+ if prompt_token_len > neg_prompt_token_len:
376
+ # padding the neg_prompt with eos token
377
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
378
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
379
+ else:
380
+ # padding the prompt
381
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
382
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
383
+
384
+ # padding the shorter one for token set 2
385
+ prompt_token_len_2 = len(prompt_tokens_2)
386
+ neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
387
+
388
+ if prompt_token_len_2 > neg_prompt_token_len_2:
389
+ # padding the neg_prompt with eos token
390
+ neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
391
+ neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
392
+ else:
393
+ # padding the prompt
394
+ prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
395
+ prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
396
+
397
+ embeds = []
398
+ neg_embeds = []
399
+
400
+ prompt_token_groups, prompt_weight_groups = self.group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
401
+
402
+ neg_prompt_token_groups, neg_prompt_weight_groups = self.group_tokens_and_weights(
403
+ neg_prompt_tokens.copy(), neg_prompt_weights.copy()
404
+ )
405
+
406
+ prompt_token_groups_2, prompt_weight_groups_2 = self.group_tokens_and_weights(
407
+ prompt_tokens_2.copy(), prompt_weights_2.copy()
408
+ )
409
+
410
+ neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = self.group_tokens_and_weights(
411
+ neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
412
+ )
413
+
414
+ # get prompt embeddings one by one is not working.
415
+ for i in range(len(prompt_token_groups)):
416
+ # get positive prompt embeddings with weights
417
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
418
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
419
+
420
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
421
+
422
+ # use first text encoder
423
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
424
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
425
+
426
+ # use second text encoder
427
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
428
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
429
+ pooled_prompt_embeds = prompt_embeds_2[0]
430
+
431
+ prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
432
+ token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
433
+
434
+ for j in range(len(weight_tensor)):
435
+ if weight_tensor[j] != 1.0:
436
+ token_embedding[j] = (
437
+ token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
438
+ )
439
+
440
+ token_embedding = token_embedding.unsqueeze(0)
441
+ embeds.append(token_embedding)
442
+
443
+ # get negative prompt embeddings with weights
444
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
445
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
446
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
447
+
448
+ # use first text encoder
449
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
450
+ neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
451
+
452
+ # use second text encoder
453
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
454
+ neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
455
+ negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
456
+
457
+ neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
458
+ neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
459
+
460
+ for z in range(len(neg_weight_tensor)):
461
+ if neg_weight_tensor[z] != 1.0:
462
+ neg_token_embedding[z] = (
463
+ neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
464
+ )
465
+
466
+ neg_token_embedding = neg_token_embedding.unsqueeze(0)
467
+ neg_embeds.append(neg_token_embedding)
468
+
469
+ prompt_embeds = torch.cat(embeds, dim=1)
470
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
471
+
472
+ if extra_emb is not None:
473
+ extra_emb = extra_emb.to(prompt_embeds.device, dtype=prompt_embeds.dtype) * extra_emb_alpha
474
+ prompt_embeds = torch.cat([prompt_embeds, extra_emb], 1)
475
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, torch.zeros_like(extra_emb)], 1)
476
+ print(f'fix prompt_embeds, extra_emb_alpha={extra_emb_alpha}')
477
+
478
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
479
+
480
+ def get_prompt_embeds(self, *args, **kwargs):
481
+ prompt_embeds, negative_prompt_embeds, _, _ = self.get_weighted_text_embeddings_sdxl(*args, **kwargs)
482
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
483
+ return prompt_embeds
484
+
485
+
486
+ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
487
+
488
+ def cuda(self, dtype=torch.float16, use_xformers=False):
489
+ self.to('cuda', dtype)
490
+
491
+ if hasattr(self, 'image_proj_model'):
492
+ self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
493
+
494
+ if use_xformers:
495
+ if is_xformers_available():
496
+ import xformers
497
+ from packaging import version
498
+
499
+ xformers_version = version.parse(xformers.__version__)
500
+ if xformers_version == version.parse("0.0.16"):
501
+ logger.warn(
502
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
503
+ )
504
+ self.enable_xformers_memory_efficient_attention()
505
+ else:
506
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
507
+
508
+ def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
509
+ self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
510
+ self.set_ip_adapter(model_ckpt, num_tokens, scale)
511
+
512
+ def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
513
+
514
+ image_proj_model = Resampler(
515
+ dim=1280,
516
+ depth=4,
517
+ dim_head=64,
518
+ heads=20,
519
+ num_queries=num_tokens,
520
+ embedding_dim=image_emb_dim,
521
+ output_dim=self.unet.config.cross_attention_dim,
522
+ ff_mult=4,
523
+ )
524
+
525
+ image_proj_model.eval()
526
+
527
+ self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
528
+ state_dict = torch.load(model_ckpt, map_location="cpu")
529
+ if 'image_proj' in state_dict:
530
+ state_dict = state_dict["image_proj"]
531
+ self.image_proj_model.load_state_dict(state_dict)
532
+
533
+ self.image_proj_model_in_features = image_emb_dim
534
+
535
+ def set_ip_adapter(self, model_ckpt, num_tokens, scale):
536
+
537
+ unet = self.unet
538
+ attn_procs = {}
539
+ for name in unet.attn_processors.keys():
540
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
541
+ if name.startswith("mid_block"):
542
+ hidden_size = unet.config.block_out_channels[-1]
543
+ elif name.startswith("up_blocks"):
544
+ block_id = int(name[len("up_blocks.")])
545
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
546
+ elif name.startswith("down_blocks"):
547
+ block_id = int(name[len("down_blocks.")])
548
+ hidden_size = unet.config.block_out_channels[block_id]
549
+ if cross_attention_dim is None:
550
+ attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
551
+ else:
552
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
553
+ cross_attention_dim=cross_attention_dim,
554
+ scale=scale,
555
+ num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
556
+ unet.set_attn_processor(attn_procs)
557
+
558
+ state_dict = torch.load(model_ckpt, map_location="cpu")
559
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
560
+ if 'ip_adapter' in state_dict:
561
+ state_dict = state_dict['ip_adapter']
562
+ ip_layers.load_state_dict(state_dict)
563
+
564
+ def set_ip_adapter_scale(self, scale):
565
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
566
+ for attn_processor in unet.attn_processors.values():
567
+ if isinstance(attn_processor, IPAttnProcessor):
568
+ attn_processor.scale = scale
569
+
570
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
571
+
572
+ if isinstance(prompt_image_emb, torch.Tensor):
573
+ prompt_image_emb = prompt_image_emb.clone().detach()
574
+ else:
575
+ prompt_image_emb = torch.tensor(prompt_image_emb)
576
+
577
+ prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
578
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
579
+
580
+ if do_classifier_free_guidance:
581
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
582
+ else:
583
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
584
+
585
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
586
+ return prompt_image_emb
587
+
588
+ @torch.no_grad()
589
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
590
+ def __call__(
591
+ self,
592
+ prompt: Union[str, List[str]] = None,
593
+ prompt_2: Optional[Union[str, List[str]]] = None,
594
+ image: PipelineImageInput = None,
595
+ height: Optional[int] = None,
596
+ width: Optional[int] = None,
597
+ num_inference_steps: int = 50,
598
+ guidance_scale: float = 5.0,
599
+ negative_prompt: Optional[Union[str, List[str]]] = None,
600
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
601
+ num_images_per_prompt: Optional[int] = 1,
602
+ eta: float = 0.0,
603
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
604
+ latents: Optional[torch.FloatTensor] = None,
605
+ prompt_embeds: Optional[torch.FloatTensor] = None,
606
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
607
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
608
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
609
+ image_embeds: Optional[torch.FloatTensor] = None,
610
+ output_type: Optional[str] = "pil",
611
+ return_dict: bool = True,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
614
+ guess_mode: bool = False,
615
+ control_guidance_start: Union[float, List[float]] = 0.0,
616
+ control_guidance_end: Union[float, List[float]] = 1.0,
617
+ original_size: Tuple[int, int] = None,
618
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
619
+ target_size: Tuple[int, int] = None,
620
+ negative_original_size: Optional[Tuple[int, int]] = None,
621
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
622
+ negative_target_size: Optional[Tuple[int, int]] = None,
623
+ clip_skip: Optional[int] = None,
624
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
625
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
626
+ control_mask = None,
627
+ **kwargs,
628
+ ):
629
+ r"""
630
+ The call function to the pipeline for generation.
631
+
632
+ Args:
633
+ prompt (`str` or `List[str]`, *optional*):
634
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
635
+ prompt_2 (`str` or `List[str]`, *optional*):
636
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
637
+ used in both text-encoders.
638
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
639
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
640
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
641
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
642
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
643
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
644
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
645
+ input to a single ControlNet.
646
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
647
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
648
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
649
+ and checkpoints that are not specifically fine-tuned on low resolutions.
650
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
651
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
652
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
653
+ and checkpoints that are not specifically fine-tuned on low resolutions.
654
+ num_inference_steps (`int`, *optional*, defaults to 50):
655
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
656
+ expense of slower inference.
657
+ guidance_scale (`float`, *optional*, defaults to 5.0):
658
+ A higher guidance scale value encourages the model to generate images closely linked to the text
659
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
660
+ negative_prompt (`str` or `List[str]`, *optional*):
661
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
662
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
663
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
664
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
665
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
666
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
667
+ The number of images to generate per prompt.
668
+ eta (`float`, *optional*, defaults to 0.0):
669
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
670
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
671
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
672
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
673
+ generation deterministic.
674
+ latents (`torch.FloatTensor`, *optional*):
675
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
676
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
677
+ tensor is generated by sampling using the supplied random `generator`.
678
+ prompt_embeds (`torch.FloatTensor`, *optional*):
679
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
680
+ provided, text embeddings are generated from the `prompt` input argument.
681
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
682
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
683
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
684
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
685
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
686
+ not provided, pooled text embeddings are generated from `prompt` input argument.
687
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
688
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
689
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
690
+ argument.
691
+ image_embeds (`torch.FloatTensor`, *optional*):
692
+ Pre-generated image embeddings.
693
+ output_type (`str`, *optional*, defaults to `"pil"`):
694
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
695
+ return_dict (`bool`, *optional*, defaults to `True`):
696
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
697
+ plain tuple.
698
+ cross_attention_kwargs (`dict`, *optional*):
699
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
700
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
701
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
702
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
703
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
704
+ the corresponding scale as a list.
705
+ guess_mode (`bool`, *optional*, defaults to `False`):
706
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
707
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
708
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
709
+ The percentage of total steps at which the ControlNet starts applying.
710
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
711
+ The percentage of total steps at which the ControlNet stops applying.
712
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
713
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
714
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
715
+ explained in section 2.2 of
716
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
717
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
718
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
719
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
720
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
721
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
722
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
723
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
724
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
725
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
726
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
727
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
728
+ micro-conditioning as explained in section 2.2 of
729
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
730
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
731
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
732
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
733
+ micro-conditioning as explained in section 2.2 of
734
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
735
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
736
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
737
+ To negatively condition the generation process based on a target image resolution. It should be as same
738
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
739
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
740
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
741
+ clip_skip (`int`, *optional*):
742
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
743
+ the output of the pre-final layer will be used for computing the prompt embeddings.
744
+ callback_on_step_end (`Callable`, *optional*):
745
+ A function that calls at the end of each denoising steps during the inference. The function is called
746
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
747
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
748
+ `callback_on_step_end_tensor_inputs`.
749
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
750
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
751
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
752
+ `._callback_tensor_inputs` attribute of your pipeine class.
753
+
754
+ Examples:
755
+
756
+ Returns:
757
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
758
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
759
+ otherwise a `tuple` is returned containing the output images.
760
+ """
761
+ lpw = LongPromptWeight()
762
+
763
+ callback = kwargs.pop("callback", None)
764
+ callback_steps = kwargs.pop("callback_steps", None)
765
+
766
+ if callback is not None:
767
+ deprecate(
768
+ "callback",
769
+ "1.0.0",
770
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
771
+ )
772
+ if callback_steps is not None:
773
+ deprecate(
774
+ "callback_steps",
775
+ "1.0.0",
776
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
777
+ )
778
+
779
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
780
+
781
+ # align format for control guidance
782
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
783
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
784
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
785
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
786
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
787
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
788
+ control_guidance_start, control_guidance_end = (
789
+ mult * [control_guidance_start],
790
+ mult * [control_guidance_end],
791
+ )
792
+
793
+ # 1. Check inputs. Raise error if not correct
794
+ self.check_inputs(
795
+ prompt,
796
+ prompt_2,
797
+ image,
798
+ callback_steps,
799
+ negative_prompt,
800
+ negative_prompt_2,
801
+ prompt_embeds,
802
+ negative_prompt_embeds,
803
+ pooled_prompt_embeds,
804
+ negative_pooled_prompt_embeds,
805
+ controlnet_conditioning_scale,
806
+ control_guidance_start,
807
+ control_guidance_end,
808
+ callback_on_step_end_tensor_inputs,
809
+ )
810
+
811
+ self._guidance_scale = guidance_scale
812
+ self._clip_skip = clip_skip
813
+ self._cross_attention_kwargs = cross_attention_kwargs
814
+
815
+ # 2. Define call parameters
816
+ if prompt is not None and isinstance(prompt, str):
817
+ batch_size = 1
818
+ elif prompt is not None and isinstance(prompt, list):
819
+ batch_size = len(prompt)
820
+ else:
821
+ batch_size = prompt_embeds.shape[0]
822
+
823
+ device = self._execution_device
824
+
825
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
826
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
827
+
828
+ global_pool_conditions = (
829
+ controlnet.config.global_pool_conditions
830
+ if isinstance(controlnet, ControlNetModel)
831
+ else controlnet.nets[0].config.global_pool_conditions
832
+ )
833
+ guess_mode = guess_mode or global_pool_conditions
834
+
835
+ # 3.1 Encode input prompt
836
+ (
837
+ prompt_embeds,
838
+ negative_prompt_embeds,
839
+ pooled_prompt_embeds,
840
+ negative_pooled_prompt_embeds,
841
+ ) = lpw.get_weighted_text_embeddings_sdxl(
842
+ pipe=self,
843
+ prompt=prompt,
844
+ neg_prompt=negative_prompt,
845
+ prompt_embeds=prompt_embeds,
846
+ negative_prompt_embeds=negative_prompt_embeds,
847
+ pooled_prompt_embeds=pooled_prompt_embeds,
848
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
849
+ )
850
+
851
+ # 3.2 Encode image prompt
852
+ prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
853
+ device,
854
+ self.unet.dtype,
855
+ self.do_classifier_free_guidance)
856
+
857
+ # 4. Prepare image
858
+ if isinstance(controlnet, ControlNetModel):
859
+ image = self.prepare_image(
860
+ image=image,
861
+ width=width,
862
+ height=height,
863
+ batch_size=batch_size * num_images_per_prompt,
864
+ num_images_per_prompt=num_images_per_prompt,
865
+ device=device,
866
+ dtype=controlnet.dtype,
867
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
868
+ guess_mode=guess_mode,
869
+ )
870
+ height, width = image.shape[-2:]
871
+ elif isinstance(controlnet, MultiControlNetModel):
872
+ images = []
873
+
874
+ for image_ in image:
875
+ image_ = self.prepare_image(
876
+ image=image_,
877
+ width=width,
878
+ height=height,
879
+ batch_size=batch_size * num_images_per_prompt,
880
+ num_images_per_prompt=num_images_per_prompt,
881
+ device=device,
882
+ dtype=controlnet.dtype,
883
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
884
+ guess_mode=guess_mode,
885
+ )
886
+
887
+ images.append(image_)
888
+
889
+ image = images
890
+ height, width = image[0].shape[-2:]
891
+ else:
892
+ assert False
893
+
894
+ # 4.1 Region control
895
+ if control_mask is not None:
896
+ mask_weight_image = control_mask
897
+ mask_weight_image = np.array(mask_weight_image)
898
+ mask_weight_image_tensor = torch.from_numpy(mask_weight_image).to(device=device, dtype=prompt_embeds.dtype)
899
+ mask_weight_image_tensor = mask_weight_image_tensor[:, :, 0] / 255.
900
+ mask_weight_image_tensor = mask_weight_image_tensor[None, None]
901
+ h, w = mask_weight_image_tensor.shape[-2:]
902
+ control_mask_wight_image_list = []
903
+ for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]:
904
+ scale_mask_weight_image_tensor = F.interpolate(
905
+ mask_weight_image_tensor,(h // scale, w // scale), mode='bilinear')
906
+ control_mask_wight_image_list.append(scale_mask_weight_image_tensor)
907
+ region_mask = torch.from_numpy(np.array(control_mask)[:, :, 0]).to(self.unet.device, dtype=self.unet.dtype) / 255.
908
+ region_control.prompt_image_conditioning = [dict(region_mask=region_mask)]
909
+ else:
910
+ control_mask_wight_image_list = None
911
+ region_control.prompt_image_conditioning = [dict(region_mask=None)]
912
+
913
+ # 5. Prepare timesteps
914
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
915
+ timesteps = self.scheduler.timesteps
916
+ self._num_timesteps = len(timesteps)
917
+
918
+ # 6. Prepare latent variables
919
+ num_channels_latents = self.unet.config.in_channels
920
+ latents = self.prepare_latents(
921
+ batch_size * num_images_per_prompt,
922
+ num_channels_latents,
923
+ height,
924
+ width,
925
+ prompt_embeds.dtype,
926
+ device,
927
+ generator,
928
+ latents,
929
+ )
930
+
931
+ # 6.5 Optionally get Guidance Scale Embedding
932
+ timestep_cond = None
933
+ if self.unet.config.time_cond_proj_dim is not None:
934
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
935
+ timestep_cond = self.get_guidance_scale_embedding(
936
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
937
+ ).to(device=device, dtype=latents.dtype)
938
+
939
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
940
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
941
+
942
+ # 7.1 Create tensor stating which controlnets to keep
943
+ controlnet_keep = []
944
+ for i in range(len(timesteps)):
945
+ keeps = [
946
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
947
+ for s, e in zip(control_guidance_start, control_guidance_end)
948
+ ]
949
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
950
+
951
+ # 7.2 Prepare added time ids & embeddings
952
+ if isinstance(image, list):
953
+ original_size = original_size or image[0].shape[-2:]
954
+ else:
955
+ original_size = original_size or image.shape[-2:]
956
+ target_size = target_size or (height, width)
957
+
958
+ add_text_embeds = pooled_prompt_embeds
959
+ if self.text_encoder_2 is None:
960
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
961
+ else:
962
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
963
+
964
+ add_time_ids = self._get_add_time_ids(
965
+ original_size,
966
+ crops_coords_top_left,
967
+ target_size,
968
+ dtype=prompt_embeds.dtype,
969
+ text_encoder_projection_dim=text_encoder_projection_dim,
970
+ )
971
+
972
+ if negative_original_size is not None and negative_target_size is not None:
973
+ negative_add_time_ids = self._get_add_time_ids(
974
+ negative_original_size,
975
+ negative_crops_coords_top_left,
976
+ negative_target_size,
977
+ dtype=prompt_embeds.dtype,
978
+ text_encoder_projection_dim=text_encoder_projection_dim,
979
+ )
980
+ else:
981
+ negative_add_time_ids = add_time_ids
982
+
983
+ if self.do_classifier_free_guidance:
984
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
985
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
986
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
987
+
988
+ prompt_embeds = prompt_embeds.to(device)
989
+ add_text_embeds = add_text_embeds.to(device)
990
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
991
+ encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
992
+
993
+ # 8. Denoising loop
994
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
995
+ is_unet_compiled = is_compiled_module(self.unet)
996
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
997
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
998
+
999
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1000
+ for i, t in enumerate(timesteps):
1001
+ # Relevant thread:
1002
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1003
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1004
+ torch._inductor.cudagraph_mark_step_begin()
1005
+ # expand the latents if we are doing classifier free guidance
1006
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1007
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1008
+
1009
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1010
+
1011
+ # controlnet(s) inference
1012
+ if guess_mode and self.do_classifier_free_guidance:
1013
+ # Infer ControlNet only for the conditional batch.
1014
+ control_model_input = latents
1015
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1016
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1017
+ controlnet_added_cond_kwargs = {
1018
+ "text_embeds": add_text_embeds.chunk(2)[1],
1019
+ "time_ids": add_time_ids.chunk(2)[1],
1020
+ }
1021
+ else:
1022
+ control_model_input = latent_model_input
1023
+ controlnet_prompt_embeds = prompt_embeds
1024
+ controlnet_added_cond_kwargs = added_cond_kwargs
1025
+
1026
+ if isinstance(controlnet_keep[i], list):
1027
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1028
+ else:
1029
+ controlnet_cond_scale = controlnet_conditioning_scale
1030
+ if isinstance(controlnet_cond_scale, list):
1031
+ controlnet_cond_scale = controlnet_cond_scale[0]
1032
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1033
+
1034
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1035
+ control_model_input,
1036
+ t,
1037
+ encoder_hidden_states=prompt_image_emb,
1038
+ controlnet_cond=image,
1039
+ conditioning_scale=cond_scale,
1040
+ guess_mode=guess_mode,
1041
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1042
+ return_dict=False,
1043
+ )
1044
+
1045
+ # controlnet mask
1046
+ if control_mask_wight_image_list is not None:
1047
+ down_block_res_samples = [
1048
+ down_block_res_sample * mask_weight
1049
+ for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1050
+ ]
1051
+ mid_block_res_sample *= control_mask_wight_image_list[-1]
1052
+
1053
+ if guess_mode and self.do_classifier_free_guidance:
1054
+ # Infered ControlNet only for the conditional batch.
1055
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1056
+ # add 0 to the unconditional batch to keep it unchanged.
1057
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1058
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1059
+
1060
+ # predict the noise residual
1061
+ noise_pred = self.unet(
1062
+ latent_model_input,
1063
+ t,
1064
+ encoder_hidden_states=encoder_hidden_states,
1065
+ timestep_cond=timestep_cond,
1066
+ cross_attention_kwargs=self.cross_attention_kwargs,
1067
+ down_block_additional_residuals=down_block_res_samples,
1068
+ mid_block_additional_residual=mid_block_res_sample,
1069
+ added_cond_kwargs=added_cond_kwargs,
1070
+ return_dict=False,
1071
+ )[0]
1072
+
1073
+ # perform guidance
1074
+ if self.do_classifier_free_guidance:
1075
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1076
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1077
+
1078
+ # compute the previous noisy sample x_t -> x_t-1
1079
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1080
+
1081
+ if callback_on_step_end is not None:
1082
+ callback_kwargs = {}
1083
+ for k in callback_on_step_end_tensor_inputs:
1084
+ callback_kwargs[k] = locals()[k]
1085
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1086
+
1087
+ latents = callback_outputs.pop("latents", latents)
1088
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1089
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1090
+
1091
+ # call the callback, if provided
1092
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1093
+ progress_bar.update()
1094
+ if callback is not None and i % callback_steps == 0:
1095
+ step_idx = i // getattr(self.scheduler, "order", 1)
1096
+ callback(step_idx, t, latents)
1097
+
1098
+ if not output_type == "latent":
1099
+ # make sure the VAE is in float32 mode, as it overflows in float16
1100
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1101
+ if needs_upcasting:
1102
+ self.upcast_vae()
1103
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1104
+
1105
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1106
+
1107
+ # cast back to fp16 if needed
1108
+ if needs_upcasting:
1109
+ self.vae.to(dtype=torch.float16)
1110
+ else:
1111
+ image = latents
1112
+
1113
+ if not output_type == "latent":
1114
+ # apply watermark if available
1115
+ if self.watermark is not None:
1116
+ image = self.watermark.apply_watermark(image)
1117
+
1118
+ image = self.image_processor.postprocess(image, output_type=output_type)
1119
+
1120
+ # Offload all models
1121
+ self.maybe_free_model_hooks()
1122
+
1123
+ if not return_dict:
1124
+ return (image,)
1125
+
1126
+ return StableDiffusionXLPipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.25.0
2
+ torch==2.0.0
3
+ torchvision==0.15.1
4
+ transformers==4.36.2
5
+ accelerate
6
+ safetensors
7
+ einops
8
+ onnxruntime-gpu
9
+ spaces==0.19.4
10
+ omegaconf
11
+ peft
12
+ huggingface-hub==0.20.2
13
+ opencv-python
14
+ insightface
style_template.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ style_list = [
2
+ {
3
+ "name": "(No style)",
4
+ "prompt": "{prompt}",
5
+ "negative_prompt": "",
6
+ },
7
+ {
8
+ "name": "Watercolor",
9
+ "prompt": "watercolor painting, {prompt}. vibrant, beautiful, painterly, detailed, textural, artistic",
10
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
11
+ },
12
+ {
13
+ "name": "Film Noir",
14
+ "prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic",
15
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
16
+ },
17
+ {
18
+ "name": "Neon",
19
+ "prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished",
20
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
21
+ },
22
+ {
23
+ "name": "Jungle",
24
+ "prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still',
25
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
26
+ },
27
+ {
28
+ "name": "Mars",
29
+ "prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)",
30
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
31
+ },
32
+ {
33
+ "name": "Vibrant Color",
34
+ "prompt": "vibrant colorful, ink sketch|vector|2d colors, at nightfall, sharp focus, {prompt}, highly detailed, sharp focus, the clouds,colorful,ultra sharpness",
35
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
36
+ },
37
+ {
38
+ "name": "Snow",
39
+ "prompt": "cinema 4d render, {prompt}, high contrast, vibrant and saturated, sico style, surrounded by magical glow,floating ice shards, snow crystals, cold, windy background, frozen natural landscape in background cinematic atmosphere,highly detailed, sharp focus, intricate design, 3d, unreal engine, octane render, CG best quality, highres, photorealistic, dramatic lighting, artstation, concept art, cinematic, epic Steven Spielberg movie still, sharp focus, smoke, sparks, art by pascal blanche and greg rutkowski and repin, trending on artstation, hyperrealism painting, matte painting, 4k resolution",
40
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
41
+ },
42
+ {
43
+ "name": "Line art",
44
+ "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
45
+ "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic",
46
+ },
47
+ ]
48
+
49
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}