|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from attn_processor import AttnProcessor2_0, SkipAttnProcessor |
|
|
|
|
|
|
|
|
def init_adapter(unet, |
|
|
cross_attn_cls=SkipAttnProcessor, |
|
|
self_attn_cls=None, |
|
|
cross_attn_dim=None, |
|
|
**kwargs): |
|
|
if cross_attn_dim is None: |
|
|
cross_attn_dim = unet.config.cross_attention_dim |
|
|
attn_procs = {} |
|
|
for name in unet.attn_processors.keys(): |
|
|
cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim |
|
|
if name.startswith("mid_block"): |
|
|
hidden_size = unet.config.block_out_channels[-1] |
|
|
elif name.startswith("up_blocks"): |
|
|
block_id = int(name[len("up_blocks.")]) |
|
|
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
|
|
elif name.startswith("down_blocks"): |
|
|
block_id = int(name[len("down_blocks.")]) |
|
|
hidden_size = unet.config.block_out_channels[block_id] |
|
|
if cross_attention_dim is None: |
|
|
if self_attn_cls is not None: |
|
|
attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
|
|
else: |
|
|
|
|
|
attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
|
|
else: |
|
|
attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) |
|
|
|
|
|
unet.set_attn_processor(attn_procs) |
|
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) |
|
|
return adapter_modules |
|
|
|
|
|
def init_diffusion_model(diffusion_model_name_or_path, unet_class=None): |
|
|
from diffusers import AutoencoderKL |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder") |
|
|
vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae") |
|
|
tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer") |
|
|
try: |
|
|
unet_folder = os.path.join(diffusion_model_name_or_path, "unet") |
|
|
unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r")) |
|
|
unet = unet_class(**unet_configs) |
|
|
unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True) |
|
|
except: |
|
|
unet = None |
|
|
return text_encoder, vae, tokenizer, unet |
|
|
|
|
|
def attn_of_unet(unet): |
|
|
attn_blocks = torch.nn.ModuleList() |
|
|
for name, param in unet.named_modules(): |
|
|
if "attn1" in name: |
|
|
attn_blocks.append(param) |
|
|
return attn_blocks |
|
|
|
|
|
def get_trainable_module(unet, trainable_module_name): |
|
|
if trainable_module_name == "unet": |
|
|
return unet |
|
|
elif trainable_module_name == "transformer": |
|
|
trainable_modules = torch.nn.ModuleList() |
|
|
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: |
|
|
if hasattr(blocks, "attentions"): |
|
|
trainable_modules.append(blocks.attentions) |
|
|
else: |
|
|
for block in blocks: |
|
|
if hasattr(block, "attentions"): |
|
|
trainable_modules.append(block.attentions) |
|
|
return trainable_modules |
|
|
elif trainable_module_name == "attention": |
|
|
attn_blocks = torch.nn.ModuleList() |
|
|
for name, param in unet.named_modules(): |
|
|
if "attn1" in name: |
|
|
attn_blocks.append(param) |
|
|
return attn_blocks |
|
|
else: |
|
|
raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}") |
|
|
|
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_vae_encodings(image, vae): |
|
|
"""Encode an image tensor using the model's VAE encoder.""" |
|
|
if isinstance(image, list): |
|
|
image = torch.cat(image, dim=0) |
|
|
latents = vae.encode(image).latent_dist.sample() |
|
|
latents = latents * vae.config.scaling_factor |
|
|
return latents |
|
|
|
|
|
|
|
|
def numpy_to_pil(images): |
|
|
"""Convert numpy arrays to PIL Images.""" |
|
|
if images.ndim == 3: |
|
|
images = images[None, ...] |
|
|
images = (images * 255).round().astype("uint8") |
|
|
return [Image.fromarray(image) for image in images] |
|
|
|
|
|
|
|
|
def prepare_image(image): |
|
|
"""Convert PIL image to normalized torch tensor.""" |
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 |
|
|
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) |
|
|
return image |
|
|
|
|
|
|
|
|
def prepare_mask_image(mask_image): |
|
|
"""Convert PIL mask to tensor in [0,1] range.""" |
|
|
if isinstance(mask_image, Image.Image): |
|
|
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 |
|
|
mask_image = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0) |
|
|
return mask_image |
|
|
|
|
|
|
|
|
def resize_and_crop(image, size): |
|
|
"""Resize image keeping aspect ratio then center crop.""" |
|
|
if isinstance(image, Image.Image): |
|
|
image = image.resize(size, Image.BICUBIC) |
|
|
return image |
|
|
|
|
|
|
|
|
def resize_and_padding(image, size): |
|
|
"""Resize and pad to match target size.""" |
|
|
if isinstance(image, Image.Image): |
|
|
image.thumbnail(size, Image.BICUBIC) |
|
|
new_image = Image.new("RGB", size) |
|
|
left = (size[0] - image.size[0]) // 2 |
|
|
top = (size[1] - image.size[1]) // 2 |
|
|
new_image.paste(image, (left, top)) |
|
|
image = new_image |
|
|
return image |
|
|
|
|
|
|