Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torchvision.transforms import v2 | |
| from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler | |
| from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer | |
| from . import register | |
| from .base import Base | |
| def apply_padding(model, mode): | |
| for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]: | |
| if mode == 'circular': | |
| layer.padding_mode = 'circular' | |
| else: | |
| layer.padding_mode = 'zeros' | |
| return model | |
| def freeze(model): | |
| model = model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| class StableDiffusion(Base): | |
| def setup(self): | |
| hf_key = self.config.get("hf_key", None) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| fp16 = self.config.get("fp16", True) | |
| self.dtype = torch.bfloat16 if fp16 else torch.float32 | |
| vae_padding = self.config.get("vae_padding", "zeros") | |
| self.sd_version = self.config.get("version", 2.1) | |
| local_files_only = False | |
| if hf_key is not None: | |
| print(f"[INFO] using hugging face custom model key: {hf_key}") | |
| model_key = hf_key | |
| local_files_only = True | |
| elif str(self.sd_version) == "2.1": | |
| # model_key = "stabilityai/stable-diffusion-2-1" | |
| # StabilityAI deleted the original 2.1 model from HF, use a community version | |
| model_key = "RedbeardNZ/stable-diffusion-2-1-base" | |
| else: | |
| raise ValueError( | |
| f"Stable-diffusion version {self.sd_version} not supported." | |
| ) | |
| # Load components separately to avoid download unnecessary weights | |
| # 1. UNet (diffusion backbone) | |
| unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet") | |
| self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only) | |
| self.unet.to(self.device, dtype=self.dtype).eval() | |
| # 2. VAE (image autoencoder) | |
| vae_config = AutoencoderKL.load_config(model_key, subfolder="vae") | |
| self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only) | |
| self.vae.to(self.device, dtype=self.dtype).eval() | |
| self.vae = apply_padding(freeze(self.vae), vae_padding) | |
| # 3. Text encoder (CLIP) | |
| text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only) | |
| self.text_encoder = CLIPTextModel(text_encoder_config) | |
| self.text_encoder.to(self.device, dtype=self.dtype).eval() | |
| # 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed) | |
| self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only) | |
| # 5. Scheduler | |
| scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler") | |
| scheduler_config["prediction_type"] = "v_prediction" | |
| scheduler_config["timestep_spacing"] = "trailing" | |
| scheduler_config["rescale_betas_zero_snr"] = True | |
| self.scheduler = DDIMScheduler.from_config(scheduler_config) | |
| def encode_text(self, prompt, padding_mode="do_not_pad"): | |
| # prompt: [str] | |
| inputs = self.tokenizer( | |
| prompt, | |
| padding=padding_mode, | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ) | |
| embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] | |
| return embeddings | |
| def decode_latents(self, latents): | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| imgs = self.vae.decode(latents).sample | |
| imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
| return imgs | |
| def encode_imgs(self, imgs): | |
| if imgs.shape[1] == 1: # for grayscale maps | |
| imgs = v2.functional.grayscale_to_rgb(imgs) | |
| imgs = 2 * imgs - 1 | |
| posterior = self.vae.encode(imgs).latent_dist | |
| latents = posterior.sample() * self.vae.config.scaling_factor | |
| return latents | |
| def encode_imgs_deterministic(self, imgs): | |
| if imgs.shape[1] == 1: # for grayscale maps | |
| imgs = v2.functional.grayscale_to_rgb(imgs) | |
| imgs = 2 * imgs - 1 | |
| h = self.vae.encoder(imgs) | |
| moments = self.vae.quant_conv(h) | |
| mean, logvar = torch.chunk(moments, 2, dim=1) | |
| latents = mean * self.vae.config.scaling_factor | |
| return latents |