File size: 3,299 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch


class RegionalIPData:
    """A class to manage the data for regional IP-Adapter conditioning."""

    def __init__(
        self,
        image_prompt_embeds: list[torch.Tensor],
        scales: list[float],
        masks: list[torch.Tensor],
        dtype: torch.dtype,
        device: torch.device,
        max_downscale_factor: int = 8,
    ):
        """Initialize a `IPAdapterConditioningData` object."""
        assert len(image_prompt_embeds) == len(scales) == len(masks)

        # The image prompt embeddings.
        # regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
        # has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
        self.image_prompt_embeds = image_prompt_embeds

        # The scales for the IP-Adapter attention.
        # scales[i] contains the attention scale for the i'th IP-Adapter.
        self.scales = scales

        # The IP-Adapter masks.
        # self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
        # s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
        # regions and 0.0 for excluded regions.
        self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)

    def _prepare_masks(
        self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
    ) -> dict[int, torch.Tensor]:
        """Prepare the masks for the IP-Adapter attention."""
        # Concatenate the masks so that they can be processed more efficiently.
        mask_tensor = torch.cat(masks, dim=1)

        mask_tensor = mask_tensor.to(device=device, dtype=dtype)

        masks_by_seq_len: dict[int, torch.Tensor] = {}

        # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
        downscale_factor = 1
        while downscale_factor <= max_downscale_factor:
            b, num_ip_adapters, h, w = mask_tensor.shape
            # Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
            assert b == 1

            # The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
            # the spatial features.
            query_seq_len = h * w

            masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))

            downscale_factor *= 2
            if downscale_factor <= max_downscale_factor:
                # We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
                # regions to be lost entirely.
                #
                # ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
                #
                # TODO(ryand): In the future, we may want to experiment with other downsampling methods.
                mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)

        return masks_by_seq_len

    def get_masks(self, query_seq_len: int) -> torch.Tensor:
        """Get the mask for the given query sequence length."""
        return self._masks_by_seq_len[query_seq_len]