Spaces:
Running
Running
| from enum import Enum, auto | |
| import torch | |
| from huggingface_hub import ( # pyright: ignore[reportMissingTypeStubs] | |
| hf_hub_download, # pyright: ignore[reportUnknownVariableType] | |
| ) | |
| from PIL import Image | |
| from refiners.fluxion.utils import load_from_safetensors, tensor_to_image | |
| from refiners.foundationals.clip import CLIPTextEncoderL | |
| from refiners.foundationals.latent_diffusion import SD1UNet | |
| from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder | |
| from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight | |
| def load_ic_light(device: torch.device, dtype: torch.dtype) -> ICLight: | |
| return ICLight( | |
| patch_weights=load_from_safetensors( | |
| path=hf_hub_download( | |
| repo_id="refiners/sd15.ic_light.fc", | |
| filename="model.safetensors", | |
| revision="ea10b4403e97c786a98afdcbdf0e0fec794ea542", | |
| ), | |
| ), | |
| unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors( | |
| tensors_path=hf_hub_download( | |
| repo_id="refiners/sd15.realistic_vision.v5_1.unet", | |
| filename="model.safetensors", | |
| revision="94f74be7adfd27bee330ea1071481c0254c29989", | |
| ) | |
| ), | |
| clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors( | |
| tensors_path=hf_hub_download( | |
| repo_id="refiners/sd15.realistic_vision.v5_1.text_encoder", | |
| filename="model.safetensors", | |
| revision="7f6fa1e870c8f197d34488e14b89e63fb8d7fd6e", | |
| ) | |
| ), | |
| lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors( | |
| tensors_path=hf_hub_download( | |
| repo_id="refiners/sd15.realistic_vision.v5_1.autoencoder", | |
| filename="model.safetensors", | |
| revision="99f089787a6e1a852a0992da1e286a19fcbbaa50", | |
| ) | |
| ), | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def resize_modulo_8( | |
| image: Image.Image, | |
| size: int = 768, | |
| resample: Image.Resampling | None = None, | |
| on_short: bool = True, | |
| ) -> Image.Image: | |
| """Resize an image respecting the aspect ratio and ensuring the size is a multiple of 8. | |
| The `on_short` parameter determines whether the resizing is based on the shortest side. | |
| """ | |
| assert size % 8 == 0, "Size must be a multiple of 8 because this is the latent compression size." | |
| side_size = min(image.size) if on_short else max(image.size) | |
| scale = size / (side_size * 8) | |
| new_size = (int(image.width * scale) * 8, int(image.height * scale) * 8) | |
| return image.resize(new_size, resample=resample or Image.Resampling.LANCZOS) | |
| class LightingPreference(str, Enum): | |
| LEFT = auto() | |
| RIGHT = auto() | |
| TOP = auto() | |
| BOTTOM = auto() | |
| NONE = auto() | |
| def get_init_image(self, width: int, height: int, interval: tuple[float, float] = (0.0, 1.0)) -> Image.Image | None: | |
| """Generate an image with a linear gradient based on the lighting preference. | |
| In the original code, interval is always (0., 1.) ; we added it as a parameter to make the function more | |
| flexible and allow for less contrasted images with a smaller interval. | |
| see https://github.com/lllyasviel/IC-Light/blob/7886874/gradio_demo.py#L242 | |
| """ | |
| start, end = interval | |
| match self: | |
| case LightingPreference.LEFT: | |
| tensor = torch.linspace(end, start, width).repeat(1, 1, height, 1) | |
| case LightingPreference.RIGHT: | |
| tensor = torch.linspace(start, end, width).repeat(1, 1, height, 1) | |
| case LightingPreference.TOP: | |
| tensor = torch.linspace(end, start, height).repeat(1, 1, width, 1).transpose(2, 3) | |
| case LightingPreference.BOTTOM: | |
| tensor = torch.linspace(start, end, height).repeat(1, 1, width, 1).transpose(2, 3) | |
| case LightingPreference.NONE: | |
| return None | |
| return tensor_to_image(tensor).convert("RGB") | |
| def from_str(cls, value: str): | |
| match value.lower(): | |
| case "left": | |
| return LightingPreference.LEFT | |
| case "right": | |
| return LightingPreference.RIGHT | |
| case "top": | |
| return LightingPreference.TOP | |
| case "bottom": | |
| return LightingPreference.BOTTOM | |
| case "none": | |
| return LightingPreference.NONE | |
| case _: | |
| raise ValueError(f"Invalid lighting preference: {value}") | |