File size: 4,033 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig


class VisualPromptEncodeConfig(PretrainedConfig):
    model_type = 'vision_encoder'
    _auto_class = 'AutoConfig'
    main_input_name = "visual_prompts"

    def __init__(self, 
                 vision_hidden_size: int,
                 language_hidden_size: int,
                 patch_size: int,
                 downsample_ratio,
                 **kwargs):
        super().__init__(**kwargs)


class VisualPromptEncodeModel(nn.Module):
    def __init__(self,
                 in_channels: int,
                 vision_hidden_size: int,
                 language_hidden_size: int,
                 force_image_size: int,
                 patch_size: int,
                 downsample_ratio: int,
                 ) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.vision_hidden_size = vision_hidden_size
        self.language_hidden_size = language_hidden_size
        self.force_image_size = force_image_size
        self.patch_size = patch_size
        self.downsample_ratio = downsample_ratio

        self.patch_embedding = nn.Conv2d(
            in_channels=in_channels, out_channels=vision_hidden_size, 
            kernel_size=patch_size, stride=patch_size
        )
        
        self.mlp1 = nn.Sequential(
            nn.LayerNorm(language_hidden_size),
            nn.Linear(language_hidden_size, vision_hidden_size),
            nn.GELU(),
            nn.Linear(vision_hidden_size, vision_hidden_size)
        )

        self.patch_edge_token = force_image_size// patch_size

    def forward(self, merged_visual_prompts, visual_prompts, mark_embeddings, num_patches, num_vprompts):
        patch_embeds = self.patch_embedding(merged_visual_prompts)  # shape = [*, channel, height, width]
        split_size = [npatch * nvprompt for (npatch, nvprompt) in zip(num_patches, num_vprompts)]
        resized_visual_prompts = F.interpolate(visual_prompts.unsqueeze(1), 
                                               size=self.patch_edge_token,
                                               mode="nearest")
        resized_visual_prompts_per_batch = torch.split(resized_visual_prompts, split_size, dim=0)
        split_size = [nvp for nvp in num_vprompts]
        mark_embeddings_per_batch = torch.split(mark_embeddings, split_size, dim=0)
        batch_vprompts_input = []
        for i, (per_visual_prompts, per_mark_embeddings) in enumerate(zip(
            resized_visual_prompts_per_batch, mark_embeddings_per_batch)):
            per_visual_prompts = per_visual_prompts.view(
                num_vprompts[i], num_patches[i], 1, self.patch_edge_token, self.patch_edge_token)
            per_background = torch.ones_like(per_visual_prompts) - per_visual_prompts
            per_vprompts_input = torch.zeros(
                (num_vprompts[i], num_patches[i], self.language_hidden_size,
                 self.patch_edge_token, self.patch_edge_token), 
                 dtype=mark_embeddings.dtype).to(mark_embeddings.device)
            per_vprompts_input = per_vprompts_input * per_background + \
                per_mark_embeddings[:, None, :, None, None] * per_visual_prompts
            #TODO numeric stability for multi-granularity prompts (one pixel covered by multiple visual prompts)
            per_vprompts_input = torch.sum(per_vprompts_input, dim=0)
            batch_vprompts_input.append(per_vprompts_input)
        batch_vprompts_input = torch.cat(batch_vprompts_input, dim=0)
        batch_vprompts_input = batch_vprompts_input.permute(0, 2, 3, 1).flatten(1, 2)
        batch_vprompts_input = self.mlp1(batch_vprompts_input).view(
            -1, self.patch_edge_token, self.patch_edge_token, self.vision_hidden_size).permute(0, 3, 1, 2)
        # this version not consider color prompt
        patch_embeds = patch_embeds * 0.0 + batch_vprompts_input
        return patch_embeds