Reverted hack to get sequential_cpu_offload working. Not perfect.
Browse files- pipeline.py +6 -67
pipeline.py
CHANGED
|
@@ -16,7 +16,7 @@ from diffusers import SchedulerMixin, StableDiffusionPipeline
|
|
| 16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
from diffusers.utils import logging
|
| 19 |
-
|
| 20 |
|
| 21 |
try:
|
| 22 |
from diffusers.utils import PIL_INTERPOLATION
|
|
@@ -281,7 +281,6 @@ def get_weighted_text_embeddings(
|
|
| 281 |
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 282 |
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 283 |
"""
|
| 284 |
-
unet_device = torch.device('cpu') if pipe.unet.device == torch.device('meta') else pipe.unet.device
|
| 285 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 286 |
if isinstance(prompt, str):
|
| 287 |
prompt = [prompt]
|
|
@@ -330,7 +329,7 @@ def get_weighted_text_embeddings(
|
|
| 330 |
no_boseos_middle=no_boseos_middle,
|
| 331 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 332 |
)
|
| 333 |
-
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=
|
| 334 |
if uncond_prompt is not None:
|
| 335 |
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 336 |
uncond_tokens,
|
|
@@ -341,7 +340,7 @@ def get_weighted_text_embeddings(
|
|
| 341 |
no_boseos_middle=no_boseos_middle,
|
| 342 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 343 |
)
|
| 344 |
-
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=
|
| 345 |
|
| 346 |
# get the embeddings
|
| 347 |
text_embeddings = get_unweighted_text_embeddings(
|
|
@@ -350,8 +349,7 @@ def get_weighted_text_embeddings(
|
|
| 350 |
pipe.tokenizer.model_max_length,
|
| 351 |
no_boseos_middle=no_boseos_middle,
|
| 352 |
)
|
| 353 |
-
|
| 354 |
-
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=unet_device)
|
| 355 |
if uncond_prompt is not None:
|
| 356 |
uncond_embeddings = get_unweighted_text_embeddings(
|
| 357 |
pipe,
|
|
@@ -359,8 +357,7 @@ def get_weighted_text_embeddings(
|
|
| 359 |
pipe.tokenizer.model_max_length,
|
| 360 |
no_boseos_middle=no_boseos_middle,
|
| 361 |
)
|
| 362 |
-
|
| 363 |
-
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=unet_device)
|
| 364 |
|
| 365 |
# assign weights to the prompts and normalize in the sense of mean
|
| 366 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
|
@@ -484,59 +481,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 484 |
if not hasattr(self, "vae_scale_factor"):
|
| 485 |
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
| 486 |
|
| 487 |
-
def enable_sequential_cpu_offload(self, gpu_id=0):
|
| 488 |
-
r"""
|
| 489 |
-
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
| 490 |
-
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
| 491 |
-
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
| 492 |
-
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
| 493 |
-
`enable_model_cpu_offload`, but performance is lower.
|
| 494 |
-
"""
|
| 495 |
-
if is_accelerate_available():
|
| 496 |
-
from accelerate import cpu_offload
|
| 497 |
-
else:
|
| 498 |
-
raise ImportError("Please install accelerate via `pip install accelerate`")
|
| 499 |
-
|
| 500 |
-
device = torch.device(f"cuda:{gpu_id}")
|
| 501 |
-
|
| 502 |
-
if self.device.type != "cpu":
|
| 503 |
-
self.to("cpu", silence_dtype_warnings=True)
|
| 504 |
-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
| 505 |
-
|
| 506 |
-
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
| 507 |
-
cpu_offload(cpu_offloaded_model, device)
|
| 508 |
-
|
| 509 |
-
if self.safety_checker is not None:
|
| 510 |
-
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
| 511 |
-
|
| 512 |
-
def enable_model_cpu_offload(self, gpu_id=0):
|
| 513 |
-
r"""
|
| 514 |
-
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
| 515 |
-
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
| 516 |
-
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
| 517 |
-
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
| 518 |
-
"""
|
| 519 |
-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
| 520 |
-
from accelerate import cpu_offload_with_hook
|
| 521 |
-
else:
|
| 522 |
-
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
| 523 |
-
|
| 524 |
-
device = torch.device(f"cuda:{gpu_id}")
|
| 525 |
-
|
| 526 |
-
if self.device.type != "cpu":
|
| 527 |
-
self.to("cpu", silence_dtype_warnings=True)
|
| 528 |
-
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
| 529 |
-
|
| 530 |
-
hook = None
|
| 531 |
-
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
| 532 |
-
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
| 533 |
-
|
| 534 |
-
if self.safety_checker is not None:
|
| 535 |
-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
| 536 |
-
|
| 537 |
-
# We'll offload the last model manually.
|
| 538 |
-
self.final_offload_hook = hook
|
| 539 |
-
|
| 540 |
@property
|
| 541 |
def _execution_device(self):
|
| 542 |
r"""
|
|
@@ -544,8 +488,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 544 |
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
| 545 |
hooks.
|
| 546 |
"""
|
| 547 |
-
|
| 548 |
-
if not hasattr(self.unet, "_hf_hook"):
|
| 549 |
return self.device
|
| 550 |
for module in self.unet.modules():
|
| 551 |
if (
|
|
@@ -915,10 +858,6 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
| 915 |
if output_type == "pil":
|
| 916 |
image = self.numpy_to_pil(image)
|
| 917 |
|
| 918 |
-
# 12. Offload last model to CPU
|
| 919 |
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 920 |
-
self.final_offload_hook.offload()
|
| 921 |
-
|
| 922 |
if not return_dict:
|
| 923 |
return image, has_nsfw_concept
|
| 924 |
|
|
|
|
| 16 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
from diffusers.utils import logging
|
| 19 |
+
|
| 20 |
|
| 21 |
try:
|
| 22 |
from diffusers.utils import PIL_INTERPOLATION
|
|
|
|
| 281 |
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 282 |
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 283 |
"""
|
|
|
|
| 284 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 285 |
if isinstance(prompt, str):
|
| 286 |
prompt = [prompt]
|
|
|
|
| 329 |
no_boseos_middle=no_boseos_middle,
|
| 330 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 331 |
)
|
| 332 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
| 333 |
if uncond_prompt is not None:
|
| 334 |
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 335 |
uncond_tokens,
|
|
|
|
| 340 |
no_boseos_middle=no_boseos_middle,
|
| 341 |
chunk_length=pipe.tokenizer.model_max_length,
|
| 342 |
)
|
| 343 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
| 344 |
|
| 345 |
# get the embeddings
|
| 346 |
text_embeddings = get_unweighted_text_embeddings(
|
|
|
|
| 349 |
pipe.tokenizer.model_max_length,
|
| 350 |
no_boseos_middle=no_boseos_middle,
|
| 351 |
)
|
| 352 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
|
|
|
| 353 |
if uncond_prompt is not None:
|
| 354 |
uncond_embeddings = get_unweighted_text_embeddings(
|
| 355 |
pipe,
|
|
|
|
| 357 |
pipe.tokenizer.model_max_length,
|
| 358 |
no_boseos_middle=no_boseos_middle,
|
| 359 |
)
|
| 360 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
|
|
|
| 361 |
|
| 362 |
# assign weights to the prompts and normalize in the sense of mean
|
| 363 |
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
|
|
|
| 481 |
if not hasattr(self, "vae_scale_factor"):
|
| 482 |
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
@property
|
| 485 |
def _execution_device(self):
|
| 486 |
r"""
|
|
|
|
| 488 |
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
| 489 |
hooks.
|
| 490 |
"""
|
| 491 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
|
|
|
| 492 |
return self.device
|
| 493 |
for module in self.unet.modules():
|
| 494 |
if (
|
|
|
|
| 858 |
if output_type == "pil":
|
| 859 |
image = self.numpy_to_pil(image)
|
| 860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
if not return_dict:
|
| 862 |
return image, has_nsfw_concept
|
| 863 |
|