AlanB commited on
Commit
a0795ca
·
1 Parent(s): 181c850

Update pipeline.py

Browse files

Updates from latest git. I readded prompt tweening functions..

Files changed (1) hide show
  1. pipeline.py +322 -53
pipeline.py CHANGED
@@ -8,14 +8,37 @@ import random
8
  import sys
9
  from tqdm.auto import tqdm
10
 
 
11
  import PIL
12
  from diffusers import SchedulerMixin, StableDiffusionPipeline
13
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
15
- from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
 
16
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
21
  re_attention = re.compile(
@@ -407,27 +430,75 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
407
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
408
  """
409
 
410
- def __init__(
411
- self,
412
- vae: AutoencoderKL,
413
- text_encoder: CLIPTextModel,
414
- tokenizer: CLIPTokenizer,
415
- unet: UNet2DConditionModel,
416
- scheduler: SchedulerMixin,
417
- safety_checker: StableDiffusionSafetyChecker,
418
- feature_extractor: CLIPFeatureExtractor,
419
- requires_safety_checker: bool = True,
420
- ):
421
- super().__init__(
422
- vae=vae,
423
- text_encoder=text_encoder,
424
- tokenizer=tokenizer,
425
- unet=unet,
426
- scheduler=scheduler,
427
- safety_checker=safety_checker,
428
- feature_extractor=feature_extractor,
429
- requires_safety_checker=requires_safety_checker,
430
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  def _encode_prompt(
433
  self,
@@ -755,37 +826,33 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
755
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
756
 
757
  # 8. Denoising loop
758
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
759
- with self.progress_bar(total=num_inference_steps) as progress_bar:
760
- for i, t in enumerate(timesteps):
761
- # expand the latents if we are doing classifier free guidance
762
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
763
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
764
-
765
- # predict the noise residual
766
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
767
-
768
- # perform guidance
769
- if do_classifier_free_guidance:
770
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
771
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
772
-
773
- # compute the previous noisy sample x_t -> x_t-1
774
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
775
-
776
- if mask is not None:
777
- # masking
778
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
779
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
780
-
781
- # call the callback, if provided
782
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
783
- progress_bar.update()
784
- if i % callback_steps == 0:
785
- if callback is not None:
786
- callback(i, t, latents)
787
- if is_cancelled_callback is not None and is_cancelled_callback():
788
- return None
789
 
790
  # 9. Post-processing
791
  image = self.decode_latents(latents)
@@ -1096,6 +1163,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1096
  callback_steps=callback_steps,
1097
  **kwargs,
1098
  )
 
 
1099
  # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1100
  def get_text_latent_space(self, prompt, guidance_scale = 7.5):
1101
  # get prompt text embeddings
@@ -1107,3 +1176,203 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1107
  return_tensors="pt",
1108
  )
1109
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import sys
9
  from tqdm.auto import tqdm
10
 
11
+ import diffusers
12
  import PIL
13
  from diffusers import SchedulerMixin, StableDiffusionPipeline
14
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
15
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
16
+ from diffusers.utils import deprecate, logging
17
+ from packaging import version
18
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
19
 
20
 
21
+ try:
22
+ from diffusers.utils import PIL_INTERPOLATION
23
+ except ImportError:
24
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
25
+ PIL_INTERPOLATION = {
26
+ "linear": PIL.Image.Resampling.BILINEAR,
27
+ "bilinear": PIL.Image.Resampling.BILINEAR,
28
+ "bicubic": PIL.Image.Resampling.BICUBIC,
29
+ "lanczos": PIL.Image.Resampling.LANCZOS,
30
+ "nearest": PIL.Image.Resampling.NEAREST,
31
+ }
32
+ else:
33
+ PIL_INTERPOLATION = {
34
+ "linear": PIL.Image.LINEAR,
35
+ "bilinear": PIL.Image.BILINEAR,
36
+ "bicubic": PIL.Image.BICUBIC,
37
+ "lanczos": PIL.Image.LANCZOS,
38
+ "nearest": PIL.Image.NEAREST,
39
+ }
40
+ # ------------------------------------------------------------------------------
41
+
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
 
44
  re_attention = re.compile(
 
430
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
431
  """
432
 
433
+ if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
434
+
435
+ def __init__(
436
+ self,
437
+ vae: AutoencoderKL,
438
+ text_encoder: CLIPTextModel,
439
+ tokenizer: CLIPTokenizer,
440
+ unet: UNet2DConditionModel,
441
+ scheduler: SchedulerMixin,
442
+ safety_checker: StableDiffusionSafetyChecker,
443
+ feature_extractor: CLIPFeatureExtractor,
444
+ requires_safety_checker: bool = True,
445
+ ):
446
+ super().__init__(
447
+ vae=vae,
448
+ text_encoder=text_encoder,
449
+ tokenizer=tokenizer,
450
+ unet=unet,
451
+ scheduler=scheduler,
452
+ safety_checker=safety_checker,
453
+ feature_extractor=feature_extractor,
454
+ requires_safety_checker=requires_safety_checker,
455
+ )
456
+ self.__init__additional__()
457
+
458
+ else:
459
+
460
+ def __init__(
461
+ self,
462
+ vae: AutoencoderKL,
463
+ text_encoder: CLIPTextModel,
464
+ tokenizer: CLIPTokenizer,
465
+ unet: UNet2DConditionModel,
466
+ scheduler: SchedulerMixin,
467
+ safety_checker: StableDiffusionSafetyChecker,
468
+ feature_extractor: CLIPFeatureExtractor,
469
+ ):
470
+ super().__init__(
471
+ vae=vae,
472
+ text_encoder=text_encoder,
473
+ tokenizer=tokenizer,
474
+ unet=unet,
475
+ scheduler=scheduler,
476
+ safety_checker=safety_checker,
477
+ feature_extractor=feature_extractor,
478
+ )
479
+ self.__init__additional__()
480
+
481
+ def __init__additional__(self):
482
+ if not hasattr(self, "vae_scale_factor"):
483
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
484
+
485
+ @property
486
+ def _execution_device(self):
487
+ r"""
488
+ Returns the device on which the pipeline's models will be executed. After calling
489
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
490
+ hooks.
491
+ """
492
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
493
+ return self.device
494
+ for module in self.unet.modules():
495
+ if (
496
+ hasattr(module, "_hf_hook")
497
+ and hasattr(module._hf_hook, "execution_device")
498
+ and module._hf_hook.execution_device is not None
499
+ ):
500
+ return torch.device(module._hf_hook.execution_device)
501
+ return self.device
502
 
503
  def _encode_prompt(
504
  self,
 
826
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
827
 
828
  # 8. Denoising loop
829
+ for i, t in enumerate(self.progress_bar(timesteps)):
830
+ # expand the latents if we are doing classifier free guidance
831
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
832
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
833
+
834
+ # predict the noise residual
835
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
836
+
837
+ # perform guidance
838
+ if do_classifier_free_guidance:
839
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
840
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
841
+
842
+ # compute the previous noisy sample x_t -> x_t-1
843
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
844
+
845
+ if mask is not None:
846
+ # masking
847
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
848
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
849
+
850
+ # call the callback, if provided
851
+ if i % callback_steps == 0:
852
+ if callback is not None:
853
+ callback(i, t, latents)
854
+ if is_cancelled_callback is not None and is_cancelled_callback():
855
+ return None
 
 
 
 
856
 
857
  # 9. Post-processing
858
  image = self.decode_latents(latents)
 
1163
  callback_steps=callback_steps,
1164
  **kwargs,
1165
  )
1166
+
1167
+
1168
  # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1169
  def get_text_latent_space(self, prompt, guidance_scale = 7.5):
1170
  # get prompt text embeddings
 
1176
  return_tensors="pt",
1177
  )
