mythicalguy commited on
Commit
fef867e
·
1 Parent(s): 1b25ea7

Add imagic_stable_diffusion.py and update requirements

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. imagic_stable_diffusion.py +470 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -16,7 +16,7 @@ def infer(
16
  randomize_seed=False,
17
  guidance_scale=10.0,
18
  num_inference_steps=30,
19
- num_opt_steps=500, # Fine-tuning steps; lower if timeouts occur
20
  embed_lr=1e-5,
21
  model_lr=1e-6,
22
  num_inner_steps=1,
 
16
  randomize_seed=False,
17
  guidance_scale=10.0,
18
  num_inference_steps=30,
19
+ num_opt_steps=500,
20
  embed_lr=1e-5,
21
  model_lr=1e-6,
22
  num_inner_steps=1,
imagic_stable_diffusion.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeled after the textual_inversion.py / train_dreambooth.py and the work
3
+ of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
4
+ """
5
+
6
+ import inspect
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from accelerate import Accelerator
15
+
16
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
17
+ from packaging import version
18
+ from tqdm.auto import tqdm
19
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
20
+
21
+ from diffusers import DiffusionPipeline
22
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
23
+ from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
24
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
25
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
26
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
27
+ from diffusers.utils import logging
28
+
29
+
30
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
31
+ PIL_INTERPOLATION = {
32
+ "linear": PIL.Image.Resampling.BILINEAR,
33
+ "bilinear": PIL.Image.Resampling.BILINEAR,
34
+ "bicubic": PIL.Image.Resampling.BICUBIC,
35
+ "lanczos": PIL.Image.Resampling.LANCZOS,
36
+ "nearest": PIL.Image.Resampling.NEAREST,
37
+ }
38
+ else:
39
+ PIL_INTERPOLATION = {
40
+ "linear": PIL.Image.LINEAR,
41
+ "bilinear": PIL.Image.BILINEAR,
42
+ "bicubic": PIL.Image.BICUBIC,
43
+ "lanczos": PIL.Image.LANCZOS,
44
+ "nearest": PIL.Image.NEAREST,
45
+ }
46
+ # ------------------------------------------------------------------------------
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ def preprocess(image):
52
+ w, h = image.size
53
+ w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
54
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
55
+ image = np.array(image).astype(np.float32) / 255.0
56
+ image = image[None].transpose(0, 3, 1, 2)
57
+ image = torch.from_numpy(image)
58
+ return 2.0 * image - 1.0
59
+
60
+
61
+ class ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
62
+ r"""
63
+ Pipeline for imagic image editing.
64
+ See paper here: https://huggingface.co/papers/2210.09276
65
+
66
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
67
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
68
+ Args:
69
+ vae ([`AutoencoderKL`]):
70
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
71
+ text_encoder ([`CLIPTextModel`]):
72
+ Frozen text-encoder. Stable Diffusion uses the text portion of
73
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
74
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
75
+ tokenizer (`CLIPTokenizer`):
76
+ Tokenizer of class
77
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
78
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
79
+ scheduler ([`SchedulerMixin`]):
80
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
81
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
82
+ safety_checker ([`StableDiffusionSafetyChecker`]):
83
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
84
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
85
+ feature_extractor ([`CLIPImageProcessor`]):
86
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ vae: AutoencoderKL,
92
+ text_encoder: CLIPTextModel,
93
+ tokenizer: CLIPTokenizer,
94
+ unet: UNet2DConditionModel,
95
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
96
+ safety_checker: StableDiffusionSafetyChecker,
97
+ feature_extractor: CLIPImageProcessor,
98
+ ):
99
+ super().__init__()
100
+ self.register_modules(
101
+ vae=vae,
102
+ text_encoder=text_encoder,
103
+ tokenizer=tokenizer,
104
+ unet=unet,
105
+ scheduler=scheduler,
106
+ safety_checker=safety_checker,
107
+ feature_extractor=feature_extractor,
108
+ )
109
+
110
+ def train(
111
+ self,
112
+ prompt: Union[str, List[str]],
113
+ image: Union[torch.Tensor, PIL.Image.Image],
114
+ height: Optional[int] = 512,
115
+ width: Optional[int] = 512,
116
+ generator: Optional[torch.Generator] = None,
117
+ embedding_learning_rate: float = 0.001,
118
+ diffusion_model_learning_rate: float = 2e-6,
119
+ text_embedding_optimization_steps: int = 500,
120
+ model_fine_tuning_optimization_steps: int = 1000,
121
+ **kwargs,
122
+ ):
123
+ r"""
124
+ Function invoked when calling the pipeline for generation.
125
+ Args:
126
+ prompt (`str` or `List[str]`):
127
+ The prompt or prompts to guide the image generation.
128
+ height (`int`, *optional*, defaults to 512):
129
+ The height in pixels of the generated image.
130
+ width (`int`, *optional*, defaults to 512):
131
+ The width in pixels of the generated image.
132
+ num_inference_steps (`int`, *optional*, defaults to 50):
133
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
134
+ expense of slower inference.
135
+ guidance_scale (`float`, *optional*, defaults to 7.5):
136
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).
137
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
138
+ Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >
139
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
140
+ usually at the expense of lower image quality.
141
+ eta (`float`, *optional*, defaults to 0.0):
142
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to
143
+ [`schedulers.DDIMScheduler`], will be ignored for others.
144
+ generator (`torch.Generator`, *optional*):
145
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
146
+ deterministic.
147
+ latents (`torch.Tensor`, *optional*):
148
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
149
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
150
+ tensor will ge generated by sampling using the supplied random `generator`.
151
+ output_type (`str`, *optional*, defaults to `"pil"`):
152
+ The output format of the generate image. Choose between
153
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
154
+ return_dict (`bool`, *optional*, defaults to `True`):
155
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
156
+ plain tuple.
157
+ Returns:
158
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
159
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
160
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
161
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
162
+ (nsfw) content, according to the `safety_checker`.
163
+ """
164
+ accelerator = Accelerator(
165
+ gradient_accumulation_steps=1,
166
+ mixed_precision="fp16",
167
+ )
168
+
169
+ if "torch_device" in kwargs:
170
+ device = kwargs.pop("torch_device")
171
+ warnings.warn(
172
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
173
+ " Consider using `pipe.to(torch_device)` instead."
174
+ )
175
+
176
+ if device is None:
177
+ device = "cuda" if torch.cuda.is_available() else "cpu"
178
+ self.to(device)
179
+
180
+ if height % 8 != 0 or width % 8 != 0:
181
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
182
+
183
+ # Freeze vae and unet
184
+ self.vae.requires_grad_(False)
185
+ self.unet.requires_grad_(False)
186
+ self.text_encoder.requires_grad_(False)
187
+ self.unet.eval()
188
+ self.vae.eval()
189
+ self.text_encoder.eval()
190
+
191
+ if accelerator.is_main_process:
192
+ accelerator.init_trackers(
193
+ "imagic",
194
+ config={
195
+ "embedding_learning_rate": embedding_learning_rate,
196
+ "text_embedding_optimization_steps": text_embedding_optimization_steps,
197
+ },
198
+ )
199
+
200
+ # get text embeddings for prompt
201
+ text_input = self.tokenizer(
202
+ prompt,
203
+ padding="max_length",
204
+ max_length=self.tokenizer.model_max_length,
205
+ truncation=True,
206
+ return_tensors="pt",
207
+ )
208
+ text_embeddings = torch.nn.Parameter(
209
+ self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
210
+ )
211
+ text_embeddings = text_embeddings.detach()
212
+ text_embeddings.requires_grad_()
213
+ text_embeddings_orig = text_embeddings.clone()
214
+
215
+ # Initialize the optimizer
216
+ optimizer = torch.optim.Adam(
217
+ [text_embeddings], # only optimize the embeddings
218
+ lr=embedding_learning_rate,
219
+ )
220
+
221
+ if isinstance(image, PIL.Image.Image):
222
+ image = preprocess(image)
223
+
224
+ latents_dtype = text_embeddings.dtype
225
+ image = image.to(device=self.device, dtype=latents_dtype)
226
+ init_latent_image_dist = self.vae.encode(image).latent_dist
227
+ image_latents = init_latent_image_dist.sample(generator=generator)
228
+ image_latents = 0.18215 * image_latents
229
+
230
+ progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
231
+ progress_bar.set_description("Steps")
232
+
233
+ global_step = 0
234
+
235
+ logger.info("First optimizing the text embedding to better reconstruct the init image")
236
+ for _ in range(text_embedding_optimization_steps):
237
+ with accelerator.accumulate(text_embeddings):
238
+ # Sample noise that we'll add to the latents
239
+ noise = torch.randn(image_latents.shape).to(image_latents.device)
240
+ timesteps = torch.randint(1000, (1,), device=image_latents.device)
241
+
242
+ # Add noise to the latents according to the noise magnitude at each timestep
243
+ # (this is the forward diffusion process)
244
+ noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
245
+
246
+ # Predict the noise residual
247
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
248
+
249
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
250
+ accelerator.backward(loss)
251
+
252
+ optimizer.step()
253
+ optimizer.zero_grad()
254
+
255
+ # Checks if the accelerator has performed an optimization step behind the scenes
256
+ if accelerator.sync_gradients:
257
+ progress_bar.update(1)
258
+ global_step += 1
259
+
260
+ logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
261
+ progress_bar.set_postfix(**logs)
262
+ accelerator.log(logs, step=global_step)
263
+
264
+ accelerator.wait_for_everyone()
265
+
266
+ text_embeddings.requires_grad_(False)
267
+
268
+ # Now we fine tune the unet to better reconstruct the image
269
+ self.unet.requires_grad_(True)
270
+ self.unet.train()
271
+ optimizer = torch.optim.Adam(
272
+ self.unet.parameters(), # only optimize unet
273
+ lr=diffusion_model_learning_rate,
274
+ )
275
+ progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
276
+
277
+ logger.info("Next fine tuning the entire model to better reconstruct the init image")
278
+ for _ in range(model_fine_tuning_optimization_steps):
279
+ with accelerator.accumulate(self.unet.parameters()):
280
+ # Sample noise that we'll add to the latents
281
+ noise = torch.randn(image_latents.shape).to(image_latents.device)
282
+ timesteps = torch.randint(1000, (1,), device=image_latents.device)
283
+
284
+ # Add noise to the latents according to the noise magnitude at each timestep
285
+ # (this is the forward diffusion process)
286
+ noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps)
287
+
288
+ # Predict the noise residual
289
+ noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
290
+
291
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
292
+ accelerator.backward(loss)
293
+
294
+ optimizer.step()
295
+ optimizer.zero_grad()
296
+
297
+ # Checks if the accelerator has performed an optimization step behind the scenes
298
+ if accelerator.sync_gradients:
299
+ progress_bar.update(1)
300
+ global_step += 1
301
+
302
+ logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
303
+ progress_bar.set_postfix(**logs)
304
+ accelerator.log(logs, step=global_step)
305
+
306
+ accelerator.wait_for_everyone()
307
+ self.text_embeddings_orig = text_embeddings_orig
308
+ self.text_embeddings = text_embeddings
309
+
310
+ @torch.no_grad()
311
+ def __call__(
312
+ self,
313
+ alpha: float = 1.2,
314
+ height: Optional[int] = 512,
315
+ width: Optional[int] = 512,
316
+ num_inference_steps: Optional[int] = 50,
317
+ generator: Optional[torch.Generator] = None,
318
+ output_type: Optional[str] = "pil",
319
+ return_dict: bool = True,
320
+ guidance_scale: float = 7.5,
321
+ eta: float = 0.0,
322
+ ):
323
+ r"""
324
+ Function invoked when calling the pipeline for generation.
325
+ Args:
326
+ alpha (`float`, *optional*, defaults to 1.2):
327
+ The interpolation factor between the original and optimized text embeddings. A value closer to 0
328
+ will resemble the original input image.
329
+ height (`int`, *optional*, defaults to 512):
330
+ The height in pixels of the generated image.
331
+ width (`int`, *optional*, defaults to 512):
332
+ The width in pixels of the generated image.
333
+ num_inference_steps (`int`, *optional*, defaults to 50):
334
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
335
+ expense of slower inference.
336
+ guidance_scale (`float`, *optional*, defaults to 7.5):
337
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).
338
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
339
+ Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >
340
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
341
+ usually at the expense of lower image quality.
342
+ generator (`torch.Generator`, *optional*):
343
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
344
+ deterministic.
345
+ output_type (`str`, *optional*, defaults to `"pil"`):
346
+ The output format of the generate image. Choose between
347
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
348
+ return_dict (`bool`, *optional*, defaults to `True`):
349
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
350
+ plain tuple.
351
+ eta (`float`, *optional*, defaults to 0.0):
352
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to
353
+ [`schedulers.DDIMScheduler`], will be ignored for others.
354
+ Returns:
355
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
356
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
357
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
358
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
359
+ (nsfw) content, according to the `safety_checker`.
360
+ """
361
+ if height % 8 != 0 or width % 8 != 0:
362
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
363
+ if self.text_embeddings is None:
364
+ raise ValueError("Please run the pipe.train() before trying to generate an image.")
365
+ if self.text_embeddings_orig is None:
366
+ raise ValueError("Please run the pipe.train() before trying to generate an image.")
367
+
368
+ text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
369
+
370
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
371
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
372
+ # corresponds to doing no classifier free guidance.
373
+ do_classifier_free_guidance = guidance_scale > 1.0
374
+ # get unconditional embeddings for classifier free guidance
375
+ if do_classifier_free_guidance:
376
+ uncond_tokens = [""]
377
+ max_length = self.tokenizer.model_max_length
378
+ uncond_input = self.tokenizer(
379
+ uncond_tokens,
380
+ padding="max_length",
381
+ max_length=max_length,
382
+ truncation=True,
383
+ return_tensors="pt",
384
+ )
385
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
386
+
387
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
388
+ seq_len = uncond_embeddings.shape[1]
389
+ uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
390
+
391
+ # For classifier free guidance, we need to do two forward passes.
392
+ # Here we concatenate the unconditional and text embeddings into a single batch
393
+ # to avoid doing two forward passes
394
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
395
+
396
+ # get the initial random noise unless the user supplied it
397
+
398
+ # Unlike in other pipelines, latents need to be generated in the target device
399
+ # for 1-to-1 results reproducibility with the CompVis implementation.
400
+ # However this currently doesn't work in `mps`.
401
+ latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8)
402
+ latents_dtype = text_embeddings.dtype
403
+ if self.device.type == "mps":
404
+ # randn does not exist on mps
405
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
406
+ self.device
407
+ )
408
+ else:
409
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
410
+
411
+ # set timesteps
412
+ self.scheduler.set_timesteps(num_inference_steps)
413
+
414
+ # Some schedulers like PNDM have timesteps as arrays
415
+ # It's more optimized to move all timesteps to correct device beforehand
416
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
417
+
418
+ # scale the initial noise by the standard deviation required by the scheduler
419
+ latents = latents * self.scheduler.init_noise_sigma
420
+
421
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
422
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
423
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
424
+ # and should be between [0, 1]
425
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
426
+ extra_step_kwargs = {}
427
+ if accepts_eta:
428
+ extra_step_kwargs["eta"] = eta
429
+
430
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
431
+ # expand the latents if we are doing classifier free guidance
432
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
433
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
434
+
435
+ # predict the noise residual
436
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
437
+
438
+ # perform guidance
439
+ if do_classifier_free_guidance:
440
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
441
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
442
+
443
+ # compute the previous noisy sample x_t -> x_t-1
444
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
445
+
446
+ latents = 1 / 0.18215 * latents
447
+ image = self.vae.decode(latents).sample
448
+
449
+ image = (image / 2 + 0.5).clamp(0, 1)
450
+
451
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
452
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
453
+
454
+ if self.safety_checker is not None:
455
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
456
+ self.device
457
+ )
458
+ image, has_nsfw_concept = self.safety_checker(
459
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
460
+ )
461
+ else:
462
+ has_nsfw_concept = None
463
+
464
+ if output_type == "pil":
465
+ image = self.numpy_to_pil(image)
466
+
467
+ if not return_dict:
468
+ return (image, has_nsfw_concept)
469
+
470
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
requirements.txt CHANGED
@@ -5,4 +5,6 @@ spaces
5
  numpy
6
  pillow
7
  torch
8
- transformers
 
 
 
5
  numpy
6
  pillow
7
  torch
8
+ transformers
9
+ safetensors
10
+ tqdm