AlanB commited on
Commit
e40757f
·
1 Parent(s): 1535184

Updated from Diffusers, still debugging new version

Browse files
Files changed (1) hide show
  1. pipeline.py +1470 -0
pipeline.py ADDED
@@ -0,0 +1,1470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import re
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import torch
8
+ from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
+
11
+ from diffusers import DiffusionPipeline
12
+ from diffusers.configuration_utils import FrozenDict
13
+ from diffusers.image_processor import VaeImageProcessor
14
+ from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
15
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
16
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import (
19
+ PIL_INTERPOLATION,
20
+ deprecate,
21
+ is_accelerate_available,
22
+ is_accelerate_version,
23
+ logging,
24
+ randn_tensor,
25
+ )
26
+
27
+
28
+ # ------------------------------------------------------------------------------
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+ re_attention = re.compile(
33
+ r"""
34
+ \\\(|
35
+ \\\)|
36
+ \\\[|
37
+ \\]|
38
+ \\\\|
39
+ \\|
40
+ \(|
41
+ \[|
42
+ :([+-]?[.\d]+)\)|
43
+ \)|
44
+ ]|
45
+ [^\\()\[\]:]+|
46
+ :
47
+ """,
48
+ re.X,
49
+ )
50
+
51
+
52
+ def parse_prompt_attention(text):
53
+ """
54
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
55
+ Accepted tokens are:
56
+ (abc) - increases attention to abc by a multiplier of 1.1
57
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
58
+ [abc] - decreases attention to abc by a multiplier of 1.1
59
+ \( - literal character '('
60
+ \[ - literal character '['
61
+ \) - literal character ')'
62
+ \] - literal character ']'
63
+ \\ - literal character '\'
64
+ anything else - just text
65
+ >>> parse_prompt_attention('normal text')
66
+ [['normal text', 1.0]]
67
+ >>> parse_prompt_attention('an (important) word')
68
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
69
+ >>> parse_prompt_attention('(unbalanced')
70
+ [['unbalanced', 1.1]]
71
+ >>> parse_prompt_attention('\(literal\]')
72
+ [['(literal]', 1.0]]
73
+ >>> parse_prompt_attention('(unnecessary)(parens)')
74
+ [['unnecessaryparens', 1.1]]
75
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
76
+ [['a ', 1.0],
77
+ ['house', 1.5730000000000004],
78
+ [' ', 1.1],
79
+ ['on', 1.0],
80
+ [' a ', 1.1],
81
+ ['hill', 0.55],
82
+ [', sun, ', 1.1],
83
+ ['sky', 1.4641000000000006],
84
+ ['.', 1.1]]
85
+ """
86
+
87
+ res = []
88
+ round_brackets = []
89
+ square_brackets = []
90
+
91
+ round_bracket_multiplier = 1.1
92
+ square_bracket_multiplier = 1 / 1.1
93
+
94
+ def multiply_range(start_position, multiplier):
95
+ for p in range(start_position, len(res)):
96
+ res[p][1] *= multiplier
97
+
98
+ for m in re_attention.finditer(text):
99
+ text = m.group(0)
100
+ weight = m.group(1)
101
+
102
+ if text.startswith("\\"):
103
+ res.append([text[1:], 1.0])
104
+ elif text == "(":
105
+ round_brackets.append(len(res))
106
+ elif text == "[":
107
+ square_brackets.append(len(res))
108
+ elif weight is not None and len(round_brackets) > 0:
109
+ multiply_range(round_brackets.pop(), float(weight))
110
+ elif text == ")" and len(round_brackets) > 0:
111
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
112
+ elif text == "]" and len(square_brackets) > 0:
113
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
114
+ else:
115
+ res.append([text, 1.0])
116
+
117
+ for pos in round_brackets:
118
+ multiply_range(pos, round_bracket_multiplier)
119
+
120
+ for pos in square_brackets:
121
+ multiply_range(pos, square_bracket_multiplier)
122
+
123
+ if len(res) == 0:
124
+ res = [["", 1.0]]
125
+
126
+ # merge runs of identical weights
127
+ i = 0
128
+ while i + 1 < len(res):
129
+ if res[i][1] == res[i + 1][1]:
130
+ res[i][0] += res[i + 1][0]
131
+ res.pop(i + 1)
132
+ else:
133
+ i += 1
134
+
135
+ return res
136
+
137
+
138
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
139
+ r"""
140
+ Tokenize a list of prompts and return its tokens with weights of each token.
141
+
142
+ No padding, starting or ending token is included.
143
+ """
144
+ tokens = []
145
+ weights = []
146
+ truncated = False
147
+ for text in prompt:
148
+ texts_and_weights = parse_prompt_attention(text)
149
+ text_token = []
150
+ text_weight = []
151
+ for word, weight in texts_and_weights:
152
+ # tokenize and discard the starting and the ending token
153
+ token = pipe.tokenizer(word).input_ids[1:-1]
154
+ text_token += token
155
+ # copy the weight by length of token
156
+ text_weight += [weight] * len(token)
157
+ # stop if the text is too long (longer than truncation limit)
158
+ if len(text_token) > max_length:
159
+ truncated = True
160
+ break
161
+ # truncate
162
+ if len(text_token) > max_length:
163
+ truncated = True
164
+ text_token = text_token[:max_length]
165
+ text_weight = text_weight[:max_length]
166
+ tokens.append(text_token)
167
+ weights.append(text_weight)
168
+ if truncated:
169
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
170
+ return tokens, weights
171
+
172
+
173
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
174
+ r"""
175
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
176
+ """
177
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
178
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
179
+ for i in range(len(tokens)):
180
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
181
+ if no_boseos_middle:
182
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
183
+ else:
184
+ w = []
185
+ if len(weights[i]) == 0:
186
+ w = [1.0] * weights_length
187
+ else:
188
+ for j in range(max_embeddings_multiples):
189
+ w.append(1.0) # weight for starting token in this chunk
190
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
191
+ w.append(1.0) # weight for ending token in this chunk
192
+ w += [1.0] * (weights_length - len(w))
193
+ weights[i] = w[:]
194
+
195
+ return tokens, weights
196
+
197
+
198
+ def get_unweighted_text_embeddings(
199
+ pipe: DiffusionPipeline,
200
+ text_input: torch.Tensor,
201
+ chunk_length: int,
202
+ no_boseos_middle: Optional[bool] = True,
203
+ ):
204
+ """
205
+ When the length of tokens is a multiple of the capacity of the text encoder,
206
+ it should be split into chunks and sent to the text encoder individually.
207
+ """
208
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
209
+ if max_embeddings_multiples > 1:
210
+ text_embeddings = []
211
+ for i in range(max_embeddings_multiples):
212
+ # extract the i-th chunk
213
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
214
+
215
+ # cover the head and the tail by the starting and the ending tokens
216
+ text_input_chunk[:, 0] = text_input[0, 0]
217
+ text_input_chunk[:, -1] = text_input[0, -1]
218
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
219
+
220
+ if no_boseos_middle:
221
+ if i == 0:
222
+ # discard the ending token
223
+ text_embedding = text_embedding[:, :-1]
224
+ elif i == max_embeddings_multiples - 1:
225
+ # discard the starting token
226
+ text_embedding = text_embedding[:, 1:]
227
+ else:
228
+ # discard both starting and ending tokens
229
+ text_embedding = text_embedding[:, 1:-1]
230
+
231
+ text_embeddings.append(text_embedding)
232
+ text_embeddings = torch.concat(text_embeddings, axis=1)
233
+ else:
234
+ text_embeddings = pipe.text_encoder(text_input)[0]
235
+ return text_embeddings
236
+
237
+
238
+ def get_weighted_text_embeddings(
239
+ pipe: DiffusionPipeline,
240
+ prompt: Union[str, List[str]],
241
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
242
+ max_embeddings_multiples: Optional[int] = 3,
243
+ no_boseos_middle: Optional[bool] = False,
244
+ skip_parsing: Optional[bool] = False,
245
+ skip_weighting: Optional[bool] = False,
246
+ ):
247
+ r"""
248
+ Prompts can be assigned with local weights using brackets. For example,
249
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
250
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
251
+
252
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
253
+
254
+ Args:
255
+ pipe (`DiffusionPipeline`):
256
+ Pipe to provide access to the tokenizer and the text encoder.
257
+ prompt (`str` or `List[str]`):
258
+ The prompt or prompts to guide the image generation.
259
+ uncond_prompt (`str` or `List[str]`):
260
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
261
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
262
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
263
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
264
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
265
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
266
+ ending token in each of the chunk in the middle.
267
+ skip_parsing (`bool`, *optional*, defaults to `False`):
268
+ Skip the parsing of brackets.
269
+ skip_weighting (`bool`, *optional*, defaults to `False`):
270
+ Skip the weighting. When the parsing is skipped, it is forced True.
271
+ """
272
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
273
+ if isinstance(prompt, str):
274
+ prompt = [prompt]
275
+
276
+ if not skip_parsing:
277
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
278
+ if uncond_prompt is not None:
279
+ if isinstance(uncond_prompt, str):
280
+ uncond_prompt = [uncond_prompt]
281
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
282
+ else:
283
+ prompt_tokens = [
284
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
285
+ ]
286
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
287
+ if uncond_prompt is not None:
288
+ if isinstance(uncond_prompt, str):
289
+ uncond_prompt = [uncond_prompt]
290
+ uncond_tokens = [
291
+ token[1:-1]
292
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
293
+ ]
294
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
295
+
296
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
297
+ max_length = max([len(token) for token in prompt_tokens])
298
+ if uncond_prompt is not None:
299
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
300
+
301
+ max_embeddings_multiples = min(
302
+ max_embeddings_multiples,
303
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
304
+ )
305
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
306
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
307
+
308
+ # pad the length of tokens and weights
309
+ bos = pipe.tokenizer.bos_token_id
310
+ eos = pipe.tokenizer.eos_token_id
311
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
312
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
313
+ prompt_tokens,
314
+ prompt_weights,
315
+ max_length,
316
+ bos,
317
+ eos,
318
+ pad,
319
+ no_boseos_middle=no_boseos_middle,
320
+ chunk_length=pipe.tokenizer.model_max_length,
321
+ )
322
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
323
+ if uncond_prompt is not None:
324
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
325
+ uncond_tokens,
326
+ uncond_weights,
327
+ max_length,
328
+ bos,
329
+ eos,
330
+ pad,
331
+ no_boseos_middle=no_boseos_middle,
332
+ chunk_length=pipe.tokenizer.model_max_length,
333
+ )
334
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
335
+
336
+ # get the embeddings
337
+ text_embeddings = get_unweighted_text_embeddings(
338
+ pipe,
339
+ prompt_tokens,
340
+ pipe.tokenizer.model_max_length,
341
+ no_boseos_middle=no_boseos_middle,
342
+ )
343
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
344
+ if uncond_prompt is not None:
345
+ uncond_embeddings = get_unweighted_text_embeddings(
346
+ pipe,
347
+ uncond_tokens,
348
+ pipe.tokenizer.model_max_length,
349
+ no_boseos_middle=no_boseos_middle,
350
+ )
351
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
352
+
353
+ # assign weights to the prompts and normalize in the sense of mean
354
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
355
+ if (not skip_parsing) and (not skip_weighting):
356
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
357
+ text_embeddings *= prompt_weights.unsqueeze(-1)
358
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
359
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
360
+ if uncond_prompt is not None:
361
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
362
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
363
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
364
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
365
+
366
+ if uncond_prompt is not None:
367
+ return text_embeddings, uncond_embeddings
368
+ return text_embeddings, None
369
+
370
+
371
+ def preprocess_image(image, batch_size):
372
+ w, h = image.size
373
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
374
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
375
+ image = np.array(image).astype(np.float32) / 255.0
376
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
377
+ image = torch.from_numpy(image)
378
+ return 2.0 * image - 1.0
379
+
380
+
381
+ def preprocess_mask(mask, batch_size, scale_factor=8):
382
+ if not isinstance(mask, torch.FloatTensor):
383
+ mask = mask.convert("L")
384
+ w, h = mask.size
385
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
386
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
387
+ mask = np.array(mask).astype(np.float32) / 255.0
388
+ mask = np.tile(mask, (4, 1, 1))
389
+ mask = np.vstack([mask[None]] * batch_size)
390
+ mask = 1 - mask # repaint white, keep black
391
+ mask = torch.from_numpy(mask)
392
+ return mask
393
+
394
+ else:
395
+ valid_mask_channel_sizes = [1, 3]
396
+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
397
+ if mask.shape[3] in valid_mask_channel_sizes:
398
+ mask = mask.permute(0, 3, 1, 2)
399
+ elif mask.shape[1] not in valid_mask_channel_sizes:
400
+ raise ValueError(
401
+ f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension,"
402
+ f" but received mask of shape {tuple(mask.shape)}"
403
+ )
404
+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
405
+ mask = mask.mean(dim=1, keepdim=True)
406
+ h, w = mask.shape[-2:]
407
+ h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8
408
+ mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
409
+ return mask
410
+
411
+
412
+ class StableDiffusionLongPromptWeightingPipeline(
413
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
414
+ ):
415
+ r"""
416
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
417
+ weighting in prompt.
418
+
419
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
420
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
421
+
422
+ Args:
423
+ vae ([`AutoencoderKL`]):
424
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
425
+ text_encoder ([`CLIPTextModel`]):
426
+ Frozen text-encoder. Stable Diffusion uses the text portion of
427
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
428
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
429
+ tokenizer (`CLIPTokenizer`):
430
+ Tokenizer of class
431
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
432
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
433
+ scheduler ([`SchedulerMixin`]):
434
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
435
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
436
+ safety_checker ([`StableDiffusionSafetyChecker`]):
437
+ Classification module that estimates whether generated images could be considered offensive or harmful.
438
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
439
+ feature_extractor ([`CLIPImageProcessor`]):
440
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
441
+ """
442
+
443
+ _optional_components = ["safety_checker", "feature_extractor"]
444
+
445
+ def __init__(
446
+ self,
447
+ vae: AutoencoderKL,
448
+ text_encoder: CLIPTextModel,
449
+ tokenizer: CLIPTokenizer,
450
+ unet: UNet2DConditionModel,
451
+ scheduler: KarrasDiffusionSchedulers,
452
+ safety_checker: StableDiffusionSafetyChecker,
453
+ feature_extractor: CLIPImageProcessor,
454
+ requires_safety_checker: bool = True,
455
+ ):
456
+ super().__init__()
457
+
458
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
459
+ deprecation_message = (
460
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
461
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
462
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
463
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
464
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
465
+ " file"
466
+ )
467
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
468
+ new_config = dict(scheduler.config)
469
+ new_config["steps_offset"] = 1
470
+ scheduler._internal_dict = FrozenDict(new_config)
471
+
472
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
473
+ deprecation_message = (
474
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
475
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
476
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
477
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
478
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
479
+ )
480
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
481
+ new_config = dict(scheduler.config)
482
+ new_config["clip_sample"] = False
483
+ scheduler._internal_dict = FrozenDict(new_config)
484
+
485
+ if safety_checker is None and requires_safety_checker:
486
+ logger.warning(
487
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
488
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
489
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
490
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
491
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
492
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
493
+ )
494
+
495
+ if safety_checker is not None and feature_extractor is None:
496
+ raise ValueError(
497
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
498
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
499
+ )
500
+
501
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
502
+ version.parse(unet.config._diffusers_version).base_version
503
+ ) < version.parse("0.9.0.dev0")
504
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
505
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
506
+ deprecation_message = (
507
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
508
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
509
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
510
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
511
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
512
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
513
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
514
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
515
+ " the `unet/config.json` file"
516
+ )
517
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
518
+ new_config = dict(unet.config)
519
+ new_config["sample_size"] = 64
520
+ unet._internal_dict = FrozenDict(new_config)
521
+ self.register_modules(
522
+ vae=vae,
523
+ text_encoder=text_encoder,
524
+ tokenizer=tokenizer,
525
+ unet=unet,
526
+ scheduler=scheduler,
527
+ safety_checker=safety_checker,
528
+ feature_extractor=feature_extractor,
529
+ )
530
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
531
+
532
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
533
+ self.register_to_config(
534
+ requires_safety_checker=requires_safety_checker,
535
+ )
536
+
537
+ def enable_vae_slicing(self):
538
+ r"""
539
+ Enable sliced VAE decoding.
540
+
541
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
542
+ steps. This is useful to save some memory and allow larger batch sizes.
543
+ """
544
+ self.vae.enable_slicing()
545
+
546
+ def disable_vae_slicing(self):
547
+ r"""
548
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
549
+ computing decoding in one step.
550
+ """
551
+ self.vae.disable_slicing()
552
+
553
+ def enable_vae_tiling(self):
554
+ r"""
555
+ Enable tiled VAE decoding.
556
+
557
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
558
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
559
+ """
560
+ self.vae.enable_tiling()
561
+
562
+ def disable_vae_tiling(self):
563
+ r"""
564
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
565
+ computing decoding in one step.
566
+ """
567
+ self.vae.disable_tiling()
568
+
569
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
570
+ def enable_sequential_cpu_offload(self, gpu_id=0):
571
+ r"""
572
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
573
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
574
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
575
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
576
+ `enable_model_cpu_offload`, but performance is lower.
577
+ """
578
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
579
+ from accelerate import cpu_offload
580
+ else:
581
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
582
+
583
+ device = torch.device(f"cuda:{gpu_id}")
584
+
585
+ if self.device.type != "cpu":
586
+ self.to("cpu", silence_dtype_warnings=True)
587
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
588
+
589
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
590
+ cpu_offload(cpu_offloaded_model, device)
591
+
592
+ if self.safety_checker is not None:
593
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
594
+
595
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
596
+ def enable_model_cpu_offload(self, gpu_id=0):
597
+ r"""
598
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
599
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
600
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
601
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
602
+ """
603
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
604
+ from accelerate import cpu_offload_with_hook
605
+ else:
606
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
607
+
608
+ device = torch.device(f"cuda:{gpu_id}")
609
+
610
+ if self.device.type != "cpu":
611
+ self.to("cpu", silence_dtype_warnings=True)
612
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
613
+
614
+ hook = None
615
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
616
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
617
+
618
+ if self.safety_checker is not None:
619
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
620
+
621
+ # We'll offload the last model manually.
622
+ self.final_offload_hook = hook
623
+
624
+ @property
625
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
626
+ def _execution_device(self):
627
+ r"""
628
+ Returns the device on which the pipeline's models will be executed. After calling
629
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
630
+ hooks.
631
+ """
632
+ if not hasattr(self.unet, "_hf_hook"):
633
+ return self.device
634
+ for module in self.unet.modules():
635
+ if (
636
+ hasattr(module, "_hf_hook")
637
+ and hasattr(module._hf_hook, "execution_device")
638
+ and module._hf_hook.execution_device is not None
639
+ ):
640
+ return torch.device(module._hf_hook.execution_device)
641
+ return self.device
642
+
643
+ def _encode_prompt(
644
+ self,
645
+ prompt,
646
+ device,
647
+ num_images_per_prompt,
648
+ do_classifier_free_guidance,
649
+ negative_prompt=None,
650
+ max_embeddings_multiples=3,
651
+ prompt_embeds: Optional[torch.FloatTensor] = None,
652
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ ):
654
+ r"""
655
+ Encodes the prompt into text encoder hidden states.
656
+
657
+ Args:
658
+ prompt (`str` or `list(int)`):
659
+ prompt to be encoded
660
+ device: (`torch.device`):
661
+ torch device
662
+ num_images_per_prompt (`int`):
663
+ number of images that should be generated per prompt
664
+ do_classifier_free_guidance (`bool`):
665
+ whether to use classifier free guidance or not
666
+ negative_prompt (`str` or `List[str]`):
667
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
668
+ if `guidance_scale` is less than `1`).
669
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
670
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
671
+ """
672
+ if prompt is not None and isinstance(prompt, str):
673
+ batch_size = 1
674
+ elif prompt is not None and isinstance(prompt, list):
675
+ batch_size = len(prompt)
676
+ else:
677
+ batch_size = prompt_embeds.shape[0]
678
+
679
+ if negative_prompt_embeds is None:
680
+ if negative_prompt is None:
681
+ negative_prompt = [""] * batch_size
682
+ elif isinstance(negative_prompt, str):
683
+ negative_prompt = [negative_prompt] * batch_size
684
+ if batch_size != len(negative_prompt):
685
+ raise ValueError(
686
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
687
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
688
+ " the batch size of `prompt`."
689
+ )
690
+ if prompt_embeds is None or negative_prompt_embeds is None:
691
+ if isinstance(self, TextualInversionLoaderMixin):
692
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
693
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
694
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
695
+
696
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
697
+ pipe=self,
698
+ prompt=prompt,
699
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
700
+ max_embeddings_multiples=max_embeddings_multiples,
701
+ )
702
+ if prompt_embeds is None:
703
+ prompt_embeds = prompt_embeds1
704
+ if negative_prompt_embeds is None:
705
+ negative_prompt_embeds = negative_prompt_embeds1
706
+
707
+ bs_embed, seq_len, _ = prompt_embeds.shape
708
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
709
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
710
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
711
+
712
+ if do_classifier_free_guidance:
713
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
714
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
715
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
716
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
717
+
718
+ return prompt_embeds
719
+
720
+ def check_inputs(
721
+ self,
722
+ prompt,
723
+ height,
724
+ width,
725
+ strength,
726
+ callback_steps,
727
+ negative_prompt=None,
728
+ prompt_embeds=None,
729
+ negative_prompt_embeds=None,
730
+ ):
731
+ if height % 8 != 0 or width % 8 != 0:
732
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
733
+
734
+ if strength < 0 or strength > 1:
735
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
736
+
737
+ if (callback_steps is None) or (
738
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
739
+ ):
740
+ raise ValueError(
741
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
742
+ f" {type(callback_steps)}."
743
+ )
744
+
745
+ if prompt is not None and prompt_embeds is not None:
746
+ raise ValueError(
747
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
748
+ " only forward one of the two."
749
+ )
750
+ elif prompt is None and prompt_embeds is None:
751
+ raise ValueError(
752
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
753
+ )
754
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
755
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
756
+
757
+ if negative_prompt is not None and negative_prompt_embeds is not None:
758
+ raise ValueError(
759
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
760
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
761
+ )
762
+
763
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
764
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
765
+ raise ValueError(
766
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
767
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
768
+ f" {negative_prompt_embeds.shape}."
769
+ )
770
+
771
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
772
+ if is_text2img:
773
+ return self.scheduler.timesteps.to(device), num_inference_steps
774
+ else:
775
+ # get the original timestep using init_timestep
776
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
777
+
778
+ t_start = max(num_inference_steps - init_timestep, 0)
779
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
780
+
781
+ return timesteps, num_inference_steps - t_start
782
+
783
+ def run_safety_checker(self, image, device, dtype):
784
+ if self.safety_checker is not None:
785
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
786
+ image, has_nsfw_concept = self.safety_checker(
787
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
788
+ )
789
+ else:
790
+ has_nsfw_concept = None
791
+ return image, has_nsfw_concept
792
+
793
+ def decode_latents(self, latents):
794
+ latents = 1 / self.vae.config.scaling_factor * latents
795
+ image = self.vae.decode(latents).sample
796
+ image = (image / 2 + 0.5).clamp(0, 1)
797
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
798
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
799
+ return image
800
+
801
+ def prepare_extra_step_kwargs(self, generator, eta):
802
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
803
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
804
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
805
+ # and should be between [0, 1]
806
+
807
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
808
+ extra_step_kwargs = {}
809
+ if accepts_eta:
810
+ extra_step_kwargs["eta"] = eta
811
+
812
+ # check if the scheduler accepts generator
813
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
814
+ if accepts_generator:
815
+ extra_step_kwargs["generator"] = generator
816
+ return extra_step_kwargs
817
+
818
+ def prepare_latents(
819
+ self,
820
+ image,
821
+ timestep,
822
+ num_images_per_prompt,
823
+ batch_size,
824
+ num_channels_latents,
825
+ height,
826
+ width,
827
+ dtype,
828
+ device,
829
+ generator,
830
+ latents=None,
831
+ ):
832
+ if image is None:
833
+ batch_size = batch_size * num_images_per_prompt
834
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
835
+ if isinstance(generator, list) and len(generator) != batch_size:
836
+ raise ValueError(
837
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
838
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
839
+ )
840
+
841
+ if latents is None:
842
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
843
+ else:
844
+ latents = latents.to(device)
845
+
846
+ # scale the initial noise by the standard deviation required by the scheduler
847
+ latents = latents * self.scheduler.init_noise_sigma
848
+ return latents, None, None
849
+ else:
850
+ image = image.to(device=self.device, dtype=dtype)
851
+ init_latent_dist = self.vae.encode(image).latent_dist
852
+ init_latents = init_latent_dist.sample(generator=generator)
853
+ init_latents = self.vae.config.scaling_factor * init_latents
854
+
855
+ # Expand init_latents for batch_size and num_images_per_prompt
856
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
857
+ init_latents_orig = init_latents
858
+
859
+ # add noise to latents using the timesteps
860
+ noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
861
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
862
+ latents = init_latents
863
+ return latents, init_latents_orig, noise
864
+
865
+ @torch.no_grad()
866
+ def __call__(
867
+ self,
868
+ prompt: Union[str, List[str]],
869
+ negative_prompt: Optional[Union[str, List[str]]] = None,
870
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
871
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
872
+ height: int = 512,
873
+ width: int = 512,
874
+ num_inference_steps: int = 50,
875
+ guidance_scale: float = 7.5,
876
+ strength: float = 0.8,
877
+ num_images_per_prompt: Optional[int] = 1,
878
+ add_predicted_noise: Optional[bool] = False,
879
+ eta: float = 0.0,
880
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
881
+ latents: Optional[torch.FloatTensor] = None,
882
+ prompt_embeds: Optional[torch.FloatTensor] = None,
883
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
884
+ max_embeddings_multiples: Optional[int] = 3,
885
+ output_type: Optional[str] = "pil",
886
+ return_dict: bool = True,
887
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
888
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
889
+ callback_steps: int = 1,
890
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
891
+ ):
892
+ r"""
893
+ Function invoked when calling the pipeline for generation.
894
+
895
+ Args:
896
+ prompt (`str` or `List[str]`):
897
+ The prompt or prompts to guide the image generation.
898
+ negative_prompt (`str` or `List[str]`, *optional*):
899
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
900
+ if `guidance_scale` is less than `1`).
901
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
902
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
903
+ process.
904
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
905
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
906
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
907
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
908
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
909
+ height (`int`, *optional*, defaults to 512):
910
+ The height in pixels of the generated image.
911
+ width (`int`, *optional*, defaults to 512):
912
+ The width in pixels of the generated image.
913
+ num_inference_steps (`int`, *optional*, defaults to 50):
914
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
915
+ expense of slower inference.
916
+ guidance_scale (`float`, *optional*, defaults to 7.5):
917
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
918
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
919
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
920
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
921
+ usually at the expense of lower image quality.
922
+ strength (`float`, *optional*, defaults to 0.8):
923
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
924
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
925
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
926
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
927
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
928
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
929
+ The number of images to generate per prompt.
930
+ add_predicted_noise (`bool`, *optional*, defaults to True):
931
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
932
+ the reverse diffusion process
933
+ eta (`float`, *optional*, defaults to 0.0):
934
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
935
+ [`schedulers.DDIMScheduler`], will be ignored for others.
936
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
937
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
938
+ to make generation deterministic.
939
+ latents (`torch.FloatTensor`, *optional*):
940
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
941
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
942
+ tensor will ge generated by sampling using the supplied random `generator`.
943
+ prompt_embeds (`torch.FloatTensor`, *optional*):
944
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
945
+ provided, text embeddings will be generated from `prompt` input argument.
946
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
947
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
948
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
949
+ argument.
950
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
951
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
952
+ output_type (`str`, *optional*, defaults to `"pil"`):
953
+ The output format of the generate image. Choose between
954
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
955
+ return_dict (`bool`, *optional*, defaults to `True`):
956
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
957
+ plain tuple.
958
+ callback (`Callable`, *optional*):
959
+ A function that will be called every `callback_steps` steps during inference. The function will be
960
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
961
+ is_cancelled_callback (`Callable`, *optional*):
962
+ A function that will be called every `callback_steps` steps during inference. If the function returns
963
+ `True`, the inference will be cancelled.
964
+ callback_steps (`int`, *optional*, defaults to 1):
965
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
966
+ called at every step.
967
+ cross_attention_kwargs (`dict`, *optional*):
968
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
969
+ `self.processor` in
970
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
971
+
972
+ Returns:
973
+ `None` if cancelled by `is_cancelled_callback`,
974
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
975
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
976
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
977
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
978
+ (nsfw) content, according to the `safety_checker`.
979
+ """
980
+ # 0. Default height and width to unet
981
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
982
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
983
+
984
+ # 1. Check inputs. Raise error if not correct
985
+ self.check_inputs(
986
+ prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
987
+ )
988
+
989
+ # 2. Define call parameters
990
+ if prompt is not None and isinstance(prompt, str):
991
+ batch_size = 1
992
+ elif prompt is not None and isinstance(prompt, list):
993
+ batch_size = len(prompt)
994
+ else:
995
+ batch_size = prompt_embeds.shape[0]
996
+
997
+ device = self._execution_device
998
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
999
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1000
+ # corresponds to doing no classifier free guidance.
1001
+ do_classifier_free_guidance = guidance_scale > 1.0
1002
+
1003
+ # 3. Encode input prompt
1004
+ prompt_embeds = self._encode_prompt(
1005
+ prompt,
1006
+ device,
1007
+ num_images_per_prompt,
1008
+ do_classifier_free_guidance,
1009
+ negative_prompt,
1010
+ max_embeddings_multiples,
1011
+ prompt_embeds=prompt_embeds,
1012
+ negative_prompt_embeds=negative_prompt_embeds,
1013
+ )
1014
+ dtype = prompt_embeds.dtype
1015
+
1016
+ # 4. Preprocess image and mask
1017
+ if isinstance(image, PIL.Image.Image):
1018
+ image = preprocess_image(image, batch_size)
1019
+ if image is not None:
1020
+ image = image.to(device=self.device, dtype=dtype)
1021
+ if isinstance(mask_image, PIL.Image.Image):
1022
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
1023
+ if mask_image is not None:
1024
+ mask = mask_image.to(device=self.device, dtype=dtype)
1025
+ mask = torch.cat([mask] * num_images_per_prompt)
1026
+ else:
1027
+ mask = None
1028
+
1029
+ # 5. set timesteps
1030
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1031
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
1032
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1033
+
1034
+ # 6. Prepare latent variables
1035
+ latents, init_latents_orig, noise = self.prepare_latents(
1036
+ image,
1037
+ latent_timestep,
1038
+ num_images_per_prompt,
1039
+ batch_size,
1040
+ self.unet.config.in_channels,
1041
+ height,
1042
+ width,
1043
+ dtype,
1044
+ device,
1045
+ generator,
1046
+ latents,
1047
+ )
1048
+
1049
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1050
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1051
+
1052
+ # 8. Denoising loop
1053
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1054
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1055
+ for i, t in enumerate(timesteps):
1056
+ # expand the latents if we are doing classifier free guidance
1057
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1058
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1059
+
1060
+ # predict the noise residual
1061
+ noise_pred = self.unet(
1062
+ latent_model_input,
1063
+ t,
1064
+ encoder_hidden_states=prompt_embeds,
1065
+ cross_attention_kwargs=cross_attention_kwargs,
1066
+ ).sample
1067
+
1068
+ # perform guidance
1069
+ if do_classifier_free_guidance:
1070
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1071
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1072
+
1073
+ # compute the previous noisy sample x_t -> x_t-1
1074
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1075
+
1076
+ if mask is not None:
1077
+ # masking
1078
+ if add_predicted_noise:
1079
+ init_latents_proper = self.scheduler.add_noise(
1080
+ init_latents_orig, noise_pred_uncond, torch.tensor([t])
1081
+ )
1082
+ else:
1083
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1084
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1085
+
1086
+ # call the callback, if provided
1087
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1088
+ progress_bar.update()
1089
+ if i % callback_steps == 0:
1090
+ if callback is not None:
1091
+ callback(i, t, latents)
1092
+ if is_cancelled_callback is not None and is_cancelled_callback():
1093
+ return None
1094
+
1095
+ if output_type == "latent":
1096
+ image = latents
1097
+ has_nsfw_concept = None
1098
+ elif output_type == "pil":
1099
+ # 9. Post-processing
1100
+ image = self.decode_latents(latents)
1101
+
1102
+ # 10. Run safety checker
1103
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1104
+
1105
+ # 11. Convert to PIL
1106
+ image = self.numpy_to_pil(image)
1107
+ else:
1108
+ # 9. Post-processing
1109
+ image = self.decode_latents(latents)
1110
+
1111
+ # 10. Run safety checker
1112
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1113
+
1114
+ # Offload last model to CPU
1115
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1116
+ self.final_offload_hook.offload()
1117
+
1118
+ if not return_dict:
1119
+ return image, has_nsfw_concept
1120
+
1121
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1122
+
1123
+ def text2img(
1124
+ self,
1125
+ prompt: Union[str, List[str]],
1126
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1127
+ height: int = 512,
1128
+ width: int = 512,
1129
+ num_inference_steps: int = 50,
1130
+ guidance_scale: float = 7.5,
1131
+ num_images_per_prompt: Optional[int] = 1,
1132
+ eta: float = 0.0,
1133
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1134
+ latents: Optional[torch.FloatTensor] = None,
1135
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1136
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1137
+ max_embeddings_multiples: Optional[int] = 3,
1138
+ output_type: Optional[str] = "pil",
1139
+ return_dict: bool = True,
1140
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1141
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1142
+ callback_steps: int = 1,
1143
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1144
+ ):
1145
+ r"""
1146
+ Function for text-to-image generation.
1147
+ Args:
1148
+ prompt (`str` or `List[str]`):
1149
+ The prompt or prompts to guide the image generation.
1150
+ negative_prompt (`str` or `List[str]`, *optional*):
1151
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1152
+ if `guidance_scale` is less than `1`).
1153
+ height (`int`, *optional*, defaults to 512):
1154
+ The height in pixels of the generated image.
1155
+ width (`int`, *optional*, defaults to 512):
1156
+ The width in pixels of the generated image.
1157
+ num_inference_steps (`int`, *optional*, defaults to 50):
1158
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1159
+ expense of slower inference.
1160
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1161
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1162
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1163
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1164
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1165
+ usually at the expense of lower image quality.
1166
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1167
+ The number of images to generate per prompt.
1168
+ eta (`float`, *optional*, defaults to 0.0):
1169
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1170
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1171
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1172
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1173
+ to make generation deterministic.
1174
+ latents (`torch.FloatTensor`, *optional*):
1175
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1176
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1177
+ tensor will ge generated by sampling using the supplied random `generator`.
1178
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1179
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1180
+ provided, text embeddings will be generated from `prompt` input argument.
1181
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1182
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1183
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1184
+ argument.
1185
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1186
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1187
+ output_type (`str`, *optional*, defaults to `"pil"`):
1188
+ The output format of the generate image. Choose between
1189
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1190
+ return_dict (`bool`, *optional*, defaults to `True`):
1191
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1192
+ plain tuple.
1193
+ callback (`Callable`, *optional*):
1194
+ A function that will be called every `callback_steps` steps during inference. The function will be
1195
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1196
+ is_cancelled_callback (`Callable`, *optional*):
1197
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1198
+ `True`, the inference will be cancelled.
1199
+ callback_steps (`int`, *optional*, defaults to 1):
1200
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1201
+ called at every step.
1202
+ cross_attention_kwargs (`dict`, *optional*):
1203
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1204
+ `self.processor` in
1205
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1206
+
1207
+ Returns:
1208
+ `None` if cancelled by `is_cancelled_callback`,
1209
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1210
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1211
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1212
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1213
+ (nsfw) content, according to the `safety_checker`.
1214
+ """
1215
+ return self.__call__(
1216
+ prompt=prompt,
1217
+ negative_prompt=negative_prompt,
1218
+ height=height,
1219
+ width=width,
1220
+ num_inference_steps=num_inference_steps,
1221
+ guidance_scale=guidance_scale,
1222
+ num_images_per_prompt=num_images_per_prompt,
1223
+ eta=eta,
1224
+ generator=generator,
1225
+ latents=latents,
1226
+ prompt_embeds=prompt_embeds,
1227
+ negative_prompt_embeds=negative_prompt_embeds,
1228
+ max_embeddings_multiples=max_embeddings_multiples,
1229
+ output_type=output_type,
1230
+ return_dict=return_dict,
1231
+ callback=callback,
1232
+ is_cancelled_callback=is_cancelled_callback,
1233
+ callback_steps=callback_steps,
1234
+ cross_attention_kwargs=cross_attention_kwargs,
1235
+ )
1236
+
1237
+ def img2img(
1238
+ self,
1239
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1240
+ prompt: Union[str, List[str]],
1241
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1242
+ strength: float = 0.8,
1243
+ num_inference_steps: Optional[int] = 50,
1244
+ guidance_scale: Optional[float] = 7.5,
1245
+ num_images_per_prompt: Optional[int] = 1,
1246
+ eta: Optional[float] = 0.0,
1247
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1248
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1249
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1250
+ max_embeddings_multiples: Optional[int] = 3,
1251
+ output_type: Optional[str] = "pil",
1252
+ return_dict: bool = True,
1253
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1254
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1255
+ callback_steps: int = 1,
1256
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1257
+ ):
1258
+ r"""
1259
+ Function for image-to-image generation.
1260
+ Args:
1261
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1262
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1263
+ process.
1264
+ prompt (`str` or `List[str]`):
1265
+ The prompt or prompts to guide the image generation.
1266
+ negative_prompt (`str` or `List[str]`, *optional*):
1267
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1268
+ if `guidance_scale` is less than `1`).
1269
+ strength (`float`, *optional*, defaults to 0.8):
1270
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1271
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1272
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1273
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1274
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1275
+ num_inference_steps (`int`, *optional*, defaults to 50):
1276
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1277
+ expense of slower inference. This parameter will be modulated by `strength`.
1278
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1279
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1280
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1281
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1282
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1283
+ usually at the expense of lower image quality.
1284
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1285
+ The number of images to generate per prompt.
1286
+ eta (`float`, *optional*, defaults to 0.0):
1287
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1288
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1289
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1290
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1291
+ to make generation deterministic.
1292
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1293
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1294
+ provided, text embeddings will be generated from `prompt` input argument.
1295
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1296
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1297
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1298
+ argument.
1299
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1300
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1301
+ output_type (`str`, *optional*, defaults to `"pil"`):
1302
+ The output format of the generate image. Choose between
1303
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1304
+ return_dict (`bool`, *optional*, defaults to `True`):
1305
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1306
+ plain tuple.
1307
+ callback (`Callable`, *optional*):
1308
+ A function that will be called every `callback_steps` steps during inference. The function will be
1309
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1310
+ is_cancelled_callback (`Callable`, *optional*):
1311
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1312
+ `True`, the inference will be cancelled.
1313
+ callback_steps (`int`, *optional*, defaults to 1):
1314
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1315
+ called at every step.
1316
+ cross_attention_kwargs (`dict`, *optional*):
1317
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1318
+ `self.processor` in
1319
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1320
+
1321
+ Returns:
1322
+ `None` if cancelled by `is_cancelled_callback`,
1323
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1324
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1325
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1326
+ (nsfw) content, according to the `safety_checker`.
1327
+ """
1328
+ return self.__call__(
1329
+ prompt=prompt,
1330
+ negative_prompt=negative_prompt,
1331
+ image=image,
1332
+ num_inference_steps=num_inference_steps,
1333
+ guidance_scale=guidance_scale,
1334
+ strength=strength,
1335
+ num_images_per_prompt=num_images_per_prompt,
1336
+ eta=eta,
1337
+ generator=generator,
1338
+ prompt_embeds=prompt_embeds,
1339
+ negative_prompt_embeds=negative_prompt_embeds,
1340
+ max_embeddings_multiples=max_embeddings_multiples,
1341
+ output_type=output_type,
1342
+ return_dict=return_dict,
1343
+ callback=callback,
1344
+ is_cancelled_callback=is_cancelled_callback,
1345
+ callback_steps=callback_steps,
1346
+ cross_attention_kwargs=cross_attention_kwargs,
1347
+ )
1348
+
1349
+ def inpaint(
1350
+ self,
1351
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1352
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1353
+ prompt: Union[str, List[str]],
1354
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1355
+ strength: float = 0.8,
1356
+ num_inference_steps: Optional[int] = 50,
1357
+ guidance_scale: Optional[float] = 7.5,
1358
+ num_images_per_prompt: Optional[int] = 1,
1359
+ add_predicted_noise: Optional[bool] = False,
1360
+ eta: Optional[float] = 0.0,
1361
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1362
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1363
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1364
+ max_embeddings_multiples: Optional[int] = 3,
1365
+ output_type: Optional[str] = "pil",
1366
+ return_dict: bool = True,
1367
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1368
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1369
+ callback_steps: int = 1,
1370
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1371
+ ):
1372
+ r"""
1373
+ Function for inpaint.
1374
+ Args:
1375
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1376
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1377
+ process. This is the image whose masked region will be inpainted.
1378
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1379
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1380
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1381
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1382
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1383
+ prompt (`str` or `List[str]`):
1384
+ The prompt or prompts to guide the image generation.
1385
+ negative_prompt (`str` or `List[str]`, *optional*):
1386
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1387
+ if `guidance_scale` is less than `1`).
1388
+ strength (`float`, *optional*, defaults to 0.8):
1389
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1390
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1391
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1392
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1393
+ num_inference_steps (`int`, *optional*, defaults to 50):
1394
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1395
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1396
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1397
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1398
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1399
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1400
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1401
+ usually at the expense of lower image quality.
1402
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1403
+ The number of images to generate per prompt.
1404
+ add_predicted_noise (`bool`, *optional*, defaults to True):
1405
+ Use predicted noise instead of random noise when constructing noisy versions of the original image in
1406
+ the reverse diffusion process
1407
+ eta (`float`, *optional*, defaults to 0.0):
1408
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1409
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1410
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1411
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1412
+ to make generation deterministic.
1413
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1414
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1415
+ provided, text embeddings will be generated from `prompt` input argument.
1416
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1417
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1418
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1419
+ argument.
1420
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1421
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1422
+ output_type (`str`, *optional*, defaults to `"pil"`):
1423
+ The output format of the generate image. Choose between
1424
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1425
+ return_dict (`bool`, *optional*, defaults to `True`):
1426
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1427
+ plain tuple.
1428
+ callback (`Callable`, *optional*):
1429
+ A function that will be called every `callback_steps` steps during inference. The function will be
1430
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1431
+ is_cancelled_callback (`Callable`, *optional*):
1432
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1433
+ `True`, the inference will be cancelled.
1434
+ callback_steps (`int`, *optional*, defaults to 1):
1435
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1436
+ called at every step.
1437
+ cross_attention_kwargs (`dict`, *optional*):
1438
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1439
+ `self.processor` in
1440
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1441
+
1442
+ Returns:
1443
+ `None` if cancelled by `is_cancelled_callback`,
1444
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1445
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1446
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1447
+ (nsfw) content, according to the `safety_checker`.
1448
+ """
1449
+ return self.__call__(
1450
+ prompt=prompt,
1451
+ negative_prompt=negative_prompt,
1452
+ image=image,
1453
+ mask_image=mask_image,
1454
+ num_inference_steps=num_inference_steps,
1455
+ guidance_scale=guidance_scale,
1456
+ strength=strength,
1457
+ num_images_per_prompt=num_images_per_prompt,
1458
+ add_predicted_noise=add_predicted_noise,
1459
+ eta=eta,
1460
+ generator=generator,
1461
+ prompt_embeds=prompt_embeds,
1462
+ negative_prompt_embeds=negative_prompt_embeds,
1463
+ max_embeddings_multiples=max_embeddings_multiples,
1464
+ output_type=output_type,
1465
+ return_dict=return_dict,
1466
+ callback=callback,
1467
+ is_cancelled_callback=is_cancelled_callback,
1468
+ callback_steps=callback_steps,
1469
+ cross_attention_kwargs=cross_attention_kwargs,
1470
+ )