1178
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1179
+
1180
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1181
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1182
+ # corresponds to doing no classifier free guidance.
1183
+ do_classifier_free_guidance = guidance_scale > 1.0
1184
+ # get unconditional embeddings for classifier free guidance
1185
+ if do_classifier_free_guidance:
1186
+ max_length = text_input.input_ids.shape[-1]
1187
+ uncond_input = self.tokenizer(
1188
+ [""], padding="max_length", max_length=max_length, return_tensors="pt"
1189
+ )
1190
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
1191
+
1192
+ # For classifier free guidance, we need to do two forward passes.
1193
+ # Here we concatenate the unconditional and text embeddings into a single batch
1194
+ # to avoid doing two forward passes
1195
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1196
+
1197
+ return text_embeddings
1198
+
1199
+ def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1200
+ """ helper function to spherically interpolate two arrays v1 v2
1201
+ from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
1202
+ this should be better than lerping for moving between noise spaces """
1203
+
1204
+ if not isinstance(v0, np.ndarray):
1205
+ inputs_are_torch = True
1206
+ input_device = v0.device
1207
+ v0 = v0.cpu().numpy()
1208
+ v1 = v1.cpu().numpy()
1209
+
1210
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
1211
+ if np.abs(dot) > DOT_THRESHOLD:
1212
+ v2 = (1 - t) * v0 + t * v1
1213
+ else:
1214
+ theta_0 = np.arccos(dot)
1215
+ sin_theta_0 = np.sin(theta_0)
1216
+ theta_t = theta_0 * t
1217
+ sin_theta_t = np.sin(theta_t)
1218
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
1219
+ s1 = sin_theta_t / sin_theta_0
1220
+ v2 = s0 * v0 + s1 * v1
1221
+
1222
+ if inputs_are_torch:
1223
+ v2 = torch.from_numpy(v2).to(input_device)
1224
+
1225
+ return v2
1226
+
1227
+ def lerp_between_prompts(self, first_prompt, second_prompt, seed = None, length = 10, save=False, guidance_scale: Optional[float] = 7.5, **kwargs):
1228
+ first_embedding = self.get_text_latent_space(first_prompt)
1229
+ second_embedding = self.get_text_latent_space(second_prompt)
1230
+ if not seed:
1231
+ seed = random.randint(0, sys.maxsize)
1232
+ generator = torch.Generator(self.device)
1233
+ generator.manual_seed(seed)
1234
+ generator_state = generator.get_state()
1235
+ lerp_embed_points = []
1236
+ for i in range(length):
1237
+ weight = i / length
1238
+ tensor_lerp = torch.lerp(first_embedding, second_embedding, weight)
1239
+ lerp_embed_points.append(tensor_lerp)
1240
+ images = []
1241
+ for idx, latent_point in enumerate(lerp_embed_points):
1242
+ generator.set_state(generator_state)
1243
+ image = self.diffuse_from_inits(latent_point, **kwargs)["image"][0]
1244
+ images.append(image)
1245
+ if save:
1246
+ image.save(f"{first_prompt}-{second_prompt}-{idx:02d}.png", "PNG")
1247
+ return {"images": images, "latent_points": lerp_embed_points,"generator_state": generator_state}
1248
+
1249
+ def slerp_through_seeds(self,
1250
+ prompt,
1251
+ height: Optional[int] = 512,
1252
+ width: Optional[int] = 512,
1253
+ save = False,
1254
+ seed = None, steps = 10, **kwargs):
1255
+
1256
+ if not seed:
1257
+ seed = random.randint(0, sys.maxsize)
1258
+ generator = torch.Generator(self.device)
1259
+ generator.manual_seed(seed)
1260
+ init_start = torch.randn(
1261
+ (1, self.unet.in_channels, height // 8, width // 8),
1262
+ generator = generator, device = self.device)
1263
+ init_end = torch.randn(
1264
+ (1, self.unet.in_channels, height // 8, width // 8),
1265
+ generator = generator, device = self.device)
1266
+ generator_state = generator.get_state()
1267
+ slerp_embed_points = []
1268
+ # weight from 0 to 1/(steps - 1), add init_end specifically so that we
1269
+ # have len(images) = steps
1270
+ for i in range(steps - 1):
1271
+ weight = i / steps
1272
+ tensor_slerp = self.slerp(weight, init_start, init_end)
1273
+ slerp_embed_points.append(tensor_slerp)
1274
+ slerp_embed_points.append(init_end)
1275
+ images = []
1276
+ embed_point = self.get_text_latent_space(prompt)
1277
+ for idx, noise_point in enumerate(slerp_embed_points):
1278
+ generator.set_state(generator_state)
1279
+ image = self.diffuse_from_inits(embed_point, init = noise_point, **kwargs)["image"][0]
1280
+ images.append(image)
1281
+ if save:
1282
+ image.save(f"{seed}-{idx:02d}.png", "PNG")
1283
+ return {"images": images, "noise_samples": slerp_embed_points,"generator_state": generator_state}
1284
+
1285
+ @torch.no_grad()
1286
+ def diffuse_from_inits(self, text_embeddings,
1287
+ init = None,
1288
+ height: Optional[int] = 512,
1289
+ width: Optional[int] = 512,
1290
+ num_inference_steps: Optional[int] = 50,
1291
+ guidance_scale: Optional[float] = 7.5,
1292
+ eta: Optional[float] = 0.0,
1293
+ generator: Optional[torch.Generator] = None,
1294
+ output_type: Optional[str] = "pil",
1295
+ **kwargs,):
1296
+
1297
+ from diffusers.schedulers import LMSDiscreteScheduler
1298
+ batch_size = 1
1299
+
1300
+ if generator == None:
1301
+ generator = torch.Generator("cuda")
1302
+ generator_state = generator.get_state()
1303
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1304
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1305
+ # corresponds to doing no classifier free guidance.
1306
+ do_classifier_free_guidance = guidance_scale > 1.0
1307
+ # get the intial random noise
1308
+ latents = init if init is not None else torch.randn(
1309
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
1310
+ generator=generator,
1311
+ device=self.device,)
1312
+
1313
+ # set timesteps
1314
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
1315
+ extra_set_kwargs = {}
1316
+ if accepts_offset:
1317
+ extra_set_kwargs["offset"] = 1
1318
+
1319
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
1320
+
1321
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
1322
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1323
+ latents = latents * self.scheduler.sigmas[0]
1324
+
1325
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1326
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1327
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1328
+ # and should be between [0, 1]
1329
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1330
+ extra_step_kwargs = {}
1331
+ if accepts_eta:
1332
+ extra_step_kwargs["eta"] = eta
1333
+
1334
+ for i, t in tqdm(enumerate(self.scheduler.timesteps)):
1335
+ # expand the latents if we are doing classifier free guidance
1336
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1337
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1338
+ sigma = self.scheduler.sigmas[i]
1339
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1340
+
1341
+ # predict the noise residual
1342
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1343
+
1344
+ # perform guidance
1345
+ if do_classifier_free_guidance:
1346
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1347
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1348
+
1349
+ # compute the previous noisy sample x_t -> x_t-1
1350
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
1351
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
1352
+ else:
1353
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1354
+
1355
+ # scale and decode the image latents with vae
1356
+ latents = 1 / 0.18215 * latents
1357
+ image = self.vae.decode(latents)
1358
+
1359
+ image = (image / 2 + 0.5).clamp(0, 1)
1360
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
1361
+
1362
+ if output_type == "pil":
1363
+ image = self.numpy_to_pil(image)
1364
+
1365
+ return {"image": image, "generator_state": generator_state}
1366
+
1367
+ def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1368
+ # random vector to move in latent space
1369
+ rand_t = (torch.rand(text_embeddings.shape, device = self.device) * 2) - 1
1370
+ rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1371
+ scaled_rand_t = rand_t / rand_mag
1372
+ variation_embedding = text_embeddings + scaled_rand_t
1373
+
1374
+ generator = torch.Generator("cuda")
1375
+ generator.set_state(generator_state)
1376
+ result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
1377
+ result.update({"latent_point": variation_embedding})
1378
+ return result