import os import json import torch from attn_processor import AttnProcessor2_0, SkipAttnProcessor def init_adapter(unet, cross_attn_cls=SkipAttnProcessor, self_attn_cls=None, cross_attn_dim=None, **kwargs): if cross_attn_dim is None: cross_attn_dim = unet.config.cross_attention_dim attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if self_attn_cls is not None: attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) else: # retain the original attn processor attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) else: attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) return adapter_modules def init_diffusion_model(diffusion_model_name_or_path, unet_class=None): from diffusers import AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae") tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer") try: unet_folder = os.path.join(diffusion_model_name_or_path, "unet") unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r")) unet = unet_class(**unet_configs) unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True) except: unet = None return text_encoder, vae, tokenizer, unet def attn_of_unet(unet): attn_blocks = torch.nn.ModuleList() for name, param in unet.named_modules(): if "attn1" in name: attn_blocks.append(param) return attn_blocks def get_trainable_module(unet, trainable_module_name): if trainable_module_name == "unet": return unet elif trainable_module_name == "transformer": trainable_modules = torch.nn.ModuleList() for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: if hasattr(blocks, "attentions"): trainable_modules.append(blocks.attentions) else: for block in blocks: if hasattr(block, "attentions"): trainable_modules.append(block.attentions) return trainable_modules elif trainable_module_name == "attention": attn_blocks = torch.nn.ModuleList() for name, param in unet.named_modules(): if "attn1" in name: attn_blocks.append(param) return attn_blocks else: raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}") import torch import numpy as np from PIL import Image # ===================================================== # Image and VAE utility functions used by CatVTONPipeline # ===================================================== def compute_vae_encodings(image, vae): """Encode an image tensor using the model's VAE encoder.""" if isinstance(image, list): image = torch.cat(image, dim=0) latents = vae.encode(image).latent_dist.sample() latents = latents * vae.config.scaling_factor return latents def numpy_to_pil(images): """Convert numpy arrays to PIL Images.""" if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") return [Image.fromarray(image) for image in images] def prepare_image(image): """Convert PIL image to normalized torch tensor.""" if isinstance(image, Image.Image): image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) return image def prepare_mask_image(mask_image): """Convert PIL mask to tensor in [0,1] range.""" if isinstance(mask_image, Image.Image): mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 mask_image = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0) return mask_image def resize_and_crop(image, size): """Resize image keeping aspect ratio then center crop.""" if isinstance(image, Image.Image): image = image.resize(size, Image.BICUBIC) return image def resize_and_padding(image, size): """Resize and pad to match target size.""" if isinstance(image, Image.Image): image.thumbnail(size, Image.BICUBIC) new_image = Image.new("RGB", size) left = (size[0] - image.size[0]) // 2 top = (size[1] - image.size[1]) // 2 new_image.paste(image, (left, top)) image = new_image return image