chord-demo / chord /module /stable_diffusion.py
ksangk's picture
demo
a846205
import torch
from torchvision.transforms import v2
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer
from . import register
from .base import Base
def apply_padding(model, mode):
for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]:
if mode == 'circular':
layer.padding_mode = 'circular'
else:
layer.padding_mode = 'zeros'
return model
def freeze(model):
model = model.eval()
for param in model.parameters():
param.requires_grad = False
return model
@register("stable_diffusion")
class StableDiffusion(Base):
def setup(self):
hf_key = self.config.get("hf_key", None)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fp16 = self.config.get("fp16", True)
self.dtype = torch.bfloat16 if fp16 else torch.float32
vae_padding = self.config.get("vae_padding", "zeros")
self.sd_version = self.config.get("version", 2.1)
local_files_only = False
if hf_key is not None:
print(f"[INFO] using hugging face custom model key: {hf_key}")
model_key = hf_key
local_files_only = True
elif str(self.sd_version) == "2.1":
# model_key = "stabilityai/stable-diffusion-2-1"
# StabilityAI deleted the original 2.1 model from HF, use a community version
model_key = "RedbeardNZ/stable-diffusion-2-1-base"
else:
raise ValueError(
f"Stable-diffusion version {self.sd_version} not supported."
)
# Load components separately to avoid download unnecessary weights
# 1. UNet (diffusion backbone)
unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet")
self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only)
self.unet.to(self.device, dtype=self.dtype).eval()
# 2. VAE (image autoencoder)
vae_config = AutoencoderKL.load_config(model_key, subfolder="vae")
self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only)
self.vae.to(self.device, dtype=self.dtype).eval()
self.vae = apply_padding(freeze(self.vae), vae_padding)
# 3. Text encoder (CLIP)
text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only)
self.text_encoder = CLIPTextModel(text_encoder_config)
self.text_encoder.to(self.device, dtype=self.dtype).eval()
# 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed)
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only)
# 5. Scheduler
scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler")
scheduler_config["prediction_type"] = "v_prediction"
scheduler_config["timestep_spacing"] = "trailing"
scheduler_config["rescale_betas_zero_snr"] = True
self.scheduler = DDIMScheduler.from_config(scheduler_config)
def encode_text(self, prompt, padding_mode="do_not_pad"):
# prompt: [str]
inputs = self.tokenizer(
prompt,
padding=padding_mode,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
return embeddings
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs):
if imgs.shape[1] == 1: # for grayscale maps
imgs = v2.functional.grayscale_to_rgb(imgs)
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def encode_imgs_deterministic(self, imgs):
if imgs.shape[1] == 1: # for grayscale maps
imgs = v2.functional.grayscale_to_rgb(imgs)
imgs = 2 * imgs - 1
h = self.vae.encoder(imgs)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
latents = mean * self.vae.config.scaling_factor
return latents