Spaces:
Runtime error
Runtime error
| # thanks to MinusZoneAI: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/b98b98bd04621e4c85547866c12de2ec723ae98a/mz_enable_vae_encode_tiling.py | |
| from typing import Optional | |
| import torch | |
| from diffusers.utils.accelerate_utils import apply_forward_hook | |
| from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution | |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
| def encode( | |
| self, x: torch.Tensor, return_dict: bool = True | |
| ): | |
| """ | |
| Encode a batch of images into latents. | |
| Args: | |
| x (`torch.Tensor`): Input batch of images. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. | |
| Returns: | |
| The latent representations of the encoded videos. If `return_dict` is True, a | |
| [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. | |
| """ | |
| if self.use_slicing and x.shape[0] > 1: | |
| encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] | |
| h = torch.cat(encoded_slices) | |
| else: | |
| h = self._encode(x) | |
| posterior = DiagonalGaussianDistribution(h) | |
| if not return_dict: | |
| return (posterior,) | |
| return AutoencoderKLOutput(latent_dist=posterior) | |
| def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: | |
| r"""Encode a batch of images using a tiled encoder. | |
| When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several | |
| steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is | |
| different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the | |
| tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the | |
| output, but they should be much less noticeable. | |
| Args: | |
| x (`torch.Tensor`): Input batch of videos. | |
| Returns: | |
| `torch.Tensor`: | |
| The latent representation of the encoded videos. | |
| """ | |
| # For a rough memory estimate, take a look at the `tiled_decode` method. | |
| batch_size, num_channels, num_frames, height, width = x.shape | |
| overlap_height = int(self.tile_sample_min_height * | |
| (1 - self.tile_overlap_factor_height)) | |
| overlap_width = int(self.tile_sample_min_width * | |
| (1 - self.tile_overlap_factor_width)) | |
| blend_extent_height = int( | |
| self.tile_latent_min_height * self.tile_overlap_factor_height) | |
| blend_extent_width = int( | |
| self.tile_latent_min_width * self.tile_overlap_factor_width) | |
| row_limit_height = self.tile_latent_min_height - blend_extent_height | |
| row_limit_width = self.tile_latent_min_width - blend_extent_width | |
| frame_batch_size = 4 | |
| # Split x into overlapping tiles and encode them separately. | |
| # The tiles have an overlap to avoid seams between tiles. | |
| rows = [] | |
| for i in range(0, height, overlap_height): | |
| row = [] | |
| for j in range(0, width, overlap_width): | |
| # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. | |
| num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 | |
| time = [] | |
| for k in range(num_batches): | |
| remaining_frames = num_frames % frame_batch_size | |
| start_frame = frame_batch_size * k + \ | |
| (0 if k == 0 else remaining_frames) | |
| end_frame = frame_batch_size * (k + 1) + remaining_frames | |
| tile = x[ | |
| :, | |
| :, | |
| start_frame:end_frame, | |
| i: i + self.tile_sample_min_height, | |
| j: j + self.tile_sample_min_width, | |
| ] | |
| tile = self.encoder(tile) | |
| if not isinstance(tile, tuple): | |
| tile = (tile,) | |
| if self.quant_conv is not None: | |
| tile = self.quant_conv(tile) | |
| time.append(tile[0]) | |
| try: | |
| self._clear_fake_context_parallel_cache() | |
| except: | |
| pass | |
| row.append(torch.cat(time, dim=2)) | |
| rows.append(row) | |
| result_rows = [] | |
| for i, row in enumerate(rows): | |
| result_row = [] | |
| for j, tile in enumerate(row): | |
| # blend the above tile and the left tile | |
| # to the current tile and add the current tile to the result row | |
| if i > 0: | |
| tile = self.blend_v( | |
| rows[i - 1][j], tile, blend_extent_height) | |
| if j > 0: | |
| tile = self.blend_h(row[j - 1], tile, blend_extent_width) | |
| result_row.append( | |
| tile[:, :, :, :row_limit_height, :row_limit_width]) | |
| result_rows.append(torch.cat(result_row, dim=4)) | |
| enc = torch.cat(result_rows, dim=3) | |
| return enc | |
| def _encode( | |
| self, x: torch.Tensor, return_dict: bool = True | |
| ): | |
| batch_size, num_channels, num_frames, height, width = x.shape | |
| if self.use_encode_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): | |
| return self.tiled_encode(x) | |
| if num_frames == 1: | |
| h = self.encoder(x) | |
| if self.quant_conv is not None: | |
| h = self.quant_conv(h) | |
| posterior = DiagonalGaussianDistribution(h) | |
| else: | |
| frame_batch_size = 4 | |
| h = [] | |
| for i in range(num_frames // frame_batch_size): | |
| remaining_frames = num_frames % frame_batch_size | |
| start_frame = frame_batch_size * i + \ | |
| (0 if i == 0 else remaining_frames) | |
| end_frame = frame_batch_size * (i + 1) + remaining_frames | |
| z_intermediate = x[:, :, start_frame:end_frame] | |
| z_intermediate = self.encoder(z_intermediate) | |
| if self.quant_conv is not None: | |
| z_intermediate = self.quant_conv(z_intermediate) | |
| h.append(z_intermediate) | |
| try: | |
| self._clear_fake_context_parallel_cache() | |
| except: | |
| pass | |
| h = torch.cat(h, dim=2) | |
| return h | |
| def enable_encode_tiling( | |
| self, | |
| tile_sample_min_height: Optional[int] = None, | |
| tile_sample_min_width: Optional[int] = None, | |
| tile_overlap_factor_height: Optional[float] = None, | |
| tile_overlap_factor_width: Optional[float] = None, | |
| ) -> None: | |
| r""" | |
| Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | |
| compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | |
| processing larger images. | |
| Args: | |
| tile_sample_min_height (`int`, *optional*): | |
| The minimum height required for a sample to be separated into tiles across the height dimension. | |
| tile_sample_min_width (`int`, *optional*): | |
| The minimum width required for a sample to be separated into tiles across the width dimension. | |
| tile_overlap_factor_height (`int`, *optional*): | |
| The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are | |
| no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher | |
| value might cause more tiles to be processed leading to slow down of the decoding process. | |
| tile_overlap_factor_width (`int`, *optional*): | |
| The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there | |
| are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher | |
| value might cause more tiles to be processed leading to slow down of the decoding process. | |
| """ | |
| self.use_encode_tiling = True | |
| self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height | |
| self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width | |
| self.tile_latent_min_height = int( | |
| self.tile_sample_min_height / | |
| (2 ** (len(self.config.block_out_channels) - 1)) | |
| ) | |
| self.tile_latent_min_width = int( | |
| self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) | |
| self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height | |
| self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width | |
| from types import MethodType | |
| def enable_vae_encode_tiling(vae): | |
| vae.encode = MethodType(encode, vae) | |
| setattr(vae, "_encode", MethodType(_encode, vae)) | |
| setattr(vae, "tiled_encode", MethodType(tiled_encode, vae)) | |
| setattr(vae, "use_encode_tiling", True) | |
| setattr(vae, "enable_encode_tiling", MethodType(enable_encode_tiling, vae)) | |
| vae.enable_encode_tiling() | |
| return vae | |