Update pipeline.py
Browse files- pipeline.py +134 -66
pipeline.py
CHANGED
|
@@ -5,14 +5,55 @@ from typing import Callable, List, Optional, Union
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
|
|
|
|
| 8 |
import PIL
|
| 9 |
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
| 10 |
-
from diffusers.onnx_utils import
|
| 11 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 12 |
-
from diffusers.utils import
|
|
|
|
| 13 |
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 17 |
|
| 18 |
re_attention = re.compile(
|
|
@@ -390,30 +431,59 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|
| 390 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 391 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 392 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
| 417 |
self.unet_in_channels = 4
|
| 418 |
self.vae_scale_factor = 8
|
| 419 |
|
|
@@ -741,49 +811,47 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
|
| 741 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 742 |
|
| 743 |
# 8. Denoising loop
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
)
|
| 758 |
-
noise_pred = noise_pred[0]
|
| 759 |
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 766 |
-
scheduler_output = self.scheduler.step(
|
| 767 |
-
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
| 768 |
-
)
|
| 769 |
-
latents = scheduler_output.prev_sample.numpy()
|
| 770 |
-
|
| 771 |
-
if mask is not None:
|
| 772 |
-
# masking
|
| 773 |
-
init_latents_proper = self.scheduler.add_noise(
|
| 774 |
-
torch.from_numpy(init_latents_orig),
|
| 775 |
-
torch.from_numpy(noise),
|
| 776 |
-
t,
|
| 777 |
-
).numpy()
|
| 778 |
-
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 779 |
-
|
| 780 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 781 |
-
progress_bar.update()
|
| 782 |
-
if i % callback_steps == 0:
|
| 783 |
-
if callback is not None:
|
| 784 |
-
callback(i, t, latents)
|
| 785 |
-
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 786 |
-
return None
|
| 787 |
# 9. Post-processing
|
| 788 |
image = self.decode_latents(latents)
|
| 789 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
import diffusers
|
| 9 |
import PIL
|
| 10 |
from diffusers import OnnxStableDiffusionPipeline, SchedulerMixin
|
| 11 |
+
from diffusers.onnx_utils import OnnxRuntimeModel
|
| 12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 13 |
+
from diffusers.utils import deprecate, logging
|
| 14 |
+
from packaging import version
|
| 15 |
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
| 16 |
|
| 17 |
|
| 18 |
+
try:
|
| 19 |
+
from diffusers.onnx_utils import ORT_TO_NP_TYPE
|
| 20 |
+
except ImportError:
|
| 21 |
+
ORT_TO_NP_TYPE = {
|
| 22 |
+
"tensor(bool)": np.bool_,
|
| 23 |
+
"tensor(int8)": np.int8,
|
| 24 |
+
"tensor(uint8)": np.uint8,
|
| 25 |
+
"tensor(int16)": np.int16,
|
| 26 |
+
"tensor(uint16)": np.uint16,
|
| 27 |
+
"tensor(int32)": np.int32,
|
| 28 |
+
"tensor(uint32)": np.uint32,
|
| 29 |
+
"tensor(int64)": np.int64,
|
| 30 |
+
"tensor(uint64)": np.uint64,
|
| 31 |
+
"tensor(float16)": np.float16,
|
| 32 |
+
"tensor(float)": np.float32,
|
| 33 |
+
"tensor(double)": np.float64,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from diffusers.utils import PIL_INTERPOLATION
|
| 38 |
+
except ImportError:
|
| 39 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 40 |
+
PIL_INTERPOLATION = {
|
| 41 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
| 42 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
| 43 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
| 44 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
| 45 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
| 46 |
+
}
|
| 47 |
+
else:
|
| 48 |
+
PIL_INTERPOLATION = {
|
| 49 |
+
"linear": PIL.Image.LINEAR,
|
| 50 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 51 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 52 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 53 |
+
"nearest": PIL.Image.NEAREST,
|
| 54 |
+
}
|
| 55 |
+
# ------------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 58 |
|
| 59 |
re_attention = re.compile(
|
|
|
|
| 431 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 432 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 433 |
"""
|
| 434 |
+
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
| 435 |
+
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
vae_encoder: OnnxRuntimeModel,
|
| 439 |
+
vae_decoder: OnnxRuntimeModel,
|
| 440 |
+
text_encoder: OnnxRuntimeModel,
|
| 441 |
+
tokenizer: CLIPTokenizer,
|
| 442 |
+
unet: OnnxRuntimeModel,
|
| 443 |
+
scheduler: SchedulerMixin,
|
| 444 |
+
safety_checker: OnnxRuntimeModel,
|
| 445 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 446 |
+
requires_safety_checker: bool = True,
|
| 447 |
+
):
|
| 448 |
+
super().__init__(
|
| 449 |
+
vae_encoder=vae_encoder,
|
| 450 |
+
vae_decoder=vae_decoder,
|
| 451 |
+
text_encoder=text_encoder,
|
| 452 |
+
tokenizer=tokenizer,
|
| 453 |
+
unet=unet,
|
| 454 |
+
scheduler=scheduler,
|
| 455 |
+
safety_checker=safety_checker,
|
| 456 |
+
feature_extractor=feature_extractor,
|
| 457 |
+
requires_safety_checker=requires_safety_checker,
|
| 458 |
+
)
|
| 459 |
+
self.__init__additional__()
|
| 460 |
|
| 461 |
+
else:
|
| 462 |
+
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
vae_encoder: OnnxRuntimeModel,
|
| 466 |
+
vae_decoder: OnnxRuntimeModel,
|
| 467 |
+
text_encoder: OnnxRuntimeModel,
|
| 468 |
+
tokenizer: CLIPTokenizer,
|
| 469 |
+
unet: OnnxRuntimeModel,
|
| 470 |
+
scheduler: SchedulerMixin,
|
| 471 |
+
safety_checker: OnnxRuntimeModel,
|
| 472 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 473 |
+
):
|
| 474 |
+
super().__init__(
|
| 475 |
+
vae_encoder=vae_encoder,
|
| 476 |
+
vae_decoder=vae_decoder,
|
| 477 |
+
text_encoder=text_encoder,
|
| 478 |
+
tokenizer=tokenizer,
|
| 479 |
+
unet=unet,
|
| 480 |
+
scheduler=scheduler,
|
| 481 |
+
safety_checker=safety_checker,
|
| 482 |
+
feature_extractor=feature_extractor,
|
| 483 |
+
)
|
| 484 |
+
self.__init__additional__()
|
| 485 |
+
|
| 486 |
+
def __init__additional__(self):
|
| 487 |
self.unet_in_channels = 4
|
| 488 |
self.vae_scale_factor = 8
|
| 489 |
|
|
|
|
| 811 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 812 |
|
| 813 |
# 8. Denoising loop
|
| 814 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 815 |
+
# expand the latents if we are doing classifier free guidance
|
| 816 |
+
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
| 817 |
+
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
| 818 |
+
latent_model_input = latent_model_input.numpy()
|
| 819 |
+
|
| 820 |
+
# predict the noise residual
|
| 821 |
+
noise_pred = self.unet(
|
| 822 |
+
sample=latent_model_input,
|
| 823 |
+
timestep=np.array([t], dtype=timestep_dtype),
|
| 824 |
+
encoder_hidden_states=text_embeddings,
|
| 825 |
+
)
|
| 826 |
+
noise_pred = noise_pred[0]
|
|
|
|
|
|
|
| 827 |
|
| 828 |
+
# perform guidance
|
| 829 |
+
if do_classifier_free_guidance:
|
| 830 |
+
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
| 831 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 832 |
+
|
| 833 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 834 |
+
scheduler_output = self.scheduler.step(
|
| 835 |
+
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
| 836 |
+
)
|
| 837 |
+
latents = scheduler_output.prev_sample.numpy()
|
| 838 |
+
|
| 839 |
+
if mask is not None:
|
| 840 |
+
# masking
|
| 841 |
+
init_latents_proper = self.scheduler.add_noise(
|
| 842 |
+
torch.from_numpy(init_latents_orig),
|
| 843 |
+
torch.from_numpy(noise),
|
| 844 |
+
t,
|
| 845 |
+
).numpy()
|
| 846 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 847 |
+
|
| 848 |
+
# call the callback, if provided
|
| 849 |
+
if i % callback_steps == 0:
|
| 850 |
+
if callback is not None:
|
| 851 |
+
callback(i, t, latents)
|
| 852 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 853 |
+
return None
|
| 854 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
# 9. Post-processing
|
| 856 |
image = self.decode_latents(latents)
|
| 857 |
|