Add TextualInversion to processed prompts. Other updates.
Browse files- pipeline.py +18 -9
pipeline.py
CHANGED
|
@@ -6,13 +6,14 @@ import numpy as np
|
|
| 6 |
import PIL
|
| 7 |
import torch
|
| 8 |
from packaging import version
|
| 9 |
-
from transformers import
|
| 10 |
import random
|
| 11 |
import sys
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
|
| 14 |
import diffusers
|
| 15 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
|
|
|
| 16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
from diffusers.utils import logging
|
|
@@ -182,14 +183,14 @@ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], m
|
|
| 182 |
return tokens, weights
|
| 183 |
|
| 184 |
|
| 185 |
-
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
| 186 |
r"""
|
| 187 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 188 |
"""
|
| 189 |
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 190 |
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 191 |
for i in range(len(tokens)):
|
| 192 |
-
tokens[i] = [bos] + tokens[i] + [
|
| 193 |
if no_boseos_middle:
|
| 194 |
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 195 |
else:
|
|
@@ -320,12 +321,14 @@ def get_weighted_text_embeddings(
|
|
| 320 |
# pad the length of tokens and weights
|
| 321 |
bos = pipe.tokenizer.bos_token_id
|
| 322 |
eos = pipe.tokenizer.eos_token_id
|
|
|
|
| 323 |
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 324 |
prompt_tokens,
|
| 325 |
prompt_weights,
|
| 326 |
max_length,
|
| 327 |
bos,
|
| 328 |
eos,
|
|
|
|
| 329 |
no_boseos_middle=no_boseos_middle,
|
| 330 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 331 |
)
|
|
@@ -337,6 +340,7 @@ def get_weighted_text_embeddings(
|
|
| 337 |
max_length,
|
| 338 |
bos,
|
| 339 |
eos,
|
|
|
|
| 340 |
no_boseos_middle=no_boseos_middle,
|
| 341 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 342 |
)
|
|
@@ -379,7 +383,7 @@ def get_weighted_text_embeddings(
|
|
| 379 |
|
| 380 |
def preprocess_image(image):
|
| 381 |
w, h = image.size
|
| 382 |
-
w, h =
|
| 383 |
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
| 384 |
image = np.array(image).astype(np.float32) / 255.0
|
| 385 |
image = image[None].transpose(0, 3, 1, 2)
|
|
@@ -390,7 +394,7 @@ def preprocess_image(image):
|
|
| 390 |
def preprocess_mask(mask, scale_factor=8):
|
| 391 |
mask = mask.convert("L")
|
| 392 |
w, h = mask.size
|
| 393 |
-
w, h =
|
| 394 |
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
| 395 |
mask = np.array(mask).astype(np.float32) / 255.0
|
| 396 |
mask = np.tile(mask, (4, 1, 1))
|
|
@@ -425,7 +429,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 425 |
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 426 |
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 427 |
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
| 428 |
-
feature_extractor ([`
|
| 429 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 430 |
"""
|
| 431 |
|
|
@@ -439,7 +443,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 439 |
unet: UNet2DConditionModel,
|
| 440 |
scheduler: SchedulerMixin,
|
| 441 |
safety_checker: StableDiffusionSafetyChecker,
|
| 442 |
-
feature_extractor:
|
| 443 |
requires_safety_checker: bool = True,
|
| 444 |
):
|
| 445 |
super().__init__(
|
|
@@ -464,7 +468,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 464 |
unet: UNet2DConditionModel,
|
| 465 |
scheduler: SchedulerMixin,
|
| 466 |
safety_checker: StableDiffusionSafetyChecker,
|
| 467 |
-
feature_extractor:
|
| 468 |
):
|
| 469 |
super().__init__(
|
| 470 |
vae=vae,
|
|
@@ -538,6 +542,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 538 |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 539 |
" the batch size of `prompt`."
|
| 540 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
| 543 |
pipe=self,
|
|
@@ -627,7 +636,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 627 |
if image is None:
|
| 628 |
shape = (
|
| 629 |
batch_size,
|
| 630 |
-
self.unet.in_channels,
|
| 631 |
height // self.vae_scale_factor,
|
| 632 |
width // self.vae_scale_factor,
|
| 633 |
)
|
|
|
|
| 6 |
import PIL
|
| 7 |
import torch
|
| 8 |
from packaging import version
|
| 9 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 10 |
import random
|
| 11 |
import sys
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
|
| 14 |
import diffusers
|
| 15 |
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
| 16 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
| 17 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 18 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 19 |
from diffusers.utils import logging
|
|
|
|
| 183 |
return tokens, weights
|
| 184 |
|
| 185 |
|
| 186 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
| 187 |
r"""
|
| 188 |
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 189 |
"""
|
| 190 |
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 191 |
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 192 |
for i in range(len(tokens)):
|
| 193 |
+
tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
|
| 194 |
if no_boseos_middle:
|
| 195 |
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 196 |
else:
|
|
|
|
| 321 |
# pad the length of tokens and weights
|
| 322 |
bos = pipe.tokenizer.bos_token_id
|
| 323 |
eos = pipe.tokenizer.eos_token_id
|
| 324 |
+
pad = getattr(pipe.tokenizer, "pad_token_id", eos)
|
| 325 |
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 326 |
prompt_tokens,
|
| 327 |
prompt_weights,
|
| 328 |
max_length,
|
| 329 |
bos,
|
| 330 |
eos,
|
| 331 |
+
pad,
|
| 332 |
no_boseos_middle=no_boseos_middle,
|
| 333 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 334 |
)
|
|
|
|
| 340 |
max_length,
|
| 341 |
bos,
|
| 342 |
eos,
|
| 343 |
+
pad,
|
| 344 |
no_boseos_middle=no_boseos_middle,
|
| 345 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 346 |
)
|
|
|
|
| 383 |
|
| 384 |
def preprocess_image(image):
|
| 385 |
w, h = image.size
|
| 386 |
+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
| 387 |
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
| 388 |
image = np.array(image).astype(np.float32) / 255.0
|
| 389 |
image = image[None].transpose(0, 3, 1, 2)
|
|
|
|
| 394 |
def preprocess_mask(mask, scale_factor=8):
|
| 395 |
mask = mask.convert("L")
|
| 396 |
w, h = mask.size
|
| 397 |
+
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
|
| 398 |
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
| 399 |
mask = np.array(mask).astype(np.float32) / 255.0
|
| 400 |
mask = np.tile(mask, (4, 1, 1))
|
|
|
|
| 429 |
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 430 |
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 431 |
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
| 432 |
+
feature_extractor ([`CLIPImageProcessor`]):
|
| 433 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 434 |
"""
|
| 435 |
|
|
|
|
| 443 |
unet: UNet2DConditionModel,
|
| 444 |
scheduler: SchedulerMixin,
|
| 445 |
safety_checker: StableDiffusionSafetyChecker,
|
| 446 |
+
feature_extractor: CLIPImageProcessor,
|
| 447 |
requires_safety_checker: bool = True,
|
| 448 |
):
|
| 449 |
super().__init__(
|
|
|
|
| 468 |
unet: UNet2DConditionModel,
|
| 469 |
scheduler: SchedulerMixin,
|
| 470 |
safety_checker: StableDiffusionSafetyChecker,
|
| 471 |
+
feature_extractor: CLIPImageProcessor,
|
| 472 |
):
|
| 473 |
super().__init__(
|
| 474 |
vae=vae,
|
|
|
|
| 542 |
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 543 |
" the batch size of `prompt`."
|
| 544 |
)
|
| 545 |
+
|
| 546 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
| 547 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 548 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 549 |
+
negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
|
| 550 |
|
| 551 |
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
| 552 |
pipe=self,
|
|
|
|
| 636 |
if image is None:
|
| 637 |
shape = (
|
| 638 |
batch_size,
|
| 639 |
+
self.unet.config.in_channels,
|
| 640 |
height // self.vae_scale_factor,
|
| 641 |
width // self.vae_scale_factor,
|
| 642 |
)
|