zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig
class VisualPromptEncodeConfig(PretrainedConfig):
model_type = 'vision_encoder'
_auto_class = 'AutoConfig'
main_input_name = "visual_prompts"
def __init__(self,
vision_hidden_size: int,
language_hidden_size: int,
patch_size: int,
downsample_ratio,
**kwargs):
super().__init__(**kwargs)
class VisualPromptEncodeModel(nn.Module):
def __init__(self,
in_channels: int,
vision_hidden_size: int,
language_hidden_size: int,
force_image_size: int,
patch_size: int,
downsample_ratio: int,
) -> None:
super().__init__()
self.in_channels = in_channels
self.vision_hidden_size = vision_hidden_size
self.language_hidden_size = language_hidden_size
self.force_image_size = force_image_size
self.patch_size = patch_size
self.downsample_ratio = downsample_ratio
self.patch_embedding = nn.Conv2d(
in_channels=in_channels, out_channels=vision_hidden_size,
kernel_size=patch_size, stride=patch_size
)
self.mlp1 = nn.Sequential(
nn.LayerNorm(language_hidden_size),
nn.Linear(language_hidden_size, vision_hidden_size),
nn.GELU(),
nn.Linear(vision_hidden_size, vision_hidden_size)
)
self.patch_edge_token = force_image_size// patch_size
def forward(self, merged_visual_prompts, visual_prompts, mark_embeddings, num_patches, num_vprompts):
patch_embeds = self.patch_embedding(merged_visual_prompts) # shape = [*, channel, height, width]
split_size = [npatch * nvprompt for (npatch, nvprompt) in zip(num_patches, num_vprompts)]
resized_visual_prompts = F.interpolate(visual_prompts.unsqueeze(1),
size=self.patch_edge_token,
mode="nearest")
resized_visual_prompts_per_batch = torch.split(resized_visual_prompts, split_size, dim=0)
split_size = [nvp for nvp in num_vprompts]
mark_embeddings_per_batch = torch.split(mark_embeddings, split_size, dim=0)
batch_vprompts_input = []
for i, (per_visual_prompts, per_mark_embeddings) in enumerate(zip(
resized_visual_prompts_per_batch, mark_embeddings_per_batch)):
per_visual_prompts = per_visual_prompts.view(
num_vprompts[i], num_patches[i], 1, self.patch_edge_token, self.patch_edge_token)
per_background = torch.ones_like(per_visual_prompts) - per_visual_prompts
per_vprompts_input = torch.zeros(
(num_vprompts[i], num_patches[i], self.language_hidden_size,
self.patch_edge_token, self.patch_edge_token),
dtype=mark_embeddings.dtype).to(mark_embeddings.device)
per_vprompts_input = per_vprompts_input * per_background + \
per_mark_embeddings[:, None, :, None, None] * per_visual_prompts
#TODO numeric stability for multi-granularity prompts (one pixel covered by multiple visual prompts)
per_vprompts_input = torch.sum(per_vprompts_input, dim=0)
batch_vprompts_input.append(per_vprompts_input)
batch_vprompts_input = torch.cat(batch_vprompts_input, dim=0)
batch_vprompts_input = batch_vprompts_input.permute(0, 2, 3, 1).flatten(1, 2)
batch_vprompts_input = self.mlp1(batch_vprompts_input).view(
-1, self.patch_edge_token, self.patch_edge_token, self.vision_hidden_size).permute(0, 3, 1, 2)
# this version not consider color prompt
patch_embeds = patch_embeds * 0.0 + batch_vprompts_input
return patch_embeds