nateraw commited on
Commit
006a2b7
·
1 Parent(s): ff645b2

Create new file

Browse files
Files changed (1) hide show
  1. pipeline.py +674 -0
pipeline.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import json
3
+ import subprocess
4
+ from pathlib import Path
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ import cv2
12
+ from diffusers.configuration_utils import FrozenDict
13
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
+ from diffusers.pipeline_utils import DiffusionPipeline
15
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
+ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
18
+ from diffusers.utils import deprecate, logging
19
+ from huggingface_hub import hf_hub_download
20
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21
+
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ default_scheduler = PNDMScheduler(
26
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
27
+ )
28
+ ddim_scheduler = DDIMScheduler(
29
+ beta_start=0.00085,
30
+ beta_end=0.012,
31
+ beta_schedule="scaled_linear",
32
+ clip_sample=False,
33
+ set_alpha_to_one=False,
34
+ )
35
+ klms_scheduler = LMSDiscreteScheduler(
36
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
37
+ )
38
+ SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
39
+
40
+
41
+ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
42
+ """helper function to spherically interpolate two arrays v1 v2"""
43
+
44
+ if not isinstance(v0, np.ndarray):
45
+ inputs_are_torch = True
46
+ input_device = v0.device
47
+ v0 = v0.cpu().numpy()
48
+ v1 = v1.cpu().numpy()
49
+
50
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
51
+ if np.abs(dot) > DOT_THRESHOLD:
52
+ v2 = (1 - t) * v0 + t * v1
53
+ else:
54
+ theta_0 = np.arccos(dot)
55
+ sin_theta_0 = np.sin(theta_0)
56
+ theta_t = theta_0 * t
57
+ sin_theta_t = np.sin(theta_t)
58
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
59
+ s1 = sin_theta_t / sin_theta_0
60
+ v2 = s0 * v0 + s1 * v1
61
+
62
+ if inputs_are_torch:
63
+ v2 = torch.from_numpy(v2).to(input_device)
64
+
65
+ return v2
66
+
67
+
68
+ class RealESRGANModel(torch.nn.Module):
69
+ def __init__(self, model_path, tile=0, tile_pad=10, pre_pad=0, fp32=False):
70
+ super().__init__()
71
+ try:
72
+ from basicsr.archs.rrdbnet_arch import RRDBNet
73
+ from realesrgan import RealESRGANer
74
+ except ImportError as e:
75
+ raise ImportError(
76
+ "You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
77
+ "pip install realesrgan"
78
+ )
79
+
80
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
81
+ self.upsampler = RealESRGANer(
82
+ scale=4,
83
+ model_path=model_path,
84
+ model=model,
85
+ tile=tile,
86
+ tile_pad=tile_pad,
87
+ pre_pad=pre_pad,
88
+ half=not fp32
89
+ )
90
+
91
+ def forward(self, image, outscale=4, convert_to_pil=True):
92
+ """Upsample an image array or path.
93
+
94
+ Args:
95
+ image (Union[np.ndarray, str]): Either a np array or an image path. np array is assumed to be in RGB format,
96
+ and we convert it to BGR.
97
+ outscale (int, optional): Amount to upscale the image. Defaults to 4.
98
+ convert_to_pil (bool, optional): If True, return PIL image. Otherwise, return numpy array (BGR). Defaults to True.
99
+
100
+ Returns:
101
+ Union[np.ndarray, PIL.Image.Image]: An upsampled version of the input image.
102
+ """
103
+ if isinstance(image, (str, Path)):
104
+ img = cv2.imread(image, cv2.IMREAD_UNCHANGED)
105
+ else:
106
+ img = image
107
+ img = (img * 255).round().astype("uint8")
108
+ img = img[:, :, ::-1]
109
+
110
+ image, _ = self.upsampler.enhance(img, outscale=outscale)
111
+
112
+ if convert_to_pil:
113
+ image = Image.fromarray(image[:, :, ::-1])
114
+
115
+ return image
116
+
117
+ @classmethod
118
+ def from_pretrained(cls, model_name_or_path='nateraw/real-esrgan'):
119
+ """Initialize a pretrained Real-ESRGAN upsampler.
120
+
121
+ Example:
122
+ ```python
123
+ >>> from stable_diffusion_videos import PipelineRealESRGAN
124
+ >>> pipe = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
125
+ >>> im_out = pipe('input_img.jpg')
126
+ ```
127
+
128
+ Args:
129
+ model_name_or_path (str, optional): The Hugging Face repo ID or path to local model. Defaults to 'nateraw/real-esrgan'.
130
+
131
+ Returns:
132
+ stable_diffusion_videos.PipelineRealESRGAN: An instance of `PipelineRealESRGAN` instantiated from pretrained model.
133
+ """
134
+ # reuploaded form official ones mentioned here:
135
+ # https://github.com/xinntao/Real-ESRGAN
136
+ if Path(model_name_or_path).exists():
137
+ file = model_name_or_path
138
+ else:
139
+ file = hf_hub_download(model_name_or_path, 'RealESRGAN_x4plus.pth')
140
+ return cls(file)
141
+
142
+
143
+ def upsample_imagefolder(self, in_dir, out_dir, suffix='out', outfile_ext='.png'):
144
+ in_dir, out_dir = Path(in_dir), Path(out_dir)
145
+ if not in_dir.exists():
146
+ raise FileNotFoundError(f"Provided input directory {in_dir} does not exist")
147
+
148
+ out_dir.mkdir(exist_ok=True, parents=True)
149
+
150
+ image_paths = [x for x in in_dir.glob('*') if x.suffix.lower() in ['.png', '.jpg', '.jpeg']]
151
+ for image in image_paths:
152
+ im = self(str(image))
153
+ out_filepath = out_dir / (image.stem + suffix + outfile_ext)
154
+ im.save(out_filepath)
155
+
156
+ class NoUpsamplingModel(torch.nn.Module):
157
+
158
+ def __init__(self):
159
+ super().__init__()
160
+
161
+ def forward(self, images):
162
+ return images
163
+
164
+
165
+ def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.png", fps=30):
166
+ frame_ref_path = str(frame_dir / frame_filename)
167
+ video_path = str(frame_dir / output_file_name)
168
+ subprocess.call(
169
+ f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p"
170
+ f" {video_path}".split()
171
+ )
172
+ return video_path
173
+
174
+
175
+ class StableDiffusionWalkPipeline(DiffusionPipeline):
176
+ r"""
177
+ Pipeline for generating videos by interpolating Stable Diffusion's latent space.
178
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
179
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
180
+ Args:
181
+ vae ([`AutoencoderKL`]):
182
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
183
+ text_encoder ([`CLIPTextModel`]):
184
+ Frozen text-encoder. Stable Diffusion uses the text portion of
185
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
186
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
187
+ tokenizer (`CLIPTokenizer`):
188
+ Tokenizer of class
189
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
190
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
191
+ scheduler ([`SchedulerMixin`]):
192
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
193
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
194
+ safety_checker ([`StableDiffusionSafetyChecker`]):
195
+ Classification module that estimates whether generated images could be considered offensive or harmful.
196
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
197
+ feature_extractor ([`CLIPFeatureExtractor`]):
198
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ vae: AutoencoderKL,
204
+ text_encoder: CLIPTextModel,
205
+ tokenizer: CLIPTokenizer,
206
+ unet: UNet2DConditionModel,
207
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
208
+ safety_checker: StableDiffusionSafetyChecker,
209
+ feature_extractor: CLIPFeatureExtractor,
210
+ ):
211
+ super().__init__()
212
+
213
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
214
+ deprecation_message = (
215
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
216
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
217
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
218
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
219
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
220
+ " file"
221
+ )
222
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
223
+ new_config = dict(scheduler.config)
224
+ new_config["steps_offset"] = 1
225
+ scheduler._internal_dict = FrozenDict(new_config)
226
+
227
+ self.register_modules(
228
+ vae=vae,
229
+ text_encoder=text_encoder,
230
+ tokenizer=tokenizer,
231
+ unet=unet,
232
+ scheduler=scheduler,
233
+ safety_checker=safety_checker,
234
+ feature_extractor=feature_extractor,
235
+ )
236
+
237
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
238
+ r"""
239
+ Enable sliced attention computation.
240
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
241
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
242
+ Args:
243
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
244
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
245
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
246
+ `attention_head_dim` must be a multiple of `slice_size`.
247
+ """
248
+ if slice_size == "auto":
249
+ # half the attention head size is usually a good trade-off between
250
+ # speed and memory
251
+ slice_size = self.unet.config.attention_head_dim // 2
252
+ self.unet.set_attention_slice(slice_size)
253
+
254
+ def disable_attention_slicing(self):
255
+ r"""
256
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
257
+ back to computing attention in one step.
258
+ """
259
+ # set slice_size = `None` to disable `attention slicing`
260
+ self.enable_attention_slicing(None)
261
+
262
+ @torch.no_grad()
263
+ def step(
264
+ self,
265
+ prompt: Optional[Union[str, List[str]]] = None,
266
+ height: int = 512,
267
+ width: int = 512,
268
+ num_inference_steps: int = 50,
269
+ guidance_scale: float = 7.5,
270
+ eta: float = 0.0,
271
+ generator: Optional[torch.Generator] = None,
272
+ latents: Optional[torch.FloatTensor] = None,
273
+ output_type: Optional[str] = "pil",
274
+ return_dict: bool = True,
275
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
276
+ callback_steps: Optional[int] = 1,
277
+ text_embeddings: Optional[torch.FloatTensor] = None,
278
+ **kwargs,
279
+ ):
280
+ r"""
281
+ Function invoked when calling the pipeline for generation.
282
+ Args:
283
+ prompt (`str` or `List[str]`):
284
+ The prompt or prompts to guide the image generation.
285
+ height (`int`, *optional*, defaults to 512):
286
+ The height in pixels of the generated image.
287
+ width (`int`, *optional*, defaults to 512):
288
+ The width in pixels of the generated image.
289
+ num_inference_steps (`int`, *optional*, defaults to 50):
290
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
291
+ expense of slower inference.
292
+ guidance_scale (`float`, *optional*, defaults to 7.5):
293
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
294
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
295
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
296
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
297
+ usually at the expense of lower image quality.
298
+ eta (`float`, *optional*, defaults to 0.0):
299
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
300
+ [`schedulers.DDIMScheduler`], will be ignored for others.
301
+ generator (`torch.Generator`, *optional*):
302
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
303
+ deterministic.
304
+ latents (`torch.FloatTensor`, *optional*):
305
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
306
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
307
+ tensor will ge generated by sampling using the supplied random `generator`.
308
+ output_type (`str`, *optional*, defaults to `"pil"`):
309
+ The output format of the generate image. Choose between
310
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
311
+ return_dict (`bool`, *optional*, defaults to `True`):
312
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
313
+ plain tuple.
314
+ callback (`Callable`, *optional*):
315
+ A function that will be called every `callback_steps` steps during inference. The function will be
316
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
317
+ callback_steps (`int`, *optional*, defaults to 1):
318
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
319
+ called at every step.
320
+ text_embeddings(`torch.FloatTensor`, *optional*):
321
+ Pre-generated text embeddings.
322
+ Returns:
323
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
324
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
325
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
326
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
327
+ (nsfw) content, according to the `safety_checker`.
328
+ """
329
+
330
+ if height % 8 != 0 or width % 8 != 0:
331
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
332
+
333
+ if (callback_steps is None) or (
334
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
335
+ ):
336
+ raise ValueError(
337
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
338
+ f" {type(callback_steps)}."
339
+ )
340
+
341
+ if text_embeddings is None:
342
+ if isinstance(prompt, str):
343
+ batch_size = 1
344
+ elif isinstance(prompt, list):
345
+ batch_size = len(prompt)
346
+ else:
347
+ raise ValueError(
348
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
349
+ )
350
+
351
+ # get prompt text embeddings
352
+ text_inputs = self.tokenizer(
353
+ prompt,
354
+ padding="max_length",
355
+ max_length=self.tokenizer.model_max_length,
356
+ return_tensors="pt",
357
+ )
358
+ text_input_ids = text_inputs.input_ids
359
+
360
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
361
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
362
+ logger.warning(
363
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
364
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
365
+ )
366
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
367
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
368
+ else:
369
+ batch_size = text_embeddings.shape[0]
370
+
371
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
372
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
373
+ # corresponds to doing no classifier free guidance.
374
+ do_classifier_free_guidance = guidance_scale > 1.0
375
+ # get unconditional embeddings for classifier free guidance
376
+ if do_classifier_free_guidance:
377
+ # HACK - Not setting text_input_ids here when walking, so hard coding to max length of tokenizer
378
+ # TODO - Determine if this is OK to do
379
+ # max_length = text_input_ids.shape[-1]
380
+ max_length = self.tokenizer.model_max_length
381
+ uncond_input = self.tokenizer(
382
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
383
+ )
384
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
385
+
386
+ # For classifier free guidance, we need to do two forward passes.
387
+ # Here we concatenate the unconditional and text embeddings into a single batch
388
+ # to avoid doing two forward passes
389
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
390
+
391
+ # get the initial random noise unless the user supplied it
392
+
393
+ # Unlike in other pipelines, latents need to be generated in the target device
394
+ # for 1-to-1 results reproducibility with the CompVis implementation.
395
+ # However this currently doesn't work in `mps`.
396
+ latents_device = "cpu" if self.device.type == "mps" else self.device
397
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
398
+ if latents is None:
399
+ latents = torch.randn(
400
+ latents_shape,
401
+ generator=generator,
402
+ device=latents_device,
403
+ dtype=text_embeddings.dtype,
404
+ )
405
+ else:
406
+ if latents.shape != latents_shape:
407
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
408
+ latents = latents.to(latents_device)
409
+
410
+ # set timesteps
411
+ self.scheduler.set_timesteps(num_inference_steps)
412
+
413
+ # Some schedulers like PNDM have timesteps as arrays
414
+ # It's more optimzed to move all timesteps to correct device beforehand
415
+ if torch.is_tensor(self.scheduler.timesteps):
416
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
417
+ else:
418
+ timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
419
+
420
+ # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
421
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
422
+ latents = latents * self.scheduler.sigmas[0]
423
+
424
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
425
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
426
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
427
+ # and should be between [0, 1]
428
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
429
+ extra_step_kwargs = {}
430
+ if accepts_eta:
431
+ extra_step_kwargs["eta"] = eta
432
+
433
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
434
+ # expand the latents if we are doing classifier free guidance
435
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
436
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
437
+ sigma = self.scheduler.sigmas[i]
438
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
439
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
440
+
441
+ # predict the noise residual
442
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
443
+
444
+ # perform guidance
445
+ if do_classifier_free_guidance:
446
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
447
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
448
+
449
+ # compute the previous noisy sample x_t -> x_t-1
450
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
451
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
452
+ else:
453
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
454
+
455
+ # call the callback, if provided
456
+ if callback is not None and i % callback_steps == 0:
457
+ callback(i, t, latents)
458
+
459
+ latents = 1 / 0.18215 * latents
460
+ image = self.vae.decode(latents).sample
461
+
462
+ image = (image / 2 + 0.5).clamp(0, 1)
463
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
464
+
465
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
466
+ image, has_nsfw_concept = self.safety_checker(
467
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
468
+ )
469
+
470
+ if output_type == "pil":
471
+ image = self.numpy_to_pil(image)
472
+
473
+ if not return_dict:
474
+ return (image, has_nsfw_concept)
475
+
476
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
477
+
478
+ def __call__(
479
+ self,
480
+ prompts: List[str] = ["blueberry spaghetti", "strawberry spaghetti"],
481
+ seeds: List[int] = [42, 123],
482
+ num_interpolation_steps: Union[int, List[int]] = 5,
483
+ output_dir: str = "dreams",
484
+ name: str = "berry_good_spaghetti",
485
+ height: int = 512,
486
+ width: int = 512,
487
+ guidance_scale: float = 7.5,
488
+ eta: float = 0.0,
489
+ num_inference_steps: int = 50,
490
+ do_loop: bool = False,
491
+ make_video: bool = False,
492
+ use_lerp_for_text: bool = True,
493
+ scheduler: str = "klms", # choices: default, ddim, klms
494
+ disable_tqdm: bool = False,
495
+ upsample: bool = False,
496
+ fps: int = 30,
497
+ resume: bool = False,
498
+ batch_size: int = 1,
499
+ frame_filename_ext: str = '.png',
500
+ ):
501
+ if upsample:
502
+ if getattr(self, 'upsampler', None) is None:
503
+ self.upsampler = RealESRGANModel.from_pretrained('nateraw/real-esrgan')
504
+ self.upsampler.to(self.device)
505
+
506
+ output_path = Path(output_dir) / name
507
+ output_path.mkdir(exist_ok=True, parents=True)
508
+ prompt_config_path = output_path / 'prompt_config.json'
509
+
510
+ if not resume:
511
+ # Write prompt info to file in output dir so we can keep track of what we did
512
+ prompt_config_path.write_text(
513
+ json.dumps(
514
+ dict(
515
+ prompts=prompts,
516
+ seeds=seeds,
517
+ num_interpolation_steps=num_interpolation_steps,
518
+ name=name,
519
+ guidance_scale=guidance_scale,
520
+ eta=eta,
521
+ num_inference_steps=num_inference_steps,
522
+ do_loop=do_loop,
523
+ make_video=make_video,
524
+ use_lerp_for_text=use_lerp_for_text,
525
+ scheduler=scheduler,
526
+ upsample=upsample,
527
+ fps=fps,
528
+ height=height,
529
+ width=width,
530
+ ),
531
+ indent=2,
532
+ sort_keys=False,
533
+ )
534
+ )
535
+ else:
536
+ # When resuming, we load all available info from existing prompt config, using kwargs passed in where necessary
537
+ if not prompt_config_path.exists():
538
+ raise FileNotFoundError(f"You specified resume=True, but no prompt config file was found at {prompt_config_path}")
539
+
540
+ data = json.load(open(prompt_config_path))
541
+ prompts = data['prompts']
542
+ seeds = data['seeds']
543
+ # NOTE - num_steps was renamed to num_interpolation_steps. Including it here for backwards compatibility.
544
+ num_interpolation_steps = data.get('num_interpolation_steps') or data.get('num_steps')
545
+ height = data['height'] if 'height' in data else height
546
+ width = data['width'] if 'width' in data else width
547
+ guidance_scale = data['guidance_scale']
548
+ eta = data['eta']
549
+ num_inference_steps = data['num_inference_steps']
550
+ do_loop = data['do_loop']
551
+ make_video = data['make_video']
552
+ use_lerp_for_text = data['use_lerp_for_text']
553
+ scheduler = data['scheduler']
554
+ disable_tqdm=disable_tqdm
555
+ upsample = data['upsample'] if 'upsample' in data else upsample
556
+ fps = data['fps'] if 'fps' in data else fps
557
+
558
+ resume_step = int(sorted(output_path.glob(f"frame*{frame_filename_ext}"))[-1].stem[5:])
559
+ print(f"\nResuming {output_path} from step {resume_step}...")
560
+
561
+ self.set_progress_bar_config(disable=disable_tqdm)
562
+ self.scheduler = SCHEDULERS[scheduler]
563
+
564
+ if isinstance(num_interpolation_steps, int):
565
+ num_interpolation_steps = [num_interpolation_steps] * (len(prompts)-1)
566
+
567
+ assert len(prompts) == len(seeds) == len(num_interpolation_steps) +1
568
+
569
+ first_prompt, *prompts = prompts
570
+ embeds_a = self.embed_text(first_prompt)
571
+
572
+ first_seed, *seeds = seeds
573
+
574
+ latents_a = torch.randn(
575
+ (1, self.unet.in_channels, height // 8, width // 8),
576
+ device=self.device,
577
+ generator=torch.Generator(device=self.device).manual_seed(first_seed),
578
+ )
579
+
580
+ if do_loop:
581
+ prompts.append(first_prompt)
582
+ seeds.append(first_seed)
583
+ num_interpolation_steps.append(num_interpolation_steps[0])
584
+
585
+
586
+ frame_index = 0
587
+ total_frame_count = sum(num_interpolation_steps)
588
+ for prompt, seed, num_step in zip(prompts, seeds, num_interpolation_steps):
589
+ # Text
590
+ embeds_b = self.embed_text(prompt)
591
+
592
+ # Latent Noise
593
+ latents_b = torch.randn(
594
+ (1, self.unet.in_channels, height // 8, width // 8),
595
+ device=self.device,
596
+ generator=torch.Generator(device=self.device).manual_seed(seed),
597
+ )
598
+
599
+ latents_batch, embeds_batch = None, None
600
+ for i, t in enumerate(np.linspace(0, 1, num_step)):
601
+
602
+ frame_filepath = output_path / (f"frame%06d{frame_filename_ext}" % frame_index)
603
+ if resume and frame_filepath.is_file():
604
+ frame_index += 1
605
+ continue
606
+
607
+ if use_lerp_for_text:
608
+ embeds = torch.lerp(embeds_a, embeds_b, float(t))
609
+ else:
610
+ embeds = slerp(float(t), embeds_a, embeds_b)
611
+ latents = slerp(float(t), latents_a, latents_b)
612
+
613
+ embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
614
+ latents_batch = latents if latents_batch is None else torch.cat([latents_batch, latents])
615
+
616
+ del embeds
617
+ del latents
618
+ torch.cuda.empty_cache()
619
+
620
+ batch_is_ready = embeds_batch.shape[0] == batch_size or t == 1.0
621
+ if not batch_is_ready:
622
+ continue
623
+
624
+ do_print_progress = (i == 0) or ((frame_index) % 20 == 0)
625
+ if do_print_progress:
626
+ print(f"COUNT: {frame_index}/{total_frame_count}")
627
+
628
+ with torch.autocast("cuda"):
629
+ outputs = self.step(
630
+ latents=latents_batch,
631
+ text_embeddings=embeds_batch,
632
+ height=height,
633
+ width=width,
634
+ guidance_scale=guidance_scale,
635
+ eta=eta,
636
+ num_inference_steps=num_inference_steps,
637
+ output_type='pil' if not upsample else 'numpy'
638
+ )["sample"]
639
+
640
+ del embeds_batch
641
+ del latents_batch
642
+ torch.cuda.empty_cache()
643
+ latents_batch, embeds_batch = None, None
644
+
645
+ if upsample:
646
+ images = []
647
+ for output in outputs:
648
+ images.append(self.upsampler(output))
649
+ else:
650
+ images = outputs
651
+ for image in images:
652
+ frame_filepath = output_path / (f"frame%06d{frame_filename_ext}" % frame_index)
653
+ image.save(frame_filepath)
654
+ frame_index += 1
655
+
656
+ embeds_a = embeds_b
657
+ latents_a = latents_b
658
+
659
+ if make_video:
660
+ return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps, frame_filename=f"frame%06d{frame_filename_ext}")
661
+
662
+ def embed_text(self, text):
663
+ """Helper to embed some text"""
664
+ with torch.autocast("cuda"):
665
+ text_input = self.tokenizer(
666
+ text,
667
+ padding="max_length",
668
+ max_length=self.tokenizer.model_max_length,
669
+ truncation=True,
670
+ return_tensors="pt",
671
+ )
672
+ with torch.no_grad():
673
+ embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
674
+ return embed