| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import Tuple, Union |
| |
|
| | import torch |
| | import torch.fft as fft |
| |
|
| | from ..utils.torch_utils import randn_tensor |
| |
|
| |
|
| | class FreeInitMixin: |
| | r"""Mixin class for FreeInit.""" |
| |
|
| | def enable_free_init( |
| | self, |
| | num_iters: int = 3, |
| | use_fast_sampling: bool = False, |
| | method: str = "butterworth", |
| | order: int = 4, |
| | spatial_stop_frequency: float = 0.25, |
| | temporal_stop_frequency: float = 0.25, |
| | ): |
| | """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537. |
| | |
| | This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit). |
| | |
| | Args: |
| | num_iters (`int`, *optional*, defaults to `3`): |
| | Number of FreeInit noise re-initialization iterations. |
| | use_fast_sampling (`bool`, *optional*, defaults to `False`): |
| | Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables |
| | the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`. |
| | method (`str`, *optional*, defaults to `butterworth`): |
| | Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the |
| | FreeInit low pass filter. |
| | order (`int`, *optional*, defaults to `4`): |
| | Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour |
| | whereas lower values lead to `gaussian` method behaviour. |
| | spatial_stop_frequency (`float`, *optional*, defaults to `0.25`): |
| | Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in |
| | the original implementation. |
| | temporal_stop_frequency (`float`, *optional*, defaults to `0.25`): |
| | Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in |
| | the original implementation. |
| | """ |
| | self._free_init_num_iters = num_iters |
| | self._free_init_use_fast_sampling = use_fast_sampling |
| | self._free_init_method = method |
| | self._free_init_order = order |
| | self._free_init_spatial_stop_frequency = spatial_stop_frequency |
| | self._free_init_temporal_stop_frequency = temporal_stop_frequency |
| |
|
| | def disable_free_init(self): |
| | """Disables the FreeInit mechanism if enabled.""" |
| | self._free_init_num_iters = None |
| |
|
| | @property |
| | def free_init_enabled(self): |
| | return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None |
| |
|
| | def _get_free_init_freq_filter( |
| | self, |
| | shape: Tuple[int, ...], |
| | device: Union[str, torch.dtype], |
| | filter_type: str, |
| | order: float, |
| | spatial_stop_frequency: float, |
| | temporal_stop_frequency: float, |
| | ) -> torch.Tensor: |
| | r"""Returns the FreeInit filter based on filter type and other input conditions.""" |
| |
|
| | time, height, width = shape[-3], shape[-2], shape[-1] |
| | mask = torch.zeros(shape) |
| |
|
| | if spatial_stop_frequency == 0 or temporal_stop_frequency == 0: |
| | return mask |
| |
|
| | if filter_type == "butterworth": |
| |
|
| | def retrieve_mask(x): |
| | return 1 / (1 + (x / spatial_stop_frequency**2) ** order) |
| | elif filter_type == "gaussian": |
| |
|
| | def retrieve_mask(x): |
| | return math.exp(-1 / (2 * spatial_stop_frequency**2) * x) |
| | elif filter_type == "ideal": |
| |
|
| | def retrieve_mask(x): |
| | return 1 if x <= spatial_stop_frequency * 2 else 0 |
| | else: |
| | raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal") |
| |
|
| | for t in range(time): |
| | for h in range(height): |
| | for w in range(width): |
| | d_square = ( |
| | ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2 |
| | + (2 * h / height - 1) ** 2 |
| | + (2 * w / width - 1) ** 2 |
| | ) |
| | mask[..., t, h, w] = retrieve_mask(d_square) |
| |
|
| | return mask.to(device) |
| |
|
| | def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor: |
| | r"""Noise reinitialization.""" |
| | |
| | x_freq = fft.fftn(x, dim=(-3, -2, -1)) |
| | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) |
| | noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) |
| | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) |
| |
|
| | |
| | high_pass_filter = 1 - low_pass_filter |
| | x_freq_low = x_freq * low_pass_filter |
| | noise_freq_high = noise_freq * high_pass_filter |
| | x_freq_mixed = x_freq_low + noise_freq_high |
| |
|
| | |
| | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) |
| | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real |
| |
|
| | return x_mixed |
| |
|
| | def _apply_free_init( |
| | self, |
| | latents: torch.Tensor, |
| | free_init_iteration: int, |
| | num_inference_steps: int, |
| | device: torch.device, |
| | dtype: torch.dtype, |
| | generator: torch.Generator, |
| | ): |
| | if free_init_iteration == 0: |
| | self._free_init_initial_noise = latents.detach().clone() |
| | return latents, self.scheduler.timesteps |
| |
|
| | latent_shape = latents.shape |
| |
|
| | free_init_filter_shape = (1, *latent_shape[1:]) |
| | free_init_freq_filter = self._get_free_init_freq_filter( |
| | shape=free_init_filter_shape, |
| | device=device, |
| | filter_type=self._free_init_method, |
| | order=self._free_init_order, |
| | spatial_stop_frequency=self._free_init_spatial_stop_frequency, |
| | temporal_stop_frequency=self._free_init_temporal_stop_frequency, |
| | ) |
| |
|
| | current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 |
| | diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long() |
| |
|
| | z_t = self.scheduler.add_noise( |
| | original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device) |
| | ).to(dtype=torch.float32) |
| |
|
| | z_rand = randn_tensor( |
| | shape=latent_shape, |
| | generator=generator, |
| | device=device, |
| | dtype=torch.float32, |
| | ) |
| | latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter) |
| | latents = latents.to(dtype) |
| |
|
| | |
| | if self._free_init_use_fast_sampling: |
| | num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1)) |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| |
|
| | return latents, self.scheduler.timesteps |
| |
|