Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,557 Bytes
a846205 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
from torchvision.transforms import v2
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer
from . import register
from .base import Base
def apply_padding(model, mode):
for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]:
if mode == 'circular':
layer.padding_mode = 'circular'
else:
layer.padding_mode = 'zeros'
return model
def freeze(model):
model = model.eval()
for param in model.parameters():
param.requires_grad = False
return model
@register("stable_diffusion")
class StableDiffusion(Base):
def setup(self):
hf_key = self.config.get("hf_key", None)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fp16 = self.config.get("fp16", True)
self.dtype = torch.bfloat16 if fp16 else torch.float32
vae_padding = self.config.get("vae_padding", "zeros")
self.sd_version = self.config.get("version", 2.1)
local_files_only = False
if hf_key is not None:
print(f"[INFO] using hugging face custom model key: {hf_key}")
model_key = hf_key
local_files_only = True
elif str(self.sd_version) == "2.1":
# model_key = "stabilityai/stable-diffusion-2-1"
# StabilityAI deleted the original 2.1 model from HF, use a community version
model_key = "RedbeardNZ/stable-diffusion-2-1-base"
else:
raise ValueError(
f"Stable-diffusion version {self.sd_version} not supported."
)
# Load components separately to avoid download unnecessary weights
# 1. UNet (diffusion backbone)
unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet")
self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only)
self.unet.to(self.device, dtype=self.dtype).eval()
# 2. VAE (image autoencoder)
vae_config = AutoencoderKL.load_config(model_key, subfolder="vae")
self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only)
self.vae.to(self.device, dtype=self.dtype).eval()
self.vae = apply_padding(freeze(self.vae), vae_padding)
# 3. Text encoder (CLIP)
text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only)
self.text_encoder = CLIPTextModel(text_encoder_config)
self.text_encoder.to(self.device, dtype=self.dtype).eval()
# 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed)
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only)
# 5. Scheduler
scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler")
scheduler_config["prediction_type"] = "v_prediction"
scheduler_config["timestep_spacing"] = "trailing"
scheduler_config["rescale_betas_zero_snr"] = True
self.scheduler = DDIMScheduler.from_config(scheduler_config)
def encode_text(self, prompt, padding_mode="do_not_pad"):
# prompt: [str]
inputs = self.tokenizer(
prompt,
padding=padding_mode,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
if imgs.shape[1] == 1: # for grayscale maps
imgs = v2.functional.grayscale_to_rgb(imgs)
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def encode_imgs_deterministic(self, imgs):
if imgs.shape[1] == 1: # for grayscale maps
imgs = v2.functional.grayscale_to_rgb(imgs)
imgs = 2 * imgs - 1
h = self.vae.encoder(imgs)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
latents = mean * self.vae.config.scaling_factor
return latents |