|
|
|
|
|
import math |
|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image, ImageFilter, ImageOps |
|
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate |
|
|
from diffusers.image_processor import VaeImageProcessor |
|
|
|
|
|
class IPAdapterMaskProcessor(VaeImageProcessor): |
|
|
""" |
|
|
Image processor for IP Adapter image masks. |
|
|
|
|
|
Args: |
|
|
do_resize (`bool`, *optional*, defaults to `True`): |
|
|
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. |
|
|
vae_scale_factor (`int`, *optional*, defaults to `8`): |
|
|
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. |
|
|
resample (`str`, *optional*, defaults to `lanczos`): |
|
|
Resampling filter to use when resizing the image. |
|
|
do_normalize (`bool`, *optional*, defaults to `False`): |
|
|
Whether to normalize the image to [-1,1]. |
|
|
do_binarize (`bool`, *optional*, defaults to `True`): |
|
|
Whether to binarize the image to 0/1. |
|
|
do_convert_grayscale (`bool`, *optional*, defaults to be `True`): |
|
|
Whether to convert the images to grayscale format. |
|
|
|
|
|
""" |
|
|
|
|
|
config_name = CONFIG_NAME |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
do_resize: bool = True, |
|
|
vae_scale_factor: int = 8, |
|
|
resample: str = "lanczos", |
|
|
do_normalize: bool = False, |
|
|
do_binarize: bool = True, |
|
|
do_convert_grayscale: bool = True, |
|
|
): |
|
|
super().__init__( |
|
|
do_resize=do_resize, |
|
|
vae_scale_factor=vae_scale_factor, |
|
|
resample=resample, |
|
|
do_normalize=do_normalize, |
|
|
do_binarize=do_binarize, |
|
|
do_convert_grayscale=do_convert_grayscale, |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): |
|
|
""" |
|
|
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the |
|
|
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. |
|
|
|
|
|
Args: |
|
|
mask (`torch.Tensor`): |
|
|
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. |
|
|
batch_size (`int`): |
|
|
The batch size. |
|
|
num_queries (`int`): |
|
|
The number of queries. |
|
|
value_embed_dim (`int`): |
|
|
The dimensionality of the value embeddings. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: |
|
|
The downsampled mask tensor. |
|
|
|
|
|
""" |
|
|
o_h = mask.shape[1] |
|
|
o_w = mask.shape[2] |
|
|
ratio = o_w / o_h |
|
|
mask_h = int(torch.sqrt(torch.FloatTensor([num_queries / ratio]))[0]) |
|
|
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) |
|
|
mask_w = num_queries // mask_h |
|
|
|
|
|
mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) |
|
|
|
|
|
|
|
|
if mask_downsample.shape[0] < batch_size: |
|
|
mask_downsample = mask_downsample.repeat(batch_size, 1, 1) |
|
|
|
|
|
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1) |
|
|
|
|
|
downsampled_area = mask_h * mask_w |
|
|
|
|
|
|
|
|
if downsampled_area < num_queries: |
|
|
warnings.warn( |
|
|
"The aspect ratio of the mask does not match the aspect ratio of the output image. " |
|
|
"Please update your masks or adjust the output size for optimal performance.", |
|
|
UserWarning, |
|
|
) |
|
|
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0) |
|
|
|
|
|
if downsampled_area > num_queries: |
|
|
warnings.warn( |
|
|
"The aspect ratio of the mask does not match the aspect ratio of the output image. " |
|
|
"Please update your masks or adjust the output size for optimal performance.", |
|
|
UserWarning, |
|
|
) |
|
|
mask_downsample = mask_downsample[:, :num_queries] |
|
|
|
|
|
|
|
|
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat( |
|
|
1, 1, value_embed_dim |
|
|
) |
|
|
|
|
|
return mask_downsample |