leideng/QCFuse / srt /models /deepseek_ocr.py
leideng's picture
download
raw
52.4 kB
# Copyright 2025 The SwissAI Initiative
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only Apertus model compatible with HuggingFace weights."""
import copy
import logging
import math
from functools import partial
from typing import Iterable, List, Optional, Set, Tuple, Type, TypeAlias, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers.models.vitdet.modeling_vitdet import get_rel_pos
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek import DeepseekForCausalLM
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM
from sglang.srt.models.transformers import maybe_prefix
NestedTensors: TypeAlias = Union[
list["NestedTensors"],
list["torch.Tensor"],
"torch.Tensor",
tuple["torch.Tensor", ...],
]
MultiModalEmbeddings: TypeAlias = list[Tensor] | Tensor | tuple[Tensor, ...]
logger = logging.getLogger(__name__)
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
is_multimodal: torch.Tensor,
) -> torch.Tensor:
"""
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
Note:
This updates `inputs_embeds` in place.
"""
if len(multimodal_embeddings) == 0:
return inputs_embeds
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype
try:
# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
inputs_embeds.masked_scatter_(
is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
)
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
if num_actual_tokens != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {num_actual_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
def isin_list(
elements: torch.Tensor,
test_elements_list: list[int],
) -> torch.Tensor:
test_elements = torch.tensor(test_elements_list, pin_memory=True).to(
device=elements.device, non_blocking=True
)
return torch.isin(elements, test_elements)
def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int | list[int],
) -> torch.Tensor:
"""
Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
positions in `inputs_embeds` corresponding to placeholder tokens in
`input_ids`.
`placeholder_token_id` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the `input_ids` MUST MATCH the order of
their embeddings in `multimodal_embeddings` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
This updates `inputs_embeds` in place.
"""
if isinstance(placeholder_token_id, list):
is_multimodal = isin_list(input_ids, placeholder_token_id)
else:
is_multimodal = input_ids == placeholder_token_id
return _merge_multimodal_embeddings(
inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
class MlpProjector(nn.Module):
def __init__(
self,
projector_type,
input_dim,
n_embed,
depth=1,
mlp_ratio=1,
downsample_ratio=4,
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.token_pooling = False
self.conv_fusion_high_low_features = False
super().__init__()
if projector_type == "identity":
modules = nn.Identity()
elif projector_type == "linear":
modules = nn.Linear(input_dim, n_embed)
elif projector_type == "mlp_gelu":
mlp_depth = depth
modules = [nn.Linear(input_dim, n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "normlayer_downsample_mlp_gelu":
mlp_depth = depth
mlp_ratio = mlp_ratio
modules = [
nn.LayerNorm(input_dim * downsample_ratio * downsample_ratio),
nn.Linear(
input_dim * downsample_ratio * downsample_ratio,
n_embed * mlp_ratio,
),
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "downsample_mlp_gelu":
mlp_depth = depth
mlp_ratio = mlp_ratio
modules = [
nn.Linear(
input_dim * downsample_ratio * downsample_ratio,
n_embed * mlp_ratio,
)
]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed * mlp_ratio, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "low_high_hybrid_split_mlp_gelu":
mlp_depth = depth
self.high_up_proj = nn.Linear(input_dim, n_embed // 2)
self.low_up_proj = nn.Linear(input_dim, n_embed // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "hybrid_split_feature_mlp_gelu":
mlp_depth = depth
channel_div = 0.5
self.high_up_proj = nn.Linear(input_dim[0], int(n_embed * channel_div))
self.low_up_proj = nn.Linear(
input_dim[1], n_embed - int(n_embed * channel_div)
)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed, n_embed))
modules = nn.Sequential(*modules)
elif projector_type == "low_high_split_mlp_gelu":
mlp_depth = depth
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(n_embed // 2, n_embed // 2))
modules = nn.Sequential(*modules)
self.high_layers = nn.Sequential(*modules)
self.low_layers = copy.deepcopy(modules)
else:
raise ValueError(f"Unknown projector type: {projector_type}")
self.layers = modules
def forward(self, x):
if self.token_pooling:
batch_size, wxh, channels = x.shape
w = h = int(wxh**0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
# Concatenate on channel dimension
patches = patches.contiguous().view(
batch_size, channels, h_patches * w_patches, -1
)
# Pass through linear layer
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)
if self.conv_fusion_high_low_features:
x = self.fusion_layer(x[:, 0]) + x[:, 1]
if self.projector_type == "low_high_hybrid_split_mlp_gelu":
high_x, low_x = x[0], x[1]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.projector_type == "hybrid_split_feature_mlp_gelu":
high_x = x[..., : self.input_dim[0]]
low_x = x[..., self.input_dim[0] :]
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
if self.projector_type == "low_high_split_mlp_gelu":
high_x, low_x = x[0], x[1]
high_x = self.high_layers(high_x)
low_x = self.low_layers(low_x)
x = torch.concat([high_x, low_x], dim=-1)
return x
if (
self.projector_type == "downsample_mlp_gelu"
or self.projector_type == "normlayer_downsample_mlp_gelu"
):
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.downsample_ratio:
pad = self.downsample_ratio - h % self.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(
x,
kernel_size=self.downsample_ratio,
stride=self.downsample_ratio,
padding=0,
) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
def add_decomposed_rel_pos(
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
rel_h = rel_h.unsqueeze(-1)
rel_w = rel_w.unsqueeze(-2)
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
return rel_h, rel_w
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = (
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(
q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
)
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
rel_h = rel_h.view(
B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)
)
rel_w = rel_w.view(
B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)
)
attn_bias = (rel_h + rel_w).view(
B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
)
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias
)
# x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = (
x.view(B, self.num_heads, H, W, -1)
.permute(0, 2, 3, 1, 4)
.reshape(B, H, W, -1)
)
x = self.proj(x)
return x
def window_partition(
x: torch.Tensor, window_size: int
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = (
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor,
window_size: int,
pad_hw: Tuple[int, int],
hw: Tuple[int, int],
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
def get_abs_pos_sam(abs_pos, tgt_size):
dtype = abs_pos.dtype
src_size = abs_pos.size(1)
if src_size != tgt_size:
old_pos_embed = abs_pos.permute(0, 3, 1, 2)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
return new_pos_embed
else:
return abs_pos
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(
1, img_size // patch_size, img_size // patch_size, embed_dim
)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
x2 = self.net_2(x)
x3 = self.net_3(x2.clone())
return x3
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_encoder = ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
image_encoder.eval()
if checkpoint is not None:
state_dict = torch.load(checkpoint)
image_encoder.load_state_dict(
{k[30:]: v for k, v in state_dict.items() if "vision_tower_high" in k},
strict=True,
)
return image_encoder
def build_sam_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
dim = abs_pos.size(-1)
abs_pos_new = abs_pos.squeeze(0)
cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
old_pos_embed = (
old_pos_embed.view(1, src_size, src_size, dim)
.permute(0, 3, 1, 2)
.contiguous()
)
old_pos_embed = old_pos_embed.to(torch.float32)
new_pos_embed = F.interpolate(
old_pos_embed,
size=(tgt_size, tgt_size),
mode="bicubic",
antialias=True,
align_corners=False,
).to(dtype)
new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
return vision_pos_embed
else:
return abs_pos
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
super().__init__()
self.embed_dim = hidden_size
self.image_size = image_size
self.patch_size = patch_size
self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = torch.nn.Conv2d(
in_channels=num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids", torch.arange(self.num_positions).expand((1, -1))
)
def forward(self, pixel_values, patch_embeds):
batch_size = pixel_values.shape[0]
if patch_embeds is not None:
patch_embeds = patch_embeds
else:
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + get_abs_pos(
self.position_embedding(self.position_ids), embeddings.size(1)
)
return embeddings
class NoTPAttention(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_heads = cfg["num_attention_heads"]
self.n_local_heads = cfg["num_attention_heads"]
self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"]
self.max_seq_len = cfg["seq_length"]
self.use_flash_attention = cfg["use_flash_attn"]
self.qkv_proj = torch.nn.Linear(
cfg["hidden_size"], cfg["hidden_size"] * 3, bias=True
)
self.out_proj = torch.nn.Linear(
cfg["hidden_size"], cfg["hidden_size"], bias=True
)
# self.core_attention = CoreAttention(cfg, AttnType.self_attn)
self.attn_drop = cfg["attention_dropout"]
def forward(
self,
x: torch.Tensor,
):
bsz, seqlen, _ = x.shape
xqkv = self.qkv_proj(x)
xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
if self.use_flash_attention:
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
# xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
# (B, num_head, S, head_size)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
output = torch.nn.functional.scaled_dot_product_attention(
xq, xk, xv, attn_mask=None
)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
else:
xq, xk, xv = torch.split(xqkv, 1, dim=2)
xq = xq.squeeze(2)
xk = xk.squeeze(2)
xv = xv.squeeze(2)
xq = xq.permute(0, 2, 1, 3)
xk = xk.permute(0, 2, 1, 3)
xv = xv.permute(0, 2, 1, 3)
output = torch.nn.functional.scaled_dot_product_attention(
xq, xk, xv, attn_mask=None
)
output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
output = self.out_proj(output)
return output
@torch.jit.script
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
class NoTPFeedForward(nn.Module):
def __init__(
self,
cfg,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
def forward(self, x):
output = self.fc2(quick_gelu(self.fc1(x)))
return output
class LayerNormfp32(torch.nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class NoTPTransformerBlock(nn.Module):
def __init__(self, cfg, layer_id: int, multiple_of=256):
super().__init__()
self.n_heads = cfg["num_attention_heads"]
self.dim = cfg["hidden_size"]
self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"]
self.self_attn = NoTPAttention(cfg)
self.mlp = NoTPFeedForward(
cfg, dim=cfg["hidden_size"], hidden_dim=cfg["ffn_hidden_size"]
)
self.layer_id = layer_id
self.layer_norm1 = torch.nn.LayerNorm(
cfg["hidden_size"], eps=cfg["layernorm_epsilon"]
)
self.layer_norm2 = torch.nn.LayerNorm(
cfg["hidden_size"], eps=cfg["layernorm_epsilon"]
)
def forward(self, x: torch.Tensor):
residual = self.self_attn.forward(self.layer_norm1(x))
h = x + residual
out = h + self.mlp.forward(self.layer_norm2(h))
return out
class NoTPTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.num_layers = cfg["num_layers"]
self.layers = torch.nn.ModuleList()
for layer_id in range(self.num_layers):
self.layers.append(
NoTPTransformerBlock(
cfg,
layer_id + 1,
)
)
def forward(
self,
hidden_states,
):
for layer in self.layers:
hidden_states = layer(hidden_states)
return hidden_states
class VitModel(nn.Module):
def __init__(self, cfg, freeze_embed=False, freeze_pre_norm=False) -> None:
super().__init__()
self.embeddings = CLIPVisionEmbeddings(
hidden_size=cfg["hidden_size"],
image_size=cfg["image_size"],
patch_size=cfg["patch_size"],
)
if freeze_embed:
for _, param in self.embeddings.named_parameters():
param.requires_grad = False
self.transformer = NoTPTransformer(cfg=cfg)
if cfg.get("fp32norm", False):
logger.info("Load fp32 layernorm for ViT.")
self.pre_layrnorm = LayerNormfp32(
cfg["hidden_size"],
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
else:
self.pre_layrnorm = torch.nn.LayerNorm(
cfg["hidden_size"],
eps=cfg.get("pre_layernorm_epsilon", 1e-5),
)
if freeze_pre_norm:
for _, param in self.pre_layrnorm.named_parameters():
param.requires_grad = False
for p in self.parameters():
p.micro_dp = True
@property
def dtype(self):
return next(self.parameters()).dtype
def set_input_tensor(self, input_tensor):
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
self.transformer.set_input_tensor(input_tensor[0])
def __str__(self) -> str:
return "open_clip"
def forward(self, x, patch_embeds):
x = self.embeddings(x, patch_embeds)
hidden_states = self.pre_layrnorm(x)
output = self.transformer(hidden_states)
return output
vit_model_cfg = dict(
num_layers=24,
hidden_size=1024,
num_heads=16,
num_attention_heads=16,
ffn_hidden_size=4096,
seq_length=256,
max_position_embeddings=256,
use_flash_attn=False,
understand_projector_stride=2,
hidden_dropout=0.0,
attention_dropout=0.0,
no_persist_layer_norm=False,
layernorm_epsilon=1e-5,
pre_layernorm_epsilon=1e-5,
image_size=224,
patch_size=14,
recompute_list=[],
)
def build_clip_l():
return VitModel(
cfg=vit_model_cfg,
freeze_embed=False,
freeze_pre_norm=False,
)
class DeepseekOCRForCausalLM(nn.Module):
def __init__(
self,
*,
config: DeepseekVLV2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.vision_config = config.vision_config
self.projector_config = config.projector_config
self.text_config = config.text_config
n_embed = 1280
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
if self.text_config.topk_method == "noaux_tc":
self.model = DeepseekV3ForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "language"),
)
elif not self.text_config.use_mla:
self.model = DeepseekForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "language"),
)
else:
self.model = DeepseekV2ForCausalLM(
config=config.text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "language"),
)
self.sam_model = build_sam_vit_b()
self.vision_model = build_clip_l()
n_embed = 1280
self.projector = MlpProjector(
projector_type="linear",
input_dim=2048,
n_embed=n_embed,
)
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
if pixel_values is None or torch.sum(pixel_values).item() == 0:
return None
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)
if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image crop. " f"Got type: {type(images_crop)}"
)
return [pixel_values, images_crop, images_spatial_crop]
raise AssertionError("This line should be unreachable.")
def _pixel_values_to_embedding(
self,
pixel_values: torch.Tensor,
images_crop: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values (global view): [n_image, batch_size, 3, height, width]
# images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
# images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
# split the pixel and image_crop, all batch_size = 1
images_in_this_batch = []
with torch.no_grad():
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
if torch.sum(patches).item() != 0:
local_features_1 = self.sam_model(patches)
local_features_2 = self.vision_model(patches, local_features_1)
local_features = torch.cat(
(
local_features_2[:, 1:],
local_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
local_features = self.projector(local_features)
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
global_features = self.projector(global_features)
_, hw, n_dim = global_features.shape
h = w = int(hw**0.5)
_2, hw2, n_dim2 = local_features.shape
h2 = w2 = int(hw2**0.5)
width_crop_num, height_crop_num = int(crop_shape[0]), int(
crop_shape[1]
)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[
global_features,
self.image_newline[None, None, :].expand(h, 1, n_dim),
],
dim=1,
)
global_features = global_features.view(-1, n_dim)
local_features = (
local_features.view(
height_crop_num, width_crop_num, h2, w2, n_dim2
)
.permute(0, 2, 1, 3, 4)
.reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
)
local_features = torch.cat(
[
local_features,
self.image_newline[None, None, :].expand(
height_crop_num * h2, 1, n_dim2
),
],
dim=1,
)
local_features = local_features.view(-1, n_dim2)
global_local_features = torch.cat(
[local_features, global_features, self.view_seperator[None, :]],
dim=0,
)
else:
global_features_1 = self.sam_model(image_ori)
global_features_2 = self.vision_model(image_ori, global_features_1)
global_features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
global_features = self.projector(global_features)
_, hw, n_dim = global_features.shape
h = w = int(hw**0.5)
global_features = global_features.view(h, w, n_dim)
global_features = torch.cat(
[
global_features,
self.image_newline[None, None, :].expand(h, 1, n_dim),
],
dim=1,
)
global_features = global_features.view(-1, n_dim)
global_local_features = torch.cat(
[global_features, self.view_seperator[None, :]], dim=0
)
images_in_this_batch.append(global_local_features)
return images_in_this_batch
def _process_image_input(self, mm_items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = torch.stack([item.feature for item in mm_items], dim=0).type(
self.vision_model.dtype
)
images_crop = (
torch.stack([item.images_crop for item in mm_items], dim=0)
.type(torch.long)
.to(device=pixel_values.device)
)
images_spatial_crop = (
torch.cat([item.images_spatial_crop for item in mm_items], dim=0)
.type(torch.long)
.to(device=pixel_values.device)
)
assert images_crop.dim() == 6
assert images_spatial_crop.dim() == 3
vision_feature_lists = self._pixel_values_to_embedding(
pixel_values=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
)
vision_features = torch.cat(vision_feature_lists, dim=0).type(
self.vision_model.dtype
)
return vision_features
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object
) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id
)
return inputs_embeds
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
vision_embeddings = self._process_image_input(items)
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
**kwargs: object,
):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.model,
multimodal_model=self,
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
name = "model.lm_head.weight"
elif name.startswith("model."):
if (
"image_newline" in name
or ".projector" in name
or "vision_model" in name
or "sam_model" in name
or "view_seperator" in name
):
name = name[len("model.") :]
elif not (
".projector" in name
or "vision_model" in name
or "sam_model" in name
or "image_newline" in name
):
name = name.replace("model.", "model.model.")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (
"mlp.experts." in name or "mlp.shared_experts." in name
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (
"mlp.experts." in name or "mlp.shared_experts." in name
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
f"Some weights are not initialized from checkpoints: {unloaded_params}"
)
EntryClass = [DeepseekOCRForCausalLM]

Xet Storage Details

Size:
52.4 kB
·
Xet hash:
1ef18e20aaa6872befb47eb50076dce7cecd3384ad0f0f762eb018b364f52b71

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.