Try-Space-Model / utils.py
feylur's picture
Update utils.py
cb2d5c5 verified
import os
import json
import torch
from attn_processor import AttnProcessor2_0, SkipAttnProcessor
def init_adapter(unet,
cross_attn_cls=SkipAttnProcessor,
self_attn_cls=None,
cross_attn_dim=None,
**kwargs):
if cross_attn_dim is None:
cross_attn_dim = unet.config.cross_attention_dim
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if self_attn_cls is not None:
attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
# retain the original attn processor
attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
else:
attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
return adapter_modules
def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer")
try:
unet_folder = os.path.join(diffusion_model_name_or_path, "unet")
unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r"))
unet = unet_class(**unet_configs)
unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True)
except:
unet = None
return text_encoder, vae, tokenizer, unet
def attn_of_unet(unet):
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
def get_trainable_module(unet, trainable_module_name):
if trainable_module_name == "unet":
return unet
elif trainable_module_name == "transformer":
trainable_modules = torch.nn.ModuleList()
for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:
if hasattr(blocks, "attentions"):
trainable_modules.append(blocks.attentions)
else:
for block in blocks:
if hasattr(block, "attentions"):
trainable_modules.append(block.attentions)
return trainable_modules
elif trainable_module_name == "attention":
attn_blocks = torch.nn.ModuleList()
for name, param in unet.named_modules():
if "attn1" in name:
attn_blocks.append(param)
return attn_blocks
else:
raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")
import torch
import numpy as np
from PIL import Image
# =====================================================
# Image and VAE utility functions used by CatVTONPipeline
# =====================================================
def compute_vae_encodings(image, vae):
"""Encode an image tensor using the model's VAE encoder."""
if isinstance(image, list):
image = torch.cat(image, dim=0)
latents = vae.encode(image).latent_dist.sample()
latents = latents * vae.config.scaling_factor
return latents
def numpy_to_pil(images):
"""Convert numpy arrays to PIL Images."""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
return [Image.fromarray(image) for image in images]
def prepare_image(image):
"""Convert PIL image to normalized torch tensor."""
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
return image
def prepare_mask_image(mask_image):
"""Convert PIL mask to tensor in [0,1] range."""
if isinstance(mask_image, Image.Image):
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
mask_image = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0)
return mask_image
def resize_and_crop(image, size):
"""Resize image keeping aspect ratio then center crop."""
if isinstance(image, Image.Image):
image = image.resize(size, Image.BICUBIC)
return image
def resize_and_padding(image, size):
"""Resize and pad to match target size."""
if isinstance(image, Image.Image):
image.thumbnail(size, Image.BICUBIC)
new_image = Image.new("RGB", size)
left = (size[0] - image.size[0]) // 2
top = (size[1] - image.size[1]) // 2
new_image.paste(image, (left, top))
image = new_image
return image