Added freeu functions to unet
Browse files- pipeline.py +26 -0
pipeline.py
CHANGED
|
@@ -864,6 +864,32 @@ class StableDiffusionLongPromptWeightingPipeline(
|
|
| 864 |
latents = init_latents
|
| 865 |
return latents, init_latents_orig, noise
|
| 866 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
@torch.no_grad()
|
| 868 |
def __call__(
|
| 869 |
self,
|
|
|
|
| 864 |
latents = init_latents
|
| 865 |
return latents, init_latents_orig, noise
|
| 866 |
|
| 867 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
| 868 |
+
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
| 869 |
+
|
| 870 |
+
The suffixes after the scaling factors represent the stages where they are being applied.
|
| 871 |
+
|
| 872 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
| 873 |
+
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
s1 (`float`):
|
| 877 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
| 878 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
| 879 |
+
s2 (`float`):
|
| 880 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
| 881 |
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
| 882 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
| 883 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
| 884 |
+
"""
|
| 885 |
+
if not hasattr(self, "unet"):
|
| 886 |
+
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
| 887 |
+
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
| 888 |
+
|
| 889 |
+
def disable_freeu(self):
|
| 890 |
+
"""Disables the FreeU mechanism if enabled."""
|
| 891 |
+
self.unet.disable_freeu()
|
| 892 |
+
|
| 893 |
@torch.no_grad()
|
| 894 |
def __call__(
|
| 895 |
self,
|