File size: 5,703 Bytes
9907bc1 7f0ab34 9907bc1 cb2d5c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import json
import torch
from attn_processor import AttnProcessor2_0, SkipAttnProcessor
def init_adapter(unet,
cross_attn_cls=SkipAttnProcessor,
self_attn_cls=None,
cross_attn_dim=None,
**kwargs):
if cross_attn_dim is None:
cross_attn_dim = unet.config.cross_attention_dim
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if self_attn_cls is not None:
attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
# retain the original attn processor
attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
return adapter_modules
def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer")
try:
unet_folder = os.path.join(diffusion_model_name_or_path, "unet")
unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r"))
unet = unet_class(**unet_configs)
unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True)
except:
unet = None
return text_encoder, vae, tokenizer, unet
def attn_of_unet(unet):
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
def get_trainable_module(unet, trainable_module_name):
if trainable_module_name == "unet":
return unet
elif trainable_module_name == "transformer":
trainable_modules = torch.nn.ModuleList()
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
if hasattr(blocks, "attentions"):
trainable_modules.append(blocks.attentions)
else:
for block in blocks:
if hasattr(block, "attentions"):
trainable_modules.append(block.attentions)
return trainable_modules
elif trainable_module_name == "attention":
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
else:
raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")
import torch
import numpy as np
from PIL import Image
# =====================================================
# Image and VAE utility functions used by CatVTONPipeline
# =====================================================
def compute_vae_encodings(image, vae):
"""Encode an image tensor using the model's VAE encoder."""
if isinstance(image, list):
image = torch.cat(image, dim=0)
latents = vae.encode(image).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents
def numpy_to_pil(images):
"""Convert numpy arrays to PIL Images."""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
return [Image.fromarray(image) for image in images]
def prepare_image(image):
"""Convert PIL image to normalized torch tensor."""
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
return image
def prepare_mask_image(mask_image):
"""Convert PIL mask to tensor in [0,1] range."""
if isinstance(mask_image, Image.Image):
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
mask_image = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0)
return mask_image
def resize_and_crop(image, size):
"""Resize image keeping aspect ratio then center crop."""
if isinstance(image, Image.Image):
image = image.resize(size, Image.BICUBIC)
return image
def resize_and_padding(image, size):
"""Resize and pad to match target size."""
if isinstance(image, Image.Image):
image.thumbnail(size, Image.BICUBIC)
new_image = Image.new("RGB", size)
left = (size[0] - image.size[0]) // 2
top = (size[1] - image.size[1]) // 2
new_image.paste(image, (left, top))
image = new_image
return image
|