|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from collections import OrderedDict |
|
|
from collections.abc import Iterable, Mapping, Sequence |
|
|
from typing import Literal, Optional, TypedDict, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from transformers import BartTokenizer, BatchFeature, PretrainedConfig |
|
|
|
|
|
from vllm.config import VllmConfig |
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor |
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
|
|
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, |
|
|
BartParallelLMHead, |
|
|
BartScaledWordEmbedding) |
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY |
|
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, |
|
|
MultiModalKwargs) |
|
|
from vllm.multimodal.parse import MultiModalDataItems |
|
|
from vllm.multimodal.processing import (BaseProcessingInfo, |
|
|
EncDecMultiModalProcessor, |
|
|
PromptIndexTargets, PromptInsertion, |
|
|
PromptUpdate) |
|
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder |
|
|
from vllm.sequence import IntermediateTensors |
|
|
|
|
|
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, |
|
|
SupportsV0Only) |
|
|
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings |
|
|
|
|
|
|
|
|
class Florence2ImagePixelInputs(TypedDict): |
|
|
type: Literal["pixel_values"] |
|
|
data: torch.Tensor |
|
|
"""Shape: (batch_size, num_channel, height, width)""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LearnedAbsolutePositionEmbedding2D(nn.Module): |
|
|
""" |
|
|
This module learns positional embeddings up to a fixed maximum size. |
|
|
""" |
|
|
|
|
|
def __init__(self, embedding_dim=256, num_pos=50): |
|
|
super().__init__() |
|
|
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) |
|
|
self.column_embeddings = nn.Embedding( |
|
|
num_pos, embedding_dim - (embedding_dim // 2)) |
|
|
|
|
|
def forward(self, pixel_values): |
|
|
""" |
|
|
pixel_values: (batch_size, height, width, num_channels) |
|
|
returns: (batch_size, height, width, embedding_dim * 2) |
|
|
""" |
|
|
if len(pixel_values.shape) != 4: |
|
|
raise ValueError('pixel_values must be a 4D tensor') |
|
|
height, width = pixel_values.shape[1:3] |
|
|
width_values = torch.arange(width, device=pixel_values.device) |
|
|
height_values = torch.arange(height, device=pixel_values.device) |
|
|
x_emb = self.column_embeddings(width_values) |
|
|
y_emb = self.row_embeddings(height_values) |
|
|
|
|
|
pos = torch.cat([ |
|
|
x_emb.unsqueeze(0).repeat(height, 1, 1), |
|
|
y_emb.unsqueeze(1).repeat(1, width, 1) |
|
|
], |
|
|
dim=-1) |
|
|
|
|
|
pos = pos.permute(2, 0, 1) |
|
|
pos = pos.unsqueeze(0) |
|
|
|
|
|
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) |
|
|
|
|
|
pos = pos.permute(0, 2, 3, 1) |
|
|
return pos |
|
|
|
|
|
|
|
|
class PositionalEmbeddingCosine1D(nn.Module): |
|
|
""" |
|
|
This class implements a very simple positional encoding. It follows closely |
|
|
the encoder from the link below: |
|
|
https://pytorch.org/tutorials/beginner/translation_transformer.html |
|
|
Args: |
|
|
embed_dim: The dimension of the embeddings. |
|
|
dropout_prob: The dropout probability. |
|
|
max_seq_len: The maximum length to precompute the positional encodings. |
|
|
""" |
|
|
|
|
|
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None: |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.max_seq_len = max_seq_len |
|
|
|
|
|
factor = math.log(10000) |
|
|
denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) / |
|
|
self.embed_dim) |
|
|
|
|
|
|
|
|
frequencies = \ |
|
|
torch.arange(0, self.max_seq_len) \ |
|
|
.reshape(self.max_seq_len, 1) * denominator |
|
|
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) |
|
|
|
|
|
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) |
|
|
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) |
|
|
|
|
|
|
|
|
self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed, |
|
|
requires_grad=False) |
|
|
|
|
|
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
seq_embeds: The sequence embeddings in order. Allowed size: |
|
|
1. [T, D], where T is the length of the sequence, and D is the |
|
|
frame embedding dimension. |
|
|
2. [B, T, D], where B is the batch size and T and D are the |
|
|
same as above. |
|
|
Returns a tensor of with the same dimensions as the input: i.e., |
|
|
[1, T, D] or [T, D]. |
|
|
""" |
|
|
shape_len = len(seq_embeds.shape) |
|
|
assert 2 <= shape_len <= 3 |
|
|
len_seq = seq_embeds.size(-2) |
|
|
assert len_seq <= self.max_seq_len |
|
|
pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] |
|
|
|
|
|
if shape_len == 3: |
|
|
pos_embeds = pos_embeds.view( |
|
|
(1, pos_embeds.size(0), pos_embeds.size(1))) |
|
|
return pos_embeds |
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential): |
|
|
|
|
|
def forward(self, *inputs): |
|
|
for module in self._modules.values(): |
|
|
if isinstance(inputs, tuple): |
|
|
inputs = module(*inputs) |
|
|
else: |
|
|
inputs = module(inputs) |
|
|
return inputs |
|
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
|
|
|
|
def __init__(self, norm, fn): |
|
|
super().__init__() |
|
|
self.norm = norm |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
shortcut = x |
|
|
if self.norm is not None: |
|
|
x, size = self.fn(self.norm(x), *args, **kwargs) |
|
|
else: |
|
|
x, size = self.fn(x, *args, **kwargs) |
|
|
|
|
|
x = shortcut + x |
|
|
|
|
|
return x, size |
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
hidden_features=None, |
|
|
out_features=None, |
|
|
act_layer=nn.GELU, |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.net = nn.Sequential( |
|
|
OrderedDict([("fc1", nn.Linear(in_features, hidden_features)), |
|
|
("act", act_layer()), |
|
|
("fc2", nn.Linear(hidden_features, out_features))])) |
|
|
|
|
|
def forward(self, x, size): |
|
|
return self.net(x), size |
|
|
|
|
|
|
|
|
class DepthWiseConv2d(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim_in, |
|
|
kernel_size, |
|
|
padding, |
|
|
stride, |
|
|
bias=True, |
|
|
): |
|
|
super().__init__() |
|
|
self.dw = nn.Conv2d(dim_in, |
|
|
dim_in, |
|
|
kernel_size=kernel_size, |
|
|
padding=padding, |
|
|
groups=dim_in, |
|
|
stride=stride, |
|
|
bias=bias) |
|
|
|
|
|
def forward(self, x, size): |
|
|
B, N, C = x.shape |
|
|
H, W = size |
|
|
assert N == H * W |
|
|
|
|
|
x = self.dw(x.transpose(1, 2).view(B, C, H, W)) |
|
|
size = (x.size(-2), x.size(-1)) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
return x, size |
|
|
|
|
|
|
|
|
class ConvEmbed(nn.Module): |
|
|
""" Image to Patch Embedding |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
patch_size=7, |
|
|
in_chans=3, |
|
|
embed_dim=64, |
|
|
stride=4, |
|
|
padding=2, |
|
|
norm_layer=None, |
|
|
pre_norm=True): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, |
|
|
embed_dim, |
|
|
kernel_size=patch_size, |
|
|
stride=stride, |
|
|
padding=padding) |
|
|
|
|
|
dim_norm = in_chans if pre_norm else embed_dim |
|
|
self.norm = norm_layer(dim_norm) if norm_layer else None |
|
|
|
|
|
self.pre_norm = pre_norm |
|
|
|
|
|
def forward(self, x, size): |
|
|
H, W = size |
|
|
if len(x.size()) == 3: |
|
|
if self.norm and self.pre_norm: |
|
|
x = self.norm(x) |
|
|
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) |
|
|
|
|
|
x = self.proj(x) |
|
|
|
|
|
_, _, H, W = x.shape |
|
|
x = rearrange(x, 'b c h w -> b (h w) c') |
|
|
if self.norm and not self.pre_norm: |
|
|
x = self.norm(x) |
|
|
|
|
|
return x, (H, W) |
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module): |
|
|
|
|
|
def __init__(self, dim, groups=8, qkv_bias=True): |
|
|
super().__init__() |
|
|
|
|
|
self.groups = groups |
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
def forward(self, x, size): |
|
|
B, N, C = x.shape |
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.groups, |
|
|
C // self.groups).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
q = q * (float(N)**-0.5) |
|
|
attention = q.transpose(-1, -2) @ k |
|
|
attention = attention.softmax(dim=-1) |
|
|
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) |
|
|
x = x.transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
return x, size |
|
|
|
|
|
|
|
|
class ChannelBlock(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
dim, |
|
|
groups, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=True, |
|
|
drop_path_rate=0., |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
conv_at_attn=True, |
|
|
conv_at_ffn=True): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = PreNorm(None, DepthWiseConv2d( |
|
|
dim, 3, 1, 1)) if conv_at_attn else None |
|
|
self.channel_attn = PreNorm( |
|
|
norm_layer(dim), |
|
|
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), |
|
|
) |
|
|
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, |
|
|
1)) if conv_at_ffn else None |
|
|
self.ffn = PreNorm( |
|
|
norm_layer(dim), |
|
|
Mlp(in_features=dim, |
|
|
hidden_features=int(dim * mlp_ratio), |
|
|
act_layer=act_layer), |
|
|
) |
|
|
|
|
|
def forward(self, x, size): |
|
|
if self.conv1: |
|
|
x, size = self.conv1(x, size) |
|
|
x, size = self.channel_attn(x, size) |
|
|
|
|
|
if self.conv2: |
|
|
x, size = self.conv2(x, size) |
|
|
x, size = self.ffn(x, size) |
|
|
|
|
|
return x, size |
|
|
|
|
|
|
|
|
def window_partition(x, window_size: int): |
|
|
B, H, W, C = x.shape |
|
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, |
|
|
C) |
|
|
windows = x.permute(0, 1, 3, 2, 4, |
|
|
5).contiguous().view(-1, window_size, window_size, C) |
|
|
return windows |
|
|
|
|
|
|
|
|
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): |
|
|
B = batch_size |
|
|
|
|
|
x = windows.view(B, H // window_size, W // window_size, window_size, |
|
|
window_size, -1) |
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|
|
return x |
|
|
|
|
|
|
|
|
class WindowAttention(nn.Module): |
|
|
|
|
|
def __init__(self, dim, num_heads, window_size, qkv_bias=True): |
|
|
|
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.window_size = window_size |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
self.scale = float(head_dim)**-0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, x, size): |
|
|
|
|
|
H, W = size |
|
|
B, L, C = x.shape |
|
|
assert L == H * W, "input feature has wrong size" |
|
|
|
|
|
x = x.view(B, H, W, C) |
|
|
|
|
|
pad_l = pad_t = 0 |
|
|
pad_r = (self.window_size - W % self.window_size) % self.window_size |
|
|
pad_b = (self.window_size - H % self.window_size) % self.window_size |
|
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) |
|
|
_, Hp, Wp, _ = x.shape |
|
|
|
|
|
x = window_partition(x, self.window_size) |
|
|
x = x.view(-1, self.window_size * self.window_size, C) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B_, N, C = x.shape |
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, |
|
|
C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
q = q * self.scale |
|
|
attn = (q @ k.transpose(-2, -1)) |
|
|
attn = self.softmax(attn) |
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
|
|
x = self.proj(x) |
|
|
|
|
|
|
|
|
x = x.view(-1, self.window_size, self.window_size, C) |
|
|
x = window_reverse(x, B, self.window_size, Hp, Wp) |
|
|
|
|
|
if pad_r > 0 or pad_b > 0: |
|
|
x = x[:, :H, :W, :].contiguous() |
|
|
|
|
|
x = x.view(B, H * W, C) |
|
|
|
|
|
return x, size |
|
|
|
|
|
|
|
|
class SpatialBlock(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
dim, |
|
|
num_heads, |
|
|
window_size, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=True, |
|
|
drop_path_rate=0., |
|
|
act_layer=nn.GELU, |
|
|
norm_layer=nn.LayerNorm, |
|
|
conv_at_attn=True, |
|
|
conv_at_ffn=True): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = PreNorm(None, DepthWiseConv2d( |
|
|
dim, 3, 1, 1)) if conv_at_attn else None |
|
|
self.window_attn = PreNorm( |
|
|
norm_layer(dim), |
|
|
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), |
|
|
) |
|
|
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, |
|
|
1)) if conv_at_ffn else None |
|
|
self.ffn = PreNorm( |
|
|
norm_layer(dim), |
|
|
Mlp(in_features=dim, |
|
|
hidden_features=int(dim * mlp_ratio), |
|
|
act_layer=act_layer), |
|
|
) |
|
|
|
|
|
def forward(self, x, size): |
|
|
if self.conv1: |
|
|
x, size = self.conv1(x, size) |
|
|
x, size = self.window_attn(x, size) |
|
|
|
|
|
if self.conv2: |
|
|
x, size = self.conv2(x, size) |
|
|
x, size = self.ffn(x, size) |
|
|
return x, size |
|
|
|
|
|
|
|
|
class DaViT(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_chans=3, |
|
|
num_classes=1000, |
|
|
depths=(1, 1, 3, 1), |
|
|
patch_size=(7, 2, 2, 2), |
|
|
patch_stride=(4, 2, 2, 2), |
|
|
patch_padding=(3, 0, 0, 0), |
|
|
patch_prenorm=(False, False, False, False), |
|
|
embed_dims=(64, 128, 192, 256), |
|
|
num_heads=(3, 6, 12, 24), |
|
|
num_groups=(3, 6, 12, 24), |
|
|
window_size=7, |
|
|
mlp_ratio=4., |
|
|
qkv_bias=True, |
|
|
drop_path_rate=0.1, |
|
|
norm_layer=nn.LayerNorm, |
|
|
enable_checkpoint=False, |
|
|
conv_at_attn=True, |
|
|
conv_at_ffn=True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.embed_dims = embed_dims |
|
|
self.num_heads = num_heads |
|
|
self.num_groups = num_groups |
|
|
self.num_stages = len(self.embed_dims) |
|
|
self.enable_checkpoint = enable_checkpoint |
|
|
assert self.num_stages == len(self.num_heads) == len(self.num_groups) |
|
|
|
|
|
num_stages = len(embed_dims) |
|
|
dpr = [ |
|
|
x.item() for x in torch.linspace(0, drop_path_rate, |
|
|
sum(depths) * 2) |
|
|
] |
|
|
|
|
|
depth_offset = 0 |
|
|
convs = [] |
|
|
blocks = [] |
|
|
for i in range(num_stages): |
|
|
conv_embed = ConvEmbed( |
|
|
patch_size=patch_size[i], |
|
|
stride=patch_stride[i], |
|
|
padding=patch_padding[i], |
|
|
in_chans=in_chans if i == 0 else self.embed_dims[i - 1], |
|
|
embed_dim=self.embed_dims[i], |
|
|
norm_layer=norm_layer, |
|
|
pre_norm=patch_prenorm[i]) |
|
|
convs.append(conv_embed) |
|
|
|
|
|
block = MySequential(*[ |
|
|
MySequential( |
|
|
OrderedDict([('spatial_block', |
|
|
SpatialBlock( |
|
|
embed_dims[i], |
|
|
num_heads[i], |
|
|
window_size, |
|
|
drop_path_rate=dpr[depth_offset + j * 2], |
|
|
qkv_bias=qkv_bias, |
|
|
mlp_ratio=mlp_ratio, |
|
|
conv_at_attn=conv_at_attn, |
|
|
conv_at_ffn=conv_at_ffn, |
|
|
)), |
|
|
('channel_block', |
|
|
ChannelBlock( |
|
|
embed_dims[i], |
|
|
num_groups[i], |
|
|
drop_path_rate=dpr[depth_offset + j * 2 + |
|
|
1], |
|
|
qkv_bias=qkv_bias, |
|
|
mlp_ratio=mlp_ratio, |
|
|
conv_at_attn=conv_at_attn, |
|
|
conv_at_ffn=conv_at_ffn, |
|
|
))])) for j in range(depths[i]) |
|
|
]) |
|
|
blocks.append(block) |
|
|
depth_offset += depths[i] * 2 |
|
|
|
|
|
self.convs = nn.ModuleList(convs) |
|
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
|
|
self.avgpool = nn.AdaptiveAvgPool1d(1) |
|
|
|
|
|
@property |
|
|
def dim_out(self): |
|
|
return self.embed_dims[-1] |
|
|
|
|
|
def forward_features_unpool(self, x): |
|
|
""" |
|
|
forward until avg pooling |
|
|
Args: |
|
|
x (_type_): input image tensor |
|
|
""" |
|
|
input_size = (x.size(2), x.size(3)) |
|
|
for conv, block in zip(self.convs, self.blocks): |
|
|
x, input_size = conv(x, input_size) |
|
|
x, input_size = block(x, input_size) |
|
|
return x |
|
|
|
|
|
def forward_features(self, x): |
|
|
x = self.forward_features_unpool(x) |
|
|
|
|
|
|
|
|
x = self.avgpool(x.transpose(1, 2)) |
|
|
|
|
|
x = torch.flatten(x, 1) |
|
|
x = self.norms(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.forward_features(x) |
|
|
x = self.head(x) |
|
|
return x |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, config): |
|
|
return cls( |
|
|
depths=config.depths, |
|
|
embed_dims=config.dim_embed, |
|
|
num_heads=config.num_heads, |
|
|
num_groups=config.num_groups, |
|
|
patch_size=config.patch_size, |
|
|
patch_stride=config.patch_stride, |
|
|
patch_padding=config.patch_padding, |
|
|
patch_prenorm=config.patch_prenorm, |
|
|
drop_path_rate=config.drop_path_rate, |
|
|
window_size=config.window_size, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class Florence2LanguageModel(nn.Module): |
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
super().__init__() |
|
|
|
|
|
config = vllm_config.model_config.hf_config |
|
|
cache_config = vllm_config.cache_config |
|
|
quant_config = vllm_config.quant_config |
|
|
|
|
|
self.config = config |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) |
|
|
self.encoder = BartEncoder(config, |
|
|
cache_config=cache_config, |
|
|
quant_config=quant_config, |
|
|
prefix=f"{prefix}.encoder") |
|
|
self.decoder = BartDecoder(config, |
|
|
cache_config=cache_config, |
|
|
quant_config=quant_config, |
|
|
prefix=f"{prefix}.decoder") |
|
|
|
|
|
if self.config.tie_word_embeddings: |
|
|
self.encoder.embed_tokens.weight = self.shared.weight |
|
|
self.decoder.embed_tokens.weight = self.shared.weight |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
encoder_input_ids: torch.Tensor, |
|
|
encoder_positions: torch.Tensor, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
r""" |
|
|
Args: |
|
|
input_ids |
|
|
Indices of *decoder* input sequence tokens in the vocabulary. |
|
|
Padding will be ignored by default should you |
|
|
provide it. |
|
|
positions |
|
|
Positions of *decoder* input sequence tokens. |
|
|
encoder_input_ids |
|
|
Indices of *encoder* input sequence tokens in the vocabulary. |
|
|
encoder_positions: |
|
|
Positions of *encoder* input sequence tokens. |
|
|
Returns: |
|
|
Model output torch.Tensor |
|
|
""" |
|
|
|
|
|
encoder_hidden_states = None |
|
|
|
|
|
if inputs_embeds is not None or encoder_input_ids.numel() > 0: |
|
|
|
|
|
|
|
|
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, |
|
|
positions=encoder_positions, |
|
|
inputs_embeds=inputs_embeds) |
|
|
|
|
|
|
|
|
|
|
|
decoder_outputs = self.decoder( |
|
|
decoder_input_ids=input_ids, |
|
|
decoder_positions=positions, |
|
|
encoder_hidden_states=encoder_hidden_states) |
|
|
|
|
|
return decoder_outputs |
|
|
|
|
|
|
|
|
class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): |
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
super().__init__() |
|
|
|
|
|
config = vllm_config.model_config.hf_config |
|
|
|
|
|
self.config = config |
|
|
self.model = Florence2LanguageModel(vllm_config=vllm_config, |
|
|
prefix=f"{prefix}.model") |
|
|
embed_scale = math.sqrt( |
|
|
config.d_model) if config.scale_embedding else 1.0 |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = BartParallelLMHead(self.vocab_size, |
|
|
config.d_model, |
|
|
embed_scale=embed_scale) |
|
|
|
|
|
self.logits_processor = LogitsProcessor(self.vocab_size, |
|
|
config.vocab_size) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
encoder_input_ids: torch.Tensor, |
|
|
encoder_positions: torch.Tensor, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
r""" |
|
|
Args: |
|
|
input_ids |
|
|
torch.Tensor of *decoder* input token ids. |
|
|
positions |
|
|
torch.Tensor of *decoder* position indices. |
|
|
encoder_input_ids |
|
|
torch.Tensor of *encoder* input token ids. |
|
|
encoder_positions |
|
|
torch.Tensor of *encoder* position indices |
|
|
Returns: |
|
|
Output torch.Tensor |
|
|
""" |
|
|
|
|
|
return self.model(input_ids, |
|
|
positions, |
|
|
encoder_input_ids, |
|
|
encoder_positions, |
|
|
inputs_embeds=inputs_embeds) |
|
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
return self.model.encoder.embed_tokens(input_ids) |
|
|
|
|
|
def compute_logits( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
sampling_metadata: SamplingMetadata, |
|
|
) -> Optional[torch.Tensor]: |
|
|
logits = self.logits_processor(self.lm_head, hidden_states, |
|
|
sampling_metadata) |
|
|
return logits |
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, |
|
|
torch.Tensor]]) -> set[str]: |
|
|
stacked_params_mapping = [ |
|
|
|
|
|
("qkv_proj", "q_proj", "q"), |
|
|
("qkv_proj", "k_proj", "k"), |
|
|
("qkv_proj", "v_proj", "v"), |
|
|
] |
|
|
|
|
|
params_dict = dict(self.named_parameters()) |
|
|
loaded_params: set[str] = set() |
|
|
for name, loaded_weight in weights: |
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping: |
|
|
if weight_name not in name: |
|
|
continue |
|
|
name = name.replace(weight_name, param_name) |
|
|
param = params_dict[name] |
|
|
weight_loader = param.weight_loader |
|
|
weight_loader(param, loaded_weight, shard_id) |
|
|
break |
|
|
else: |
|
|
if "final_logits_bias" in name: |
|
|
continue |
|
|
if self.config.tie_word_embeddings and "embed_tokens" in name: |
|
|
continue |
|
|
param = params_dict[name] |
|
|
weight_loader = getattr(param, "weight_loader", |
|
|
default_weight_loader) |
|
|
weight_loader(param, loaded_weight) |
|
|
loaded_params.add(name) |
|
|
return loaded_params |
|
|
|
|
|
|
|
|
class Florence2ProcessingInfo(BaseProcessingInfo): |
|
|
|
|
|
def get_hf_config(self): |
|
|
return self.ctx.get_hf_config() |
|
|
|
|
|
def get_hf_processor(self): |
|
|
return self.ctx.get_hf_processor() |
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: |
|
|
return {"image": 1} |
|
|
|
|
|
def get_num_image_tokens(self) -> int: |
|
|
processor_config = self.ctx.get_hf_image_processor_config() |
|
|
return processor_config["image_seq_length"] |
|
|
|
|
|
|
|
|
class Florence2DummyInputsBuilder( |
|
|
BaseDummyInputsBuilder[Florence2ProcessingInfo]): |
|
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: |
|
|
return "" |
|
|
|
|
|
def get_dummy_mm_data( |
|
|
self, |
|
|
seq_len: int, |
|
|
mm_counts: Mapping[str, int], |
|
|
) -> MultiModalDataDict: |
|
|
num_images = mm_counts.get("image", 0) |
|
|
|
|
|
target_width = target_height = self.info.get_hf_config().projection_dim |
|
|
|
|
|
return { |
|
|
"image": |
|
|
self._get_dummy_images(width=target_width, |
|
|
height=target_height, |
|
|
num_images=num_images) |
|
|
} |
|
|
|
|
|
|
|
|
class Florence2MultiModalProcessor( |
|
|
EncDecMultiModalProcessor[Florence2ProcessingInfo]): |
|
|
|
|
|
def _hf_processor_applies_updates( |
|
|
self, |
|
|
prompt_text: str, |
|
|
mm_items: MultiModalDataItems, |
|
|
hf_processor_mm_kwargs: Mapping[str, object], |
|
|
) -> bool: |
|
|
return False |
|
|
|
|
|
def create_encoder_prompt( |
|
|
self, |
|
|
prompt: Union[str, list[int]], |
|
|
mm_data: MultiModalDataDict, |
|
|
) -> Union[str, list[int]]: |
|
|
return prompt |
|
|
|
|
|
def create_decoder_prompt( |
|
|
self, |
|
|
prompt: Union[str, list[int]], |
|
|
mm_data: MultiModalDataDict, |
|
|
) -> Union[str, list[int]]: |
|
|
return [self.info.get_hf_config().eos_token_id] |
|
|
|
|
|
def _apply_hf_processor_tokens_only( |
|
|
self, |
|
|
prompt_tokens: list[int], |
|
|
) -> list[int]: |
|
|
hf_processor = self.info.get_hf_processor() |
|
|
tokenizer: BartTokenizer = hf_processor.tokenizer |
|
|
prompt_text = tokenizer.decode(prompt_tokens) |
|
|
|
|
|
prompt_text = hf_processor._construct_prompts([prompt_text])[0] |
|
|
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False) |
|
|
return prompt_tokens |
|
|
|
|
|
def _call_hf_processor( |
|
|
self, |
|
|
prompt: str, |
|
|
mm_data: Mapping[str, object], |
|
|
mm_kwargs: Mapping[str, object], |
|
|
) -> BatchFeature: |
|
|
if mm_data: |
|
|
processed_outputs = super()._call_hf_processor( |
|
|
prompt, mm_data, mm_kwargs) |
|
|
else: |
|
|
hf_processor = self.info.get_hf_processor() |
|
|
tokenizer = hf_processor.tokenizer |
|
|
prompt = hf_processor._construct_prompts([prompt])[0] |
|
|
processed_outputs = tokenizer(prompt, |
|
|
add_special_tokens=True, |
|
|
return_tensors="pt") |
|
|
return processed_outputs |
|
|
|
|
|
def _get_mm_fields_config( |
|
|
self, |
|
|
hf_inputs: BatchFeature, |
|
|
hf_processor_mm_kwargs: Mapping[str, object], |
|
|
) -> Mapping[str, MultiModalFieldConfig]: |
|
|
return dict(pixel_values=MultiModalFieldConfig.batched("image")) |
|
|
|
|
|
def _get_prompt_updates( |
|
|
self, |
|
|
mm_items: MultiModalDataItems, |
|
|
hf_processor_mm_kwargs: Mapping[str, object], |
|
|
out_mm_kwargs: MultiModalKwargs, |
|
|
) -> Sequence[PromptUpdate]: |
|
|
hf_config = self.info.get_hf_config() |
|
|
pad_token_id = hf_config.pad_token_id |
|
|
num_image_tokens = self.info.get_num_image_tokens() |
|
|
image_tokens = [pad_token_id] * num_image_tokens |
|
|
|
|
|
return [ |
|
|
PromptInsertion( |
|
|
modality="image", |
|
|
target=PromptIndexTargets.start(), |
|
|
insertion=image_tokens, |
|
|
) |
|
|
] |
|
|
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor( |
|
|
Florence2MultiModalProcessor, |
|
|
info=Florence2ProcessingInfo, |
|
|
dummy_inputs=Florence2DummyInputsBuilder) |
|
|
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, |
|
|
SupportsV0Only): |
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
super().__init__() |
|
|
config = vllm_config.model_config.hf_config |
|
|
processor_config = vllm_config.model_config.hf_image_processor_config |
|
|
|
|
|
self.config = config |
|
|
self.vision_config = config.vision_config |
|
|
self.processor_config = processor_config |
|
|
assert config.vision_config.model_type == 'davit', ( |
|
|
'only DaViT is supported for now') |
|
|
self.vision_tower = DaViT.from_config(config=config.vision_config) |
|
|
self._build_image_projection_layers(config) |
|
|
self.language_model = Florence2LanguageForConditionalGeneration( |
|
|
vllm_config=vllm_config.with_hf_config(config.text_config), |
|
|
prefix=f"{prefix}.language_model", |
|
|
) |
|
|
self.pad_token_id = config.pad_token_id |
|
|
|
|
|
def _build_image_projection_layers(self, config: PretrainedConfig): |
|
|
image_dim_out = config.vision_config.dim_embed[-1] |
|
|
dim_projection = config.vision_config.projection_dim |
|
|
self.image_projection = nn.Parameter( |
|
|
torch.empty(image_dim_out, dim_projection)) |
|
|
self.image_proj_norm = nn.LayerNorm(dim_projection) |
|
|
image_pos_embed_config = config.vision_config.image_pos_embed |
|
|
if image_pos_embed_config['type'] == 'learned_abs_2d': |
|
|
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( |
|
|
embedding_dim=image_dim_out, |
|
|
num_pos=image_pos_embed_config['max_pos_embeddings']) |
|
|
else: |
|
|
raise NotImplementedError("Florence2 only supports learned_abs_2d " |
|
|
"as image position embedding.") |
|
|
|
|
|
self.image_feature_source = config.vision_config.image_feature_source |
|
|
|
|
|
|
|
|
visual_temporal_embedding_config = ( |
|
|
self.vision_config.visual_temporal_embedding) |
|
|
if visual_temporal_embedding_config['type'] == 'COSINE': |
|
|
self.visual_temporal_embed = PositionalEmbeddingCosine1D( |
|
|
embed_dim=image_dim_out, |
|
|
max_seq_len=visual_temporal_embedding_config[ |
|
|
'max_temporal_embeddings']) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
'Florence2 only supports COSINE as temporal embedding.') |
|
|
|
|
|
def _validate_pixel_values( |
|
|
self, data: Union[torch.Tensor, list[torch.Tensor]] |
|
|
) -> Union[torch.Tensor, list[torch.Tensor]]: |
|
|
|
|
|
size = self.processor_config["size"] |
|
|
h, w = size["height"], size["width"] |
|
|
expected_dims = (3, h, w) |
|
|
|
|
|
def _validate_shape(d: torch.Tensor): |
|
|
actual_dims = tuple(d.shape) |
|
|
|
|
|
if actual_dims != expected_dims: |
|
|
expected_expr = tuple(*map(str, expected_dims)) |
|
|
raise ValueError( |
|
|
"The expected shape of pixel values per batch " |
|
|
f"is {expected_expr}. You supplied {tuple(d.shape)}.") |
|
|
|
|
|
for d in data: |
|
|
_validate_shape(d) |
|
|
|
|
|
return data |
|
|
|
|
|
def _parse_and_validate_image_input(self, **kwargs: object): |
|
|
pixel_values: Optional[Union[list[list[torch.Tensor]], |
|
|
list[torch.Tensor], |
|
|
torch.Tensor]] = kwargs.pop( |
|
|
"pixel_values", None) |
|
|
image_embeds: Optional[Union[list[list[torch.Tensor]], |
|
|
list[torch.Tensor], |
|
|
torch.Tensor]] = kwargs.pop( |
|
|
"image_embeds", None) |
|
|
|
|
|
if pixel_values is None and image_embeds is None: |
|
|
return None |
|
|
|
|
|
if pixel_values is not None and image_embeds is not None: |
|
|
raise ValueError( |
|
|
"Both pixel values and image embeds are provided.") |
|
|
|
|
|
if pixel_values is not None: |
|
|
return Florence2ImagePixelInputs( |
|
|
type="pixel_values", |
|
|
data=self._validate_pixel_values( |
|
|
flatten_bn(pixel_values, concat=True)), |
|
|
) |
|
|
|
|
|
if image_embeds is not None: |
|
|
raise NotImplementedError |
|
|
|
|
|
raise AssertionError("This line should be unreachable.") |
|
|
|
|
|
def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
dtype = next(self.vision_tower.parameters()).dtype |
|
|
pixel_values = pixel_values.to(dtype) |
|
|
|
|
|
batch_size, T = pixel_values.size(0), 1 |
|
|
x = self.vision_tower.forward_features_unpool(pixel_values) |
|
|
if self.image_pos_embed is not None: |
|
|
x = x.view(batch_size * T, -1, x.shape[-1]) |
|
|
num_tokens = x.shape[-2] |
|
|
h, w = int(num_tokens**0.5), int(num_tokens**0.5) |
|
|
assert h * w == num_tokens, ( |
|
|
'only support square feature maps for now') |
|
|
x = x.view(batch_size * T, h, w, x.shape[-1]) |
|
|
pos_embed = self.image_pos_embed(x) |
|
|
x = x + pos_embed |
|
|
x = x.view(batch_size, T * h * w, x.shape[-1]) |
|
|
|
|
|
if self.visual_temporal_embed is not None: |
|
|
visual_temporal_embed = self.visual_temporal_embed( |
|
|
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) |
|
|
x = x.view(batch_size, T, -1, |
|
|
x.shape[-1]) + visual_temporal_embed.view( |
|
|
1, T, 1, x.shape[-1]) |
|
|
|
|
|
x_feat_dict = {} |
|
|
|
|
|
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) |
|
|
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x |
|
|
|
|
|
temporal_avg_pool_x = x.view(batch_size, T, -1, |
|
|
x.shape[-1]).mean(dim=1) |
|
|
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x |
|
|
|
|
|
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] |
|
|
x_feat_dict['last_frame'] = x |
|
|
|
|
|
new_x = [] |
|
|
for _image_feature_source in self.image_feature_source: |
|
|
if _image_feature_source not in x_feat_dict: |
|
|
raise ValueError('invalid image feature source: {}'.format( |
|
|
_image_feature_source)) |
|
|
new_x.append(x_feat_dict[_image_feature_source]) |
|
|
|
|
|
x = torch.cat(new_x, dim=1) |
|
|
|
|
|
x = x @ self.image_projection |
|
|
x = self.image_proj_norm(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def _process_image_input( |
|
|
self, image_input: Florence2ImagePixelInputs) -> torch.Tensor: |
|
|
assert image_input["type"] == "pixel_values" |
|
|
pixel_values = image_input["data"] |
|
|
return self._encode_image(pixel_values) |
|
|
|
|
|
def get_language_model(self) -> torch.nn.Module: |
|
|
return self.language_model |
|
|
|
|
|
def get_multimodal_embeddings( |
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: |
|
|
image_input = self._parse_and_validate_image_input(**kwargs) |
|
|
if image_input is None: |
|
|
return None |
|
|
vision_embeddings = self._process_image_input(image_input) |
|
|
return vision_embeddings |
|
|
|
|
|
def get_input_embeddings( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, |
|
|
) -> torch.Tensor: |
|
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids) |
|
|
if multimodal_embeddings is not None: |
|
|
inputs_embeds = merge_multimodal_embeddings( |
|
|
input_ids, inputs_embeds, multimodal_embeddings, |
|
|
self.pad_token_id) |
|
|
return inputs_embeds |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
|
*, |
|
|
encoder_input_ids: torch.Tensor, |
|
|
encoder_positions: torch.Tensor, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
r""" |
|
|
Args: |
|
|
input_ids |
|
|
torch.Tensor of *decoder* input token ids. |
|
|
positions |
|
|
torch.Tensor of *decoder* position indices. |
|
|
encoder_input_ids |
|
|
torch.Tensor of *encoder* input token ids. |
|
|
encoder_positions |
|
|
torch.Tensor of *encoder* position indices |
|
|
Returns: |
|
|
Output torch.Tensor |
|
|
""" |
|
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs) |
|
|
if encoder_input_ids.numel() > 0 or vision_embeddings is not None: |
|
|
inputs_embeds = self.get_input_embeddings(encoder_input_ids, |
|
|
vision_embeddings) |
|
|
else: |
|
|
inputs_embeds = None |
|
|
|
|
|
hidden_states = self.language_model(input_ids, |
|
|
positions, |
|
|
encoder_input_ids, |
|
|
encoder_positions, |
|
|
inputs_embeds=inputs_embeds) |
|
|
return hidden_states |
|
|
|
|
|
def compute_logits( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
sampling_metadata: SamplingMetadata, |
|
|
) -> Optional[torch.Tensor]: |
|
|
return self.language_model.compute_logits(hidden_states, |
|
|
sampling_metadata) |
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, |
|
|
torch.Tensor]]) -> set[str]: |
|
|
loader = AutoWeightsLoader(self) |
|
|
return loader.load_weights(weights) |
|
|
|