quickjkee commited on
Commit
b1d3789
·
verified ·
1 Parent(s): c125a80

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +430 -0
pipeline.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
7
+ from diffusers.image_processor import PipelineImageInput
8
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
9
+ StableDiffusionXLPipelineOutput,
10
+ )
11
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
12
+ StableDiffusionXLPipeline,
13
+ rescale_noise_cfg,
14
+ )
15
+ from diffusers.utils import deprecate, is_torch_xla_available
16
+
17
+ if is_torch_xla_available():
18
+ import torch_xla.core.xla_model as xm
19
+
20
+ XLA_AVAILABLE = True
21
+ else:
22
+ XLA_AVAILABLE = False
23
+
24
+
25
+ def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...]) -> torch.Tensor:
26
+ b, *_ = t.shape
27
+ out = a.gather(-1, t.long())
28
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
29
+
30
+
31
+ class SwDPipeline(StableDiffusionXLPipeline):
32
+ @torch.no_grad()
33
+ def __call__(
34
+ self,
35
+ prompt: Union[str, List[str]] = None,
36
+ prompt_2: Optional[Union[str, List[str]]] = None,
37
+ height: Optional[int] = None,
38
+ width: Optional[int] = None,
39
+ num_inference_steps: int = 50,
40
+ timesteps: Optional[List[int]] = None,
41
+ sigmas: Optional[List[float]] = None,
42
+ scales: Optional[List[float]] = None,
43
+ denoising_end: Optional[float] = None,
44
+ guidance_scale: float = 5.0,
45
+ negative_prompt: Optional[Union[str, List[str]]] = None,
46
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
47
+ num_images_per_prompt: int = 1,
48
+ eta: float = 0.0,
49
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
50
+ latents: Optional[torch.Tensor] = None,
51
+ prompt_embeds: Optional[torch.Tensor] = None,
52
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
53
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
54
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
55
+ ip_adapter_image: Optional[PipelineImageInput] = None,
56
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
57
+ output_type: str = "pil",
58
+ return_dict: bool = True,
59
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
60
+ guidance_rescale: float = 0.0,
61
+ original_size: Optional[Tuple[int, int]] = None,
62
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
63
+ target_size: Optional[Tuple[int, int]] = None,
64
+ negative_original_size: Optional[Tuple[int, int]] = None,
65
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
66
+ negative_target_size: Optional[Tuple[int, int]] = None,
67
+ clip_skip: Optional[int] = None,
68
+ callback_on_step_end: Optional[
69
+ Union[
70
+ Callable[[int, int, Dict[str, Any]], None],
71
+ PipelineCallback,
72
+ MultiPipelineCallbacks,
73
+ ]
74
+ ] = None,
75
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
76
+ **kwargs: Any,
77
+ ) -> StableDiffusionXLPipelineOutput:
78
+ if callback_on_step_end_tensor_inputs is None:
79
+ callback_on_step_end_tensor_inputs = ["latents"]
80
+
81
+ callback = kwargs.pop("callback", None)
82
+ callback_steps = kwargs.pop("callback_steps", None)
83
+
84
+ if callback is not None:
85
+ deprecate(
86
+ "callback",
87
+ "1.0.0",
88
+ (
89
+ "Passing `callback` as an input argument to `__call__` is deprecated, "
90
+ "consider use `callback_on_step_end`"
91
+ ),
92
+ )
93
+ if callback_steps is not None:
94
+ deprecate(
95
+ "callback_steps",
96
+ "1.0.0",
97
+ (
98
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, "
99
+ "consider use `callback_on_step_end`"
100
+ ),
101
+ )
102
+
103
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
104
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
105
+
106
+ # 0. Default height and width to unet
107
+ height = height or self.default_sample_size * self.vae_scale_factor
108
+ width = width or self.default_sample_size * self.vae_scale_factor
109
+
110
+ original_size = original_size or (height, width)
111
+ target_size = target_size or (height, width)
112
+
113
+ # 1. Check inputs. Raise error if not correct
114
+ self.check_inputs(
115
+ prompt,
116
+ prompt_2,
117
+ height,
118
+ width,
119
+ callback_steps,
120
+ negative_prompt,
121
+ negative_prompt_2,
122
+ prompt_embeds,
123
+ negative_prompt_embeds,
124
+ pooled_prompt_embeds,
125
+ negative_pooled_prompt_embeds,
126
+ ip_adapter_image,
127
+ ip_adapter_image_embeds,
128
+ callback_on_step_end_tensor_inputs,
129
+ )
130
+
131
+ self._guidance_scale = guidance_scale
132
+ self._guidance_rescale = guidance_rescale
133
+ self._clip_skip = clip_skip
134
+ self._cross_attention_kwargs = cross_attention_kwargs
135
+ self._denoising_end = denoising_end
136
+ self._interrupt = False
137
+
138
+ # 2. Define call parameters
139
+ if prompt is not None and isinstance(prompt, str):
140
+ batch_size = 1
141
+ elif prompt is not None and isinstance(prompt, list):
142
+ batch_size = len(prompt)
143
+ else:
144
+ batch_size = prompt_embeds.shape[0]
145
+
146
+ device = self._execution_device
147
+
148
+ # 3. Encode input prompt
149
+ lora_scale = None
150
+ if self.cross_attention_kwargs is not None:
151
+ lora_scale = self.cross_attention_kwargs.get("scale", None)
152
+
153
+ (
154
+ prompt_embeds,
155
+ negative_prompt_embeds,
156
+ pooled_prompt_embeds,
157
+ negative_pooled_prompt_embeds,
158
+ ) = self.encode_prompt(
159
+ prompt=prompt,
160
+ prompt_2=prompt_2,
161
+ device=device,
162
+ num_images_per_prompt=num_images_per_prompt,
163
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
164
+ negative_prompt=negative_prompt,
165
+ negative_prompt_2=negative_prompt_2,
166
+ prompt_embeds=prompt_embeds,
167
+ negative_prompt_embeds=negative_prompt_embeds,
168
+ pooled_prompt_embeds=pooled_prompt_embeds,
169
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
170
+ lora_scale=lora_scale,
171
+ clip_skip=self.clip_skip,
172
+ )
173
+
174
+ # 4. Prepare timesteps
175
+ if timesteps is None:
176
+ raise ValueError("`timesteps` must be provided for SwDPipeline.__call__().")
177
+
178
+ timesteps_tensor = torch.tensor(timesteps, dtype=torch.long)
179
+ timesteps = self.scheduler.timesteps[(1000 - timesteps_tensor)[:-1]].to(
180
+ device=device,
181
+ dtype=torch.long,
182
+ )
183
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device=device)
184
+
185
+ # 5. Prepare latent variables
186
+ if not scales:
187
+ raise ValueError("`scales` must be a non-empty list.")
188
+
189
+ num_channels_latents = self.unet.config.in_channels
190
+ latents = self.prepare_latents(
191
+ batch_size * num_images_per_prompt,
192
+ num_channels_latents,
193
+ scales[0] * self.vae_scale_factor,
194
+ scales[0] * self.vae_scale_factor,
195
+ prompt_embeds.dtype,
196
+ device,
197
+ generator,
198
+ latents,
199
+ )
200
+
201
+ # 6. Prepare extra step kwargs
202
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
203
+
204
+ # 7. Prepare added time ids & embeddings
205
+ _ = extra_step_kwargs # kept for parity with original pipeline flow
206
+
207
+ add_text_embeds = pooled_prompt_embeds
208
+ if self.text_encoder_2 is None:
209
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
210
+ else:
211
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
212
+
213
+ add_time_ids = self._get_add_time_ids(
214
+ original_size,
215
+ crops_coords_top_left,
216
+ target_size,
217
+ dtype=prompt_embeds.dtype,
218
+ text_encoder_projection_dim=text_encoder_projection_dim,
219
+ )
220
+
221
+ if negative_original_size is not None and negative_target_size is not None:
222
+ negative_add_time_ids = self._get_add_time_ids(
223
+ negative_original_size,
224
+ negative_crops_coords_top_left,
225
+ negative_target_size,
226
+ dtype=prompt_embeds.dtype,
227
+ text_encoder_projection_dim=text_encoder_projection_dim,
228
+ )
229
+ else:
230
+ negative_add_time_ids = add_time_ids
231
+
232
+ if self.do_classifier_free_guidance:
233
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
234
+ add_text_embeds = torch.cat(
235
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
236
+ )
237
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
238
+
239
+ prompt_embeds = prompt_embeds.to(device)
240
+ add_text_embeds = add_text_embeds.to(device)
241
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
242
+
243
+ image_embeds = None
244
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
245
+ image_embeds = self.prepare_ip_adapter_image_embeds(
246
+ ip_adapter_image,
247
+ ip_adapter_image_embeds,
248
+ device,
249
+ batch_size * num_images_per_prompt,
250
+ self.do_classifier_free_guidance,
251
+ )
252
+
253
+ # 8. Denoising loop
254
+ num_warmup_steps = max(
255
+ len(timesteps) - num_inference_steps * self.scheduler.order,
256
+ 0,
257
+ )
258
+
259
+ # 8.1 Apply denoising_end
260
+ if (
261
+ self.denoising_end is not None
262
+ and isinstance(self.denoising_end, float)
263
+ and 0 < self.denoising_end < 1
264
+ ):
265
+ discrete_timestep_cutoff = int(
266
+ round(
267
+ self.scheduler.config.num_train_timesteps
268
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
269
+ )
270
+ )
271
+ num_inference_steps = len([ts for ts in timesteps if ts >= discrete_timestep_cutoff])
272
+ timesteps = timesteps[:num_inference_steps]
273
+
274
+ self._num_timesteps = len(timesteps)
275
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
276
+ for i, t in enumerate(timesteps):
277
+ if self.interrupt:
278
+ continue
279
+
280
+ latent_model_input = (
281
+ torch.cat([latents] * 2)
282
+ if self.do_classifier_free_guidance
283
+ else latents
284
+ )
285
+
286
+ added_cond_kwargs: Dict[str, Any] = {
287
+ "text_embeds": add_text_embeds,
288
+ "time_ids": add_time_ids,
289
+ }
290
+ added_cond_kwargs["time_ids"][:, :2] = scales[i] * 8
291
+
292
+ if image_embeds is not None:
293
+ added_cond_kwargs["image_embeds"] = image_embeds
294
+
295
+ noise_pred = self.unet(
296
+ latent_model_input,
297
+ t,
298
+ encoder_hidden_states=prompt_embeds,
299
+ cross_attention_kwargs=self.cross_attention_kwargs,
300
+ added_cond_kwargs=added_cond_kwargs,
301
+ return_dict=False,
302
+ )[0]
303
+
304
+ if self.do_classifier_free_guidance:
305
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
306
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
307
+ noise_pred_text - noise_pred_uncond
308
+ )
309
+
310
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
311
+ noise_pred = rescale_noise_cfg(
312
+ noise_pred,
313
+ noise_pred_text,
314
+ guidance_rescale=self.guidance_rescale,
315
+ )
316
+
317
+ alphas = torch.sqrt(self.scheduler.alphas_cumprod)[t]
318
+ sigmas = torch.sqrt(1 - self.scheduler.alphas_cumprod)[t]
319
+ x0_pred = (latents - sigmas * noise_pred) / alphas
320
+
321
+ if scales and i + 1 < len(scales):
322
+ x0_pred = torch.nn.functional.interpolate(
323
+ x0_pred,
324
+ size=scales[i + 1],
325
+ mode="bicubic",
326
+ )
327
+
328
+ noise = torch.randn(
329
+ x0_pred.shape,
330
+ generator=generator,
331
+ dtype=x0_pred.dtype,
332
+ device=x0_pred.device,
333
+ )
334
+
335
+ if i + 1 < len(timesteps):
336
+ next_t = timesteps[i + 1]
337
+ alphas = torch.sqrt(self.scheduler.alphas_cumprod)[next_t]
338
+ sigmas = torch.sqrt(1 - self.scheduler.alphas_cumprod)[next_t]
339
+ latents = alphas * x0_pred + sigmas * noise
340
+ else:
341
+ latents = x0_pred
342
+
343
+ latents_dtype = latents.dtype
344
+ if latents.dtype != latents_dtype:
345
+ if torch.backends.mps.is_available():
346
+ latents = latents.to(latents_dtype)
347
+
348
+ if callback_on_step_end is not None:
349
+ callback_kwargs: Dict[str, Any] = {
350
+ k: locals()[k] for k in callback_on_step_end_tensor_inputs
351
+ }
352
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
353
+
354
+ latents = callback_outputs.pop("latents", latents)
355
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
356
+ add_text_embeds = callback_outputs.pop(
357
+ "add_text_embeds", add_text_embeds
358
+ )
359
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
360
+
361
+ if (
362
+ i == len(timesteps) - 1
363
+ or (i + 1) > num_warmup_steps
364
+ and (i + 1) % self.scheduler.order == 0
365
+ ):
366
+ progress_bar.update()
367
+ if callback is not None and i % callback_steps == 0:
368
+ step_idx = i // getattr(self.scheduler, "order", 1)
369
+ callback(step_idx, t, latents)
370
+
371
+ if XLA_AVAILABLE:
372
+ xm.mark_step()
373
+
374
+ if output_type != "latent":
375
+ needs_upcasting = (
376
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
377
+ )
378
+
379
+ if needs_upcasting:
380
+ self.upcast_vae()
381
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
382
+ elif latents.dtype != self.vae.dtype:
383
+ if torch.backends.mps.is_available():
384
+ self.vae = self.vae.to(latents.dtype)
385
+
386
+ has_latents_mean = (
387
+ hasattr(self.vae.config, "latents_mean")
388
+ and self.vae.config.latents_mean is not None
389
+ )
390
+ has_latents_std = (
391
+ hasattr(self.vae.config, "latents_std")
392
+ and self.vae.config.latents_std is not None
393
+ )
394
+
395
+ if has_latents_mean and has_latents_std:
396
+ latents_mean = (
397
+ torch.tensor(self.vae.config.latents_mean)
398
+ .view(1, 4, 1, 1)
399
+ .to(latents.device, latents.dtype)
400
+ )
401
+ latents_std = (
402
+ torch.tensor(self.vae.config.latents_std)
403
+ .view(1, 4, 1, 1)
404
+ .to(latents.device, latents.dtype)
405
+ )
406
+ latents = (
407
+ latents * latents_std / self.vae.config.scaling_factor + latents_mean
408
+ )
409
+ else:
410
+ latents = latents / self.vae.config.scaling_factor
411
+
412
+ image = self.vae.decode(latents, return_dict=False)[0]
413
+
414
+ if needs_upcasting:
415
+ self.vae.to(dtype=torch.float16)
416
+ else:
417
+ image = latents
418
+
419
+ if output_type != "latent":
420
+ if self.watermark is not None:
421
+ image = self.watermark.apply_watermark(image)
422
+
423
+ image = self.image_processor.postprocess(image, output_type=output_type)
424
+
425
+ self.maybe_free_model_hooks()
426
+
427
+ if not return_dict:
428
+ return (image,)
429
+
430
+ return StableDiffusionXLPipelineOutput(images=image)