AlanB commited on
Commit
69e0e4e
·
1 Parent(s): 0ca23a8

Updated prompt tweening, trying to fix error "Input type (float) and bias type (c10::Half) should be the same"

Browse files
Files changed (1) hide show
  1. pipeline.py +32 -12
pipeline.py CHANGED
@@ -1167,6 +1167,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1167
 
1168
  # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1169
  def get_text_latent_space(self, prompt, guidance_scale = 7.5):
 
1170
  # get prompt text embeddings
1171
  text_input = self.tokenizer(
1172
  prompt,
@@ -1176,7 +1177,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1176
  return_tensors="pt",
1177
  )
1178
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1179
-
1180
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1181
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1182
  # corresponds to doing no classifier free guidance.
@@ -1195,7 +1196,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1195
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1196
 
1197
  return text_embeddings
1198
-
1199
  def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1200
  """ helper function to spherically interpolate two arrays v1 v2
1201
  from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
@@ -1292,11 +1293,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1292
  eta: Optional[float] = 0.0,
1293
  generator: Optional[torch.Generator] = None,
1294
  output_type: Optional[str] = "pil",
 
1295
  **kwargs,):
1296
 
1297
- from diffusers.schedulers import LMSDiscreteScheduler
1298
  batch_size = 1
1299
-
1300
  if generator == None:
1301
  generator = torch.Generator("cuda")
1302
  generator_state = generator.get_state()
@@ -1330,8 +1331,27 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1330
  extra_step_kwargs = {}
1331
  if accepts_eta:
1332
  extra_step_kwargs["eta"] = eta
1333
-
 
 
 
 
 
1334
  for i, t in tqdm(enumerate(self.scheduler.timesteps)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1335
  # expand the latents if we are doing classifier free guidance
1336
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1337
  if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -1339,7 +1359,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1339
  latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1340
 
1341
  # predict the noise residual
1342
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1343
 
1344
  # perform guidance
1345
  if do_classifier_free_guidance:
@@ -1348,21 +1368,21 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1348
 
1349
  # compute the previous noisy sample x_t -> x_t-1
1350
  if isinstance(self.scheduler, LMSDiscreteScheduler):
1351
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
1352
  else:
1353
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1354
 
1355
  # scale and decode the image latents with vae
1356
  latents = 1 / 0.18215 * latents
1357
- image = self.vae.decode(latents)
1358
 
1359
  image = (image / 2 + 0.5).clamp(0, 1)
1360
  image = image.cpu().permute(0, 2, 3, 1).numpy()
1361
-
1362
  if output_type == "pil":
1363
  image = self.numpy_to_pil(image)
1364
 
1365
- return {"image": image, "generator_state": generator_state}
1366
 
1367
  def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1368
  # random vector to move in latent space
@@ -1370,7 +1390,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1370
  rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1371
  scaled_rand_t = rand_t / rand_mag
1372
  variation_embedding = text_embeddings + scaled_rand_t
1373
-
1374
  generator = torch.Generator("cuda")
1375
  generator.set_state(generator_state)
1376
  result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
 
1167
 
1168
  # Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
1169
  def get_text_latent_space(self, prompt, guidance_scale = 7.5):
1170
+
1171
  # get prompt text embeddings
1172
  text_input = self.tokenizer(
1173
  prompt,
 
1177
  return_tensors="pt",
1178
  )
1179
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1180
+
1181
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1182
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1183
  # corresponds to doing no classifier free guidance.
 
1196
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1197
 
1198
  return text_embeddings
1199
+
1200
  def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1201
  """ helper function to spherically interpolate two arrays v1 v2
1202
  from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
 
1293
  eta: Optional[float] = 0.0,
1294
  generator: Optional[torch.Generator] = None,
1295
  output_type: Optional[str] = "pil",
1296
+ save_n_steps: Optional[int] = None,
1297
  **kwargs,):
1298
 
 
1299
  batch_size = 1
1300
+
1301
  if generator == None:
1302
  generator = torch.Generator("cuda")
1303
  generator_state = generator.get_state()
 
1331
  extra_step_kwargs = {}
1332
  if accepts_eta:
1333
  extra_step_kwargs["eta"] = eta
1334
+ if save_n_steps:
1335
+ mid_latents = []
1336
+ mid_images = []
1337
+ else:
1338
+ mid_latents = None
1339
+ mid_images = None
1340
  for i, t in tqdm(enumerate(self.scheduler.timesteps)):
1341
+ if save_n_steps:
1342
+ if i % save_n_steps == 0:
1343
+ # scale and decode the image latents with vae
1344
+ dec_mid_latents = 1 / 0.18215 * latents
1345
+ mid_latents.append(dec_mid_latents)
1346
+ image = self.vae.decode(dec_mid_latents).sample
1347
+
1348
+ image = (image / 2 + 0.5).clamp(0, 1)
1349
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
1350
+
1351
+ if output_type == "pil":
1352
+ image = self.numpy_to_pil(image)
1353
+ mid_latents.append(image)
1354
+ mid_images.append(image)
1355
  # expand the latents if we are doing classifier free guidance
1356
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1357
  if isinstance(self.scheduler, LMSDiscreteScheduler):
 
1359
  latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1360
 
1361
  # predict the noise residual
1362
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
1363
 
1364
  # perform guidance
1365
  if do_classifier_free_guidance:
 
1368
 
1369
  # compute the previous noisy sample x_t -> x_t-1
1370
  if isinstance(self.scheduler, LMSDiscreteScheduler):
1371
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
1372
  else:
1373
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
1374
 
1375
  # scale and decode the image latents with vae
1376
  latents = 1 / 0.18215 * latents
1377
+ image = self.vae.decode(latents).sample
1378
 
1379
  image = (image / 2 + 0.5).clamp(0, 1)
1380
  image = image.cpu().permute(0, 2, 3, 1).numpy()
1381
+
1382
  if output_type == "pil":
1383
  image = self.numpy_to_pil(image)
1384
 
1385
+ return {"image": image, "generator_state": generator_state, "mid_latents": mid_latents, "mid_images": mid_images}
1386
 
1387
  def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1388
  # random vector to move in latent space
 
1390
  rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1391
  scaled_rand_t = rand_t / rand_mag
1392
  variation_embedding = text_embeddings + scaled_rand_t
1393
+
1394
  generator = torch.Generator("cuda")
1395
  generator.set_state(generator_state)
1396
  result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)