Updates from diffusers
Browse files- pipeline.py +21 -53
pipeline.py
CHANGED
|
@@ -3,19 +3,19 @@ import re
|
|
| 3 |
from typing import Callable, List, Optional, Union
|
| 4 |
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import torch
|
|
|
|
|
|
|
| 7 |
import random
|
| 8 |
import sys
|
| 9 |
from tqdm.auto import tqdm
|
| 10 |
|
| 11 |
import diffusers
|
| 12 |
-
import PIL
|
| 13 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
| 14 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 15 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 16 |
-
from diffusers.utils import
|
| 17 |
-
from packaging import version
|
| 18 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 19 |
|
| 20 |
|
| 21 |
try:
|
|
@@ -255,7 +255,6 @@ def get_weighted_text_embeddings(
|
|
| 255 |
no_boseos_middle: Optional[bool] = False,
|
| 256 |
skip_parsing: Optional[bool] = False,
|
| 257 |
skip_weighting: Optional[bool] = False,
|
| 258 |
-
**kwargs,
|
| 259 |
):
|
| 260 |
r"""
|
| 261 |
Prompts can be assigned with local weights using brackets. For example,
|
|
@@ -603,7 +602,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 603 |
latents = 1 / 0.18215 * latents
|
| 604 |
image = self.vae.decode(latents).sample
|
| 605 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 606 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with
|
| 607 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 608 |
return image
|
| 609 |
|
|
@@ -684,8 +683,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 684 |
return_dict: bool = True,
|
| 685 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 686 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 687 |
-
callback_steps:
|
| 688 |
-
**kwargs,
|
| 689 |
):
|
| 690 |
r"""
|
| 691 |
Function invoked when calling the pipeline for generation.
|
|
@@ -761,10 +759,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 761 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 762 |
(nsfw) content, according to the `safety_checker`.
|
| 763 |
"""
|
| 764 |
-
message = "Please use `image` instead of `init_image`."
|
| 765 |
-
init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
|
| 766 |
-
image = init_image or image
|
| 767 |
-
|
| 768 |
# 0. Default height and width to unet
|
| 769 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 770 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
@@ -886,8 +880,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 886 |
return_dict: bool = True,
|
| 887 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 888 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 889 |
-
callback_steps:
|
| 890 |
-
**kwargs,
|
| 891 |
):
|
| 892 |
r"""
|
| 893 |
Function for text-to-image generation.
|
|
@@ -963,7 +956,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 963 |
callback=callback,
|
| 964 |
is_cancelled_callback=is_cancelled_callback,
|
| 965 |
callback_steps=callback_steps,
|
| 966 |
-
**kwargs,
|
| 967 |
)
|
| 968 |
|
| 969 |
def img2img(
|
|
@@ -982,8 +974,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 982 |
return_dict: bool = True,
|
| 983 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 984 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 985 |
-
callback_steps:
|
| 986 |
-
**kwargs,
|
| 987 |
):
|
| 988 |
r"""
|
| 989 |
Function for image-to-image generation.
|
|
@@ -1059,7 +1050,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1059 |
callback=callback,
|
| 1060 |
is_cancelled_callback=is_cancelled_callback,
|
| 1061 |
callback_steps=callback_steps,
|
| 1062 |
-
**kwargs,
|
| 1063 |
)
|
| 1064 |
|
| 1065 |
def inpaint(
|
|
@@ -1079,8 +1069,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1079 |
return_dict: bool = True,
|
| 1080 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1081 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1082 |
-
callback_steps:
|
| 1083 |
-
**kwargs,
|
| 1084 |
):
|
| 1085 |
r"""
|
| 1086 |
Function for inpaint.
|
|
@@ -1161,13 +1150,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1161 |
callback=callback,
|
| 1162 |
is_cancelled_callback=is_cancelled_callback,
|
| 1163 |
callback_steps=callback_steps,
|
| 1164 |
-
**kwargs,
|
| 1165 |
)
|
| 1166 |
|
| 1167 |
|
| 1168 |
# Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
| 1169 |
def get_text_latent_space(self, prompt, guidance_scale = 7.5):
|
| 1170 |
-
|
| 1171 |
# get prompt text embeddings
|
| 1172 |
text_input = self.tokenizer(
|
| 1173 |
prompt,
|
|
@@ -1177,7 +1164,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1177 |
return_tensors="pt",
|
| 1178 |
)
|
| 1179 |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
| 1180 |
-
|
| 1181 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 1182 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 1183 |
# corresponds to doing no classifier free guidance.
|
|
@@ -1196,7 +1183,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1196 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 1197 |
|
| 1198 |
return text_embeddings
|
| 1199 |
-
|
| 1200 |
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
| 1201 |
""" helper function to spherically interpolate two arrays v1 v2
|
| 1202 |
from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
|
@@ -1293,11 +1280,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1293 |
eta: Optional[float] = 0.0,
|
| 1294 |
generator: Optional[torch.Generator] = None,
|
| 1295 |
output_type: Optional[str] = "pil",
|
| 1296 |
-
save_n_steps: Optional[int] = None,
|
| 1297 |
**kwargs,):
|
|
|
|
| 1298 |
from diffusers.schedulers import LMSDiscreteScheduler
|
| 1299 |
batch_size = 1
|
| 1300 |
-
|
| 1301 |
if generator == None:
|
| 1302 |
generator = torch.Generator("cuda")
|
| 1303 |
generator_state = generator.get_state()
|
|
@@ -1331,27 +1318,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1331 |
extra_step_kwargs = {}
|
| 1332 |
if accepts_eta:
|
| 1333 |
extra_step_kwargs["eta"] = eta
|
| 1334 |
-
|
| 1335 |
-
mid_latents = []
|
| 1336 |
-
mid_images = []
|
| 1337 |
-
else:
|
| 1338 |
-
mid_latents = None
|
| 1339 |
-
mid_images = None
|
| 1340 |
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
| 1341 |
-
if save_n_steps:
|
| 1342 |
-
if i % save_n_steps == 0:
|
| 1343 |
-
# scale and decode the image latents with vae
|
| 1344 |
-
dec_mid_latents = 1 / 0.18215 * latents
|
| 1345 |
-
mid_latents.append(dec_mid_latents)
|
| 1346 |
-
image = self.vae.decode(dec_mid_latents).sample
|
| 1347 |
-
|
| 1348 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1349 |
-
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 1350 |
-
|
| 1351 |
-
if output_type == "pil":
|
| 1352 |
-
image = self.numpy_to_pil(image)
|
| 1353 |
-
mid_latents.append(image)
|
| 1354 |
-
mid_images.append(image)
|
| 1355 |
# expand the latents if we are doing classifier free guidance
|
| 1356 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1357 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
@@ -1359,7 +1327,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1359 |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
| 1360 |
|
| 1361 |
# predict the noise residual
|
| 1362 |
-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
| 1363 |
|
| 1364 |
# perform guidance
|
| 1365 |
if do_classifier_free_guidance:
|
|
@@ -1368,21 +1336,21 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1368 |
|
| 1369 |
# compute the previous noisy sample x_t -> x_t-1
|
| 1370 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 1371 |
-
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)
|
| 1372 |
else:
|
| 1373 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)
|
| 1374 |
|
| 1375 |
# scale and decode the image latents with vae
|
| 1376 |
latents = 1 / 0.18215 * latents
|
| 1377 |
-
image = self.vae.decode(latents)
|
| 1378 |
|
| 1379 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1380 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 1381 |
-
|
| 1382 |
if output_type == "pil":
|
| 1383 |
image = self.numpy_to_pil(image)
|
| 1384 |
|
| 1385 |
-
return {"image": image, "generator_state": generator_state
|
| 1386 |
|
| 1387 |
def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
|
| 1388 |
# random vector to move in latent space
|
|
@@ -1390,7 +1358,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 1390 |
rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
|
| 1391 |
scaled_rand_t = rand_t / rand_mag
|
| 1392 |
variation_embedding = text_embeddings + scaled_rand_t
|
| 1393 |
-
|
| 1394 |
generator = torch.Generator("cuda")
|
| 1395 |
generator.set_state(generator_state)
|
| 1396 |
result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
|
|
|
|
| 3 |
from typing import Callable, List, Optional, Union
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
+
import PIL
|
| 7 |
import torch
|
| 8 |
+
from packaging import version
|
| 9 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 10 |
import random
|
| 11 |
import sys
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
|
| 14 |
import diffusers
|
|
|
|
| 15 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
| 16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
+
from diffusers.utils import logging
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
try:
|
|
|
|
| 255 |
no_boseos_middle: Optional[bool] = False,
|
| 256 |
skip_parsing: Optional[bool] = False,
|
| 257 |
skip_weighting: Optional[bool] = False,
|
|
|
|
| 258 |
):
|
| 259 |
r"""
|
| 260 |
Prompts can be assigned with local weights using brackets. For example,
|
|
|
|
| 602 |
latents = 1 / 0.18215 * latents
|
| 603 |
image = self.vae.decode(latents).sample
|
| 604 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 605 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 606 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 607 |
return image
|
| 608 |
|
|
|
|
| 683 |
return_dict: bool = True,
|
| 684 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 685 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 686 |
+
callback_steps: int = 1,
|
|
|
|
| 687 |
):
|
| 688 |
r"""
|
| 689 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 759 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 760 |
(nsfw) content, according to the `safety_checker`.
|
| 761 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
# 0. Default height and width to unet
|
| 763 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 764 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
|
| 880 |
return_dict: bool = True,
|
| 881 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 882 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 883 |
+
callback_steps: int = 1,
|
|
|
|
| 884 |
):
|
| 885 |
r"""
|
| 886 |
Function for text-to-image generation.
|
|
|
|
| 956 |
callback=callback,
|
| 957 |
is_cancelled_callback=is_cancelled_callback,
|
| 958 |
callback_steps=callback_steps,
|
|
|
|
| 959 |
)
|
| 960 |
|
| 961 |
def img2img(
|
|
|
|
| 974 |
return_dict: bool = True,
|
| 975 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 976 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 977 |
+
callback_steps: int = 1,
|
|
|
|
| 978 |
):
|
| 979 |
r"""
|
| 980 |
Function for image-to-image generation.
|
|
|
|
| 1050 |
callback=callback,
|
| 1051 |
is_cancelled_callback=is_cancelled_callback,
|
| 1052 |
callback_steps=callback_steps,
|
|
|
|
| 1053 |
)
|
| 1054 |
|
| 1055 |
def inpaint(
|
|
|
|
| 1069 |
return_dict: bool = True,
|
| 1070 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1071 |
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1072 |
+
callback_steps: int = 1,
|
|
|
|
| 1073 |
):
|
| 1074 |
r"""
|
| 1075 |
Function for inpaint.
|
|
|
|
| 1150 |
callback=callback,
|
| 1151 |
is_cancelled_callback=is_cancelled_callback,
|
| 1152 |
callback_steps=callback_steps,
|
|
|
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
|
| 1156 |
# Borrowed from https://github.com/csaluski/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
| 1157 |
def get_text_latent_space(self, prompt, guidance_scale = 7.5):
|
|
|
|
| 1158 |
# get prompt text embeddings
|
| 1159 |
text_input = self.tokenizer(
|
| 1160 |
prompt,
|
|
|
|
| 1164 |
return_tensors="pt",
|
| 1165 |
)
|
| 1166 |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
| 1167 |
+
|
| 1168 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 1169 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 1170 |
# corresponds to doing no classifier free guidance.
|
|
|
|
| 1183 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 1184 |
|
| 1185 |
return text_embeddings
|
| 1186 |
+
|
| 1187 |
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
|
| 1188 |
""" helper function to spherically interpolate two arrays v1 v2
|
| 1189 |
from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
|
|
|
|
| 1280 |
eta: Optional[float] = 0.0,
|
| 1281 |
generator: Optional[torch.Generator] = None,
|
| 1282 |
output_type: Optional[str] = "pil",
|
|
|
|
| 1283 |
**kwargs,):
|
| 1284 |
+
|
| 1285 |
from diffusers.schedulers import LMSDiscreteScheduler
|
| 1286 |
batch_size = 1
|
| 1287 |
+
|
| 1288 |
if generator == None:
|
| 1289 |
generator = torch.Generator("cuda")
|
| 1290 |
generator_state = generator.get_state()
|
|
|
|
| 1318 |
extra_step_kwargs = {}
|
| 1319 |
if accepts_eta:
|
| 1320 |
extra_step_kwargs["eta"] = eta
|
| 1321 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1322 |
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1323 |
# expand the latents if we are doing classifier free guidance
|
| 1324 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1325 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
|
|
|
| 1327 |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
| 1328 |
|
| 1329 |
# predict the noise residual
|
| 1330 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 1331 |
|
| 1332 |
# perform guidance
|
| 1333 |
if do_classifier_free_guidance:
|
|
|
|
| 1336 |
|
| 1337 |
# compute the previous noisy sample x_t -> x_t-1
|
| 1338 |
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
| 1339 |
+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
|
| 1340 |
else:
|
| 1341 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 1342 |
|
| 1343 |
# scale and decode the image latents with vae
|
| 1344 |
latents = 1 / 0.18215 * latents
|
| 1345 |
+
image = self.vae.decode(latents)
|
| 1346 |
|
| 1347 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1348 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 1349 |
+
|
| 1350 |
if output_type == "pil":
|
| 1351 |
image = self.numpy_to_pil(image)
|
| 1352 |
|
| 1353 |
+
return {"image": image, "generator_state": generator_state}
|
| 1354 |
|
| 1355 |
def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
|
| 1356 |
# random vector to move in latent space
|
|
|
|
| 1358 |
rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
|
| 1359 |
scaled_rand_t = rand_t / rand_mag
|
| 1360 |
variation_embedding = text_embeddings + scaled_rand_t
|
| 1361 |
+
|
| 1362 |
generator = torch.Generator("cuda")
|
| 1363 |
generator.set_state(generator_state)
|
| 1364 |
result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
|