File size: 4,557 Bytes
a846205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torchvision.transforms import v2

from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTokenizer

from . import register
from .base import Base


def apply_padding(model, mode):
    for layer in [layer for _, layer in model.named_modules() if isinstance(layer, torch.nn.Conv2d)]:
        if mode == 'circular':
            layer.padding_mode = 'circular'
        else:
            layer.padding_mode = 'zeros'
    return model

def freeze(model):
    model = model.eval()
    for param in model.parameters():
        param.requires_grad = False
    return model

@register("stable_diffusion")
class StableDiffusion(Base):
    def setup(self):
        hf_key = self.config.get("hf_key", None)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        fp16 = self.config.get("fp16", True)
        self.dtype = torch.bfloat16 if fp16 else torch.float32
        vae_padding = self.config.get("vae_padding", "zeros")

        self.sd_version = self.config.get("version", 2.1)
        local_files_only = False
        if hf_key is not None:
            print(f"[INFO] using hugging face custom model key: {hf_key}")
            model_key = hf_key
            local_files_only = True
        elif str(self.sd_version) == "2.1":
            # model_key = "stabilityai/stable-diffusion-2-1"
            # StabilityAI deleted the original 2.1 model from HF, use a community version
            model_key = "RedbeardNZ/stable-diffusion-2-1-base"
        else:
            raise ValueError(
                f"Stable-diffusion version {self.sd_version} not supported."
            )

        # Load components separately to avoid download unnecessary weights
        # 1. UNet (diffusion backbone)
        unet_config = UNet2DConditionModel.load_config(model_key, subfolder="unet")
        self.unet = UNet2DConditionModel.from_config(unet_config, local_files_only=local_files_only)
        self.unet.to(self.device, dtype=self.dtype).eval()
        # 2. VAE (image autoencoder)
        vae_config = AutoencoderKL.load_config(model_key, subfolder="vae")
        self.vae = AutoencoderKL.from_config(vae_config, local_files_only=local_files_only)
        self.vae.to(self.device, dtype=self.dtype).eval()
        self.vae = apply_padding(freeze(self.vae), vae_padding)
        # 3. Text encoder (CLIP)
        text_encoder_config = CLIPTextConfig.from_pretrained(model_key, subfolder="text_encoder", local_files_only=local_files_only)
        self.text_encoder = CLIPTextModel(text_encoder_config)
        self.text_encoder.to(self.device, dtype=self.dtype).eval()
        # 4. Tokenizer (CLIP tokenizer, this one has vocab so from_pretrained is needed)
        self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer", local_files_only=local_files_only)
        # 5. Scheduler
        scheduler_config = DDIMScheduler.load_config(model_key, subfolder="scheduler")
        scheduler_config["prediction_type"] = "v_prediction"
        scheduler_config["timestep_spacing"] = "trailing"
        scheduler_config["rescale_betas_zero_snr"] = True
        self.scheduler = DDIMScheduler.from_config(scheduler_config)

    def encode_text(self, prompt, padding_mode="do_not_pad"):
        # prompt: [str]
        inputs = self.tokenizer(
            prompt,
            padding=padding_mode,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        )
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
        return embeddings    

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents
        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

    def encode_imgs(self, imgs):
        if imgs.shape[1] == 1: # for grayscale maps
            imgs = v2.functional.grayscale_to_rgb(imgs)
        imgs = 2 * imgs - 1
        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor
        return latents
    
    def encode_imgs_deterministic(self, imgs):
        if imgs.shape[1] == 1: # for grayscale maps
            imgs = v2.functional.grayscale_to_rgb(imgs)
        imgs = 2 * imgs - 1
        h = self.vae.encoder(imgs)
        moments = self.vae.quant_conv(h)
        mean, logvar = torch.chunk(moments, 2, dim=1)
        latents = mean * self.vae.config.scaling_factor
        return latents