Updated prompt tweening, trying to fix error "Input type (float) and bias type (c10::Half) should be the same"
Browse files- pipeline.py +32 -12
pipeline.py
CHANGED
|
@@ -1167,6 +1167,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1167 |
|
| 1168 |
# Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
| 1169 |
def get_text_latent_space(self, prompt, guidance_scale = 7.5):
|
|
|
|
| 1170 |
# get prompt text embeddings
|
| 1171 |
text_input = self.tokenizer(
|
| 1172 |
prompt,
|
|
@@ -1176,7 +1177,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1176 |
return_tensors="pt",
|
| 1177 |
)
|
| 1178 |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
| 1179 |
-
|
| 1180 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 1181 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 1182 |
# corresponds to doing no classifier free guidance.
|
|
@@ -1195,7 +1196,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1195 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 1196 |
|
| 1197 |
return text_embeddings
|
| 1198 |
-
|
| 1199 |
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
| 1200 |
""" helper function to spherically interpolate two arrays v1 v2
|
| 1201 |
from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
|
@@ -1292,11 +1293,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1292 |
eta: Optional[float] = 0.0,
|
| 1293 |
generator: Optional[torch.Generator] = None,
|
| 1294 |
output_type: Optional[str] = "pil",
|
|
|
|
| 1295 |
**kwargs,):
|
| 1296 |
|
| 1297 |
-
from diffusers.schedulers import LMSDiscreteScheduler
|
| 1298 |
batch_size = 1
|
| 1299 |
-
|
| 1300 |
if generator == None:
|
| 1301 |
generator = torch.Generator("cuda")
|
| 1302 |
generator_state = generator.get_state()
|
|
@@ -1330,8 +1331,27 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1330 |
extra_step_kwargs = {}
|
| 1331 |
if accepts_eta:
|
| 1332 |
extra_step_kwargs["eta"] = eta
|
| 1333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1334 |
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1335 |
# expand the latents if we are doing classifier free guidance
|
| 1336 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1337 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
@@ -1339,7 +1359,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1339 |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
| 1340 |
|
| 1341 |
# predict the noise residual
|
| 1342 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
| 1343 |
|
| 1344 |
# perform guidance
|
| 1345 |
if do_classifier_free_guidance:
|
|
@@ -1348,21 +1368,21 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1348 |
|
| 1349 |
# compute the previous noisy sample x_t -> x_t-1
|
| 1350 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 1351 |
-
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)
|
| 1352 |
else:
|
| 1353 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
| 1354 |
|
| 1355 |
# scale and decode the image latents with vae
|
| 1356 |
latents = 1 / 0.18215 * latents
|
| 1357 |
-
image = self.vae.decode(latents)
|
| 1358 |
|
| 1359 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1360 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 1361 |
-
|
| 1362 |
if output_type == "pil":
|
| 1363 |
image = self.numpy_to_pil(image)
|
| 1364 |
|
| 1365 |
-
return {"image": image, "generator_state": generator_state}
|
| 1366 |
|
| 1367 |
def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
|
| 1368 |
# random vector to move in latent space
|
|
@@ -1370,7 +1390,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1370 |
rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
|
| 1371 |
scaled_rand_t = rand_t / rand_mag
|
| 1372 |
variation_embedding = text_embeddings + scaled_rand_t
|
| 1373 |
-
|
| 1374 |
generator = torch.Generator("cuda")
|
| 1375 |
generator.set_state(generator_state)
|
| 1376 |
result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
|
|
|
|
| 1167 |
|
| 1168 |
# Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
| 1169 |
def get_text_latent_space(self, prompt, guidance_scale = 7.5):
|
| 1170 |
+
|
| 1171 |
# get prompt text embeddings
|
| 1172 |
text_input = self.tokenizer(
|
| 1173 |
prompt,
|
|
|
|
| 1177 |
return_tensors="pt",
|
| 1178 |
)
|
| 1179 |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
| 1180 |
+
|
| 1181 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 1182 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 1183 |
# corresponds to doing no classifier free guidance.
|
|
|
|
| 1196 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 1197 |
|
| 1198 |
return text_embeddings
|
| 1199 |
+
|
| 1200 |
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
| 1201 |
""" helper function to spherically interpolate two arrays v1 v2
|
| 1202 |
from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
|
|
|
| 1293 |
eta: Optional[float] = 0.0,
|
| 1294 |
generator: Optional[torch.Generator] = None,
|
| 1295 |
output_type: Optional[str] = "pil",
|
| 1296 |
+
save_n_steps: Optional[int] = None,
|
| 1297 |
**kwargs,):
|
| 1298 |
|
|
|
|
| 1299 |
batch_size = 1
|
| 1300 |
+
|
| 1301 |
if generator == None:
|
| 1302 |
generator = torch.Generator("cuda")
|
| 1303 |
generator_state = generator.get_state()
|
|
|
|
| 1331 |
extra_step_kwargs = {}
|
| 1332 |
if accepts_eta:
|
| 1333 |
extra_step_kwargs["eta"] = eta
|
| 1334 |
+
if save_n_steps:
|
| 1335 |
+
mid_latents = []
|
| 1336 |
+
mid_images = []
|
| 1337 |
+
else:
|
| 1338 |
+
mid_latents = None
|
| 1339 |
+
mid_images = None
|
| 1340 |
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
| 1341 |
+
if save_n_steps:
|
| 1342 |
+
if i % save_n_steps == 0:
|
| 1343 |
+
# scale and decode the image latents with vae
|
| 1344 |
+
dec_mid_latents = 1 / 0.18215 * latents
|
| 1345 |
+
mid_latents.append(dec_mid_latents)
|
| 1346 |
+
image = self.vae.decode(dec_mid_latents).sample
|
| 1347 |
+
|
| 1348 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1349 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 1350 |
+
|
| 1351 |
+
if output_type == "pil":
|
| 1352 |
+
image = self.numpy_to_pil(image)
|
| 1353 |
+
mid_latents.append(image)
|
| 1354 |
+
mid_images.append(image)
|
| 1355 |
# expand the latents if we are doing classifier free guidance
|
| 1356 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1357 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
|
|
| 1359 |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
| 1360 |
|
| 1361 |
# predict the noise residual
|
| 1362 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
| 1363 |
|
| 1364 |
# perform guidance
|
| 1365 |
if do_classifier_free_guidance:
|
|
|
|
| 1368 |
|
| 1369 |
# compute the previous noisy sample x_t -> x_t-1
|
| 1370 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 1371 |
+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
|
| 1372 |
else:
|
| 1373 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
| 1374 |
|
| 1375 |
# scale and decode the image latents with vae
|
| 1376 |
latents = 1 / 0.18215 * latents
|
| 1377 |
+
image = self.vae.decode(latents).sample
|
| 1378 |
|
| 1379 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1380 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 1381 |
+
|
| 1382 |
if output_type == "pil":
|
| 1383 |
image = self.numpy_to_pil(image)
|
| 1384 |
|
| 1385 |
+
return {"image": image, "generator_state": generator_state, "mid_latents": mid_latents, "mid_images": mid_images}
|
| 1386 |
|
| 1387 |
def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
|
| 1388 |
# random vector to move in latent space
|
|
|
|
| 1390 |
rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
|
| 1391 |
scaled_rand_t = rand_t / rand_mag
|
| 1392 |
variation_embedding = text_embeddings + scaled_rand_t
|
| 1393 |
+
|
| 1394 |
generator = torch.Generator("cuda")
|
| 1395 |
generator.set_state(generator_state)
|
| 1396 |
result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
|