|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch InfiniteVL model (built on top of Qwen2-VL/Qwen2.5-VL).""" |
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.image_utils import ImageInput |
|
|
from transformers.modeling_flash_attention_utils import is_flash_attn_available |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
from transformers.processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs |
|
|
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput |
|
|
from transformers.utils import is_torchdynamo_compiling, logging |
|
|
from transformers.video_utils import VideoInput |
|
|
|
|
|
|
|
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig |
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import ( |
|
|
PatchEmbed, |
|
|
PatchMerger, |
|
|
Qwen2RMSNorm, |
|
|
Qwen2VLCausalLMOutputWithPast, |
|
|
Qwen2VLForConditionalGeneration, |
|
|
Qwen2VLModel, |
|
|
Qwen2VLModelOutputWithPast, |
|
|
Qwen2VLPreTrainedModel, |
|
|
TransformersKwargs, |
|
|
VisionAttention, |
|
|
VisionRotaryEmbedding, |
|
|
) |
|
|
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLImagesKwargs, Qwen2VLProcessor |
|
|
|
|
|
|
|
|
if is_flash_attn_available(): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteVLVisionConfig(PretrainedConfig): |
|
|
""" |
|
|
Vision backbone configuration for InfiniteVL. |
|
|
|
|
|
This mirrors the Qwen2.5-VL vision encoder but is exposed under the |
|
|
InfiniteVL naming for clarity. It is used as a sub-config inside |
|
|
:class:`InfiniteVLConfig`. |
|
|
""" |
|
|
|
|
|
model_type = "infinite_vl" |
|
|
base_config_key = "vision_config" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
depth: int = 32, |
|
|
hidden_size: int = 3584, |
|
|
hidden_act: str = "silu", |
|
|
intermediate_size: int = 3420, |
|
|
num_heads: int = 16, |
|
|
in_channels: int = 3, |
|
|
patch_size: int = 14, |
|
|
spatial_merge_size: int = 2, |
|
|
temporal_patch_size: int = 2, |
|
|
tokens_per_second: int = 4, |
|
|
window_size: int = 112, |
|
|
out_hidden_size: int = 3584, |
|
|
fullatt_block_indexes: Optional[List[int]] = None, |
|
|
initializer_range: float = 0.02, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if fullatt_block_indexes is None: |
|
|
fullatt_block_indexes = [7, 15, 23, 31] |
|
|
|
|
|
self.depth = depth |
|
|
self.hidden_size = hidden_size |
|
|
self.hidden_act = hidden_act |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_heads = num_heads |
|
|
self.in_channels = in_channels |
|
|
self.patch_size = patch_size |
|
|
self.spatial_merge_size = spatial_merge_size |
|
|
self.temporal_patch_size = temporal_patch_size |
|
|
self.tokens_per_second = tokens_per_second |
|
|
self.window_size = window_size |
|
|
self.fullatt_block_indexes = list(fullatt_block_indexes) |
|
|
self.out_hidden_size = out_hidden_size |
|
|
self.initializer_range = initializer_range |
|
|
|
|
|
|
|
|
class InfiniteVLTextConfig(Qwen2VLTextConfig): |
|
|
""" |
|
|
Text backbone configuration for InfiniteVL. |
|
|
|
|
|
This class currently reuses :class:`Qwen2VLTextConfig` as a base and |
|
|
only overrides the model_type to keep InfiniteVL text separate at |
|
|
the configuration level, while remaining fully compatible with |
|
|
the parent implementation. |
|
|
""" |
|
|
|
|
|
model_type = "infinite_vl_text" |
|
|
|
|
|
|
|
|
class InfiniteVLConfig(Qwen2VLConfig): |
|
|
""" |
|
|
Top-level InfiniteVL configuration. |
|
|
|
|
|
This extends :class:`Qwen2VLConfig` and swaps in the InfiniteVL |
|
|
vision/text config classes via ``sub_configs`` so that downstream |
|
|
models can transparently use InfiniteVL while remaining compatible |
|
|
with Qwen2-VL tooling and loading code. |
|
|
""" |
|
|
|
|
|
model_type = "infinite_vl" |
|
|
sub_configs = {"vision_config": InfiniteVLVisionConfig, "text_config": InfiniteVLTextConfig} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteVLMLP(nn.Module): |
|
|
""" |
|
|
Standard gated MLP used in the InfiniteVL vision backbone. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: InfiniteVLVisionConfig, bias: bool = False): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
|
|
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: |
|
|
gated = self.act_fn(self.gate_proj(hidden_state)) |
|
|
return self.down_proj(gated * self.up_proj(hidden_state)) |
|
|
|
|
|
|
|
|
class InfiniteVisionPatchEmbed(PatchEmbed): |
|
|
""" |
|
|
Wrapper around the Qwen2-VL patch embedder kept for naming |
|
|
consistency in the InfiniteVL codebase. |
|
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class InfiniteVisionRotaryEmbedding(VisionRotaryEmbedding): |
|
|
""" |
|
|
Rotary embedding for the InfiniteVL vision backbone. This is a direct |
|
|
alias for the Qwen2-VL implementation, exposed under an InfiniteVL |
|
|
name for clarity. |
|
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class InfiniteVLPatchMerger(PatchMerger): |
|
|
""" |
|
|
Patch merger with Qwen2-style RMSNorm on the query side. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: |
|
|
super().__init__(dim, context_dim, spatial_merge_size) |
|
|
self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) |
|
|
|
|
|
|
|
|
class InfiniteVLVisionAttention(VisionAttention): |
|
|
""" |
|
|
Vision attention wrapper that exposes the hidden size via ``dim`` |
|
|
for convenience. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: InfiniteVLVisionConfig) -> None: |
|
|
super().__init__(config) |
|
|
self.dim = config.hidden_size |
|
|
|
|
|
|
|
|
class InfiniteVLVisionBlock(GradientCheckpointingLayer): |
|
|
""" |
|
|
A single InfiniteVL vision transformer block consisting of: |
|
|
- Qwen2-style RMSNorm |
|
|
- multi-head attention |
|
|
- gated MLP |
|
|
""" |
|
|
|
|
|
def __init__(self, config: InfiniteVLVisionConfig, attn_implementation: str = "sdpa") -> None: |
|
|
super().__init__() |
|
|
self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) |
|
|
self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) |
|
|
self.attn = InfiniteVLVisionAttention(config=config) |
|
|
self.mlp = InfiniteVLMLP(config, bias=True) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cu_seqlens: torch.Tensor, |
|
|
rotary_pos_emb: Optional[torch.Tensor] = None, |
|
|
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
hidden_states = hidden_states + self.attn( |
|
|
self.norm1(hidden_states), |
|
|
cu_seqlens=cu_seqlens, |
|
|
rotary_pos_emb=rotary_pos_emb, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteVLPreTrainedModel(Qwen2VLPreTrainedModel): |
|
|
""" |
|
|
Pretrained model wrapper so that InfiniteVL can plug into the same |
|
|
utilities as Qwen2-VL. |
|
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class InfiniteVisionTransformerPretrainedModel(InfiniteVLPreTrainedModel): |
|
|
""" |
|
|
InfiniteVL vision transformer that adapts the Qwen2.5-VL visual |
|
|
encoder to the modular InfiniteVL stack. |
|
|
""" |
|
|
|
|
|
config: InfiniteVLVisionConfig |
|
|
_no_split_modules = ["InfiniteVLVisionBlock"] |
|
|
|
|
|
def __init__(self, config: InfiniteVLVisionConfig, *inputs, **kwargs) -> None: |
|
|
super().__init__(config, *inputs, **kwargs) |
|
|
self.spatial_merge_size = config.spatial_merge_size |
|
|
self.patch_size = config.patch_size |
|
|
self.fullatt_block_indexes = config.fullatt_block_indexes |
|
|
self.window_size = config.window_size |
|
|
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size |
|
|
|
|
|
self.patch_embed = InfiniteVisionPatchEmbed( |
|
|
patch_size=config.patch_size, |
|
|
temporal_patch_size=config.temporal_patch_size, |
|
|
in_channels=config.in_channels, |
|
|
embed_dim=config.hidden_size, |
|
|
) |
|
|
|
|
|
head_dim = config.hidden_size // config.num_heads |
|
|
self.rotary_pos_emb = InfiniteVisionRotaryEmbedding(head_dim // 2) |
|
|
|
|
|
self.blocks = nn.ModuleList([InfiniteVLVisionBlock(config) for _ in range(config.depth)]) |
|
|
self.merger = InfiniteVLPatchMerger( |
|
|
dim=config.out_hidden_size, |
|
|
context_dim=config.hidden_size, |
|
|
spatial_merge_size=config.spatial_merge_size, |
|
|
) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: |
|
|
pos_ids = [] |
|
|
for t, h, w in grid_thw: |
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) |
|
|
hpos_ids = hpos_ids.reshape( |
|
|
h // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
w // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
) |
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3) |
|
|
hpos_ids = hpos_ids.flatten() |
|
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) |
|
|
wpos_ids = wpos_ids.reshape( |
|
|
h // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
w // self.spatial_merge_size, |
|
|
self.spatial_merge_size, |
|
|
) |
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3) |
|
|
wpos_ids = wpos_ids.flatten() |
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) |
|
|
|
|
|
pos_ids = torch.cat(pos_ids, dim=0) |
|
|
max_grid_size = grid_thw[:, 1:].max() |
|
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) |
|
|
return rotary_pos_emb |
|
|
|
|
|
def get_window_index(self, grid_thw: torch.Tensor) -> Tuple[torch.Tensor, List[int]]: |
|
|
window_index: List[torch.Tensor] = [] |
|
|
cu_window_seqlens: List[int] = [0] |
|
|
window_index_id = 0 |
|
|
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size |
|
|
|
|
|
for grid_t, grid_h, grid_w in grid_thw: |
|
|
llm_grid_h, llm_grid_w = ( |
|
|
grid_h // self.spatial_merge_size, |
|
|
grid_w // self.spatial_merge_size, |
|
|
) |
|
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) |
|
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size |
|
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size |
|
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size |
|
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size |
|
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) |
|
|
index_padded = index_padded.reshape( |
|
|
grid_t, |
|
|
num_windows_h, |
|
|
vit_merger_window_size, |
|
|
num_windows_w, |
|
|
vit_merger_window_size, |
|
|
) |
|
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( |
|
|
grid_t, |
|
|
num_windows_h * num_windows_w, |
|
|
vit_merger_window_size, |
|
|
vit_merger_window_size, |
|
|
) |
|
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) |
|
|
index_padded = index_padded.reshape(-1) |
|
|
index_new = index_padded[index_padded != -100] |
|
|
window_index.append(index_new + window_index_id) |
|
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] |
|
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) |
|
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() |
|
|
window_index_tensor = torch.cat(window_index, dim=0) |
|
|
|
|
|
return window_index_tensor, cu_window_seqlens |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): |
|
|
The final hidden states of the model. |
|
|
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): |
|
|
The temporal, height and width of feature shape of each image in LLM. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: hidden_states. |
|
|
""" |
|
|
hidden_states = self.patch_embed(hidden_states) |
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
|
|
cu_window_seqlens_tensor = torch.tensor( |
|
|
cu_window_seqlens, |
|
|
device=hidden_states.device, |
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_window_seqlens_tensor = torch.unique_consecutive(cu_window_seqlens_tensor) |
|
|
|
|
|
seq_len, _ = hidden_states.size() |
|
|
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
|
|
hidden_states = hidden_states[window_index, :, :] |
|
|
hidden_states = hidden_states.reshape(seq_len, -1) |
|
|
|
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) |
|
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
|
|
position_embeddings = (emb.cos(), emb.sin()) |
|
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
|
|
dim=0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
|
|
for layer_num, blk in enumerate(self.blocks): |
|
|
if layer_num in self.fullatt_block_indexes: |
|
|
cu_seqlens_now = cu_seqlens |
|
|
else: |
|
|
cu_seqlens_now = cu_window_seqlens_tensor |
|
|
|
|
|
hidden_states = blk( |
|
|
hidden_states, |
|
|
cu_seqlens=cu_seqlens_now, |
|
|
position_embeddings=position_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = self.merger(hidden_states) |
|
|
reverse_indices = torch.argsort(window_index) |
|
|
hidden_states = hidden_states[reverse_indices, :] |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteVLModelOutputWithPast(Qwen2VLModelOutputWithPast): |
|
|
""" |
|
|
Output type for :class:`InfiniteVLModel`. This simply extends the |
|
|
Qwen2-VL output to also track ``rope_deltas``. |
|
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class InfiniteVLModel(Qwen2VLModel): |
|
|
""" |
|
|
InfiniteVL multimodal model that reuses the Qwen2-VL language model, |
|
|
but swaps in the InfiniteVL vision encoder and a custom 3D RoPE |
|
|
indexing strategy. |
|
|
""" |
|
|
|
|
|
config: InfiniteVLConfig |
|
|
base_model_prefix = "" |
|
|
_no_split_modules = ["InfiniteVLDecoderLayer", "InfiniteVLVisionBlock"] |
|
|
|
|
|
accepts_loss_kwargs = False |
|
|
|
|
|
def __init__(self, config: InfiniteVLConfig): |
|
|
super().__init__(config) |
|
|
self.visual = InfiniteVisionTransformerPretrainedModel._from_config(config.vision_config) |
|
|
|
|
|
def get_rope_index( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Calculate the 3D RoPE index based on image and video temporal, height |
|
|
and width in the LLM token space. |
|
|
|
|
|
See the original Qwen2.5-VL paper and implementation for more |
|
|
background on the 3D M-ROPE design. |
|
|
""" |
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size |
|
|
image_token_id = self.config.image_token_id |
|
|
video_token_id = self.config.video_token_id |
|
|
vision_start_token_id = self.config.vision_start_token_id |
|
|
mrope_position_deltas = [] |
|
|
|
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): |
|
|
total_input_ids = input_ids |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask == 1 |
|
|
position_ids = torch.ones( |
|
|
3, |
|
|
input_ids.shape[0], |
|
|
input_ids.shape[1], |
|
|
dtype=input_ids.dtype, |
|
|
device=input_ids.device, |
|
|
) |
|
|
image_index, video_index = 0, 0 |
|
|
for i, input_ids_row in enumerate(total_input_ids): |
|
|
if attention_mask is not None: |
|
|
input_ids_row = input_ids_row[attention_mask[i]] |
|
|
|
|
|
image_nums, video_nums = 0, 0 |
|
|
vision_start_indices = torch.argwhere(input_ids_row == vision_start_token_id).squeeze(1) |
|
|
vision_tokens = input_ids_row[vision_start_indices + 1] |
|
|
image_nums = (vision_tokens == image_token_id).sum() |
|
|
video_nums = (vision_tokens == video_token_id).sum() |
|
|
input_tokens = input_ids_row.tolist() |
|
|
|
|
|
llm_pos_ids_list: List[torch.Tensor] = [] |
|
|
st = 0 |
|
|
remain_images, remain_videos = image_nums, video_nums |
|
|
for _ in range(image_nums + video_nums): |
|
|
if image_token_id in input_tokens and remain_images > 0: |
|
|
ed_image = input_tokens.index(image_token_id, st) |
|
|
else: |
|
|
ed_image = len(input_tokens) + 1 |
|
|
if video_token_id in input_tokens and remain_videos > 0: |
|
|
ed_video = input_tokens.index(video_token_id, st) |
|
|
else: |
|
|
ed_video = len(input_tokens) + 1 |
|
|
if ed_image < ed_video: |
|
|
t, h, w = ( |
|
|
image_grid_thw[image_index][0], |
|
|
image_grid_thw[image_index][1], |
|
|
image_grid_thw[image_index][2], |
|
|
) |
|
|
second_per_grid_t = 0 |
|
|
image_index += 1 |
|
|
remain_images -= 1 |
|
|
ed = ed_image |
|
|
else: |
|
|
t, h, w = ( |
|
|
video_grid_thw[video_index][0], |
|
|
video_grid_thw[video_index][1], |
|
|
video_grid_thw[video_index][2], |
|
|
) |
|
|
if second_per_grid_ts is not None: |
|
|
second_per_grid_t = second_per_grid_ts[video_index] |
|
|
else: |
|
|
second_per_grid_t = 1.0 |
|
|
video_index += 1 |
|
|
remain_videos -= 1 |
|
|
ed = ed_video |
|
|
|
|
|
llm_grid_t, llm_grid_h, llm_grid_w = ( |
|
|
t.item(), |
|
|
h.item() // spatial_merge_size, |
|
|
w.item() // spatial_merge_size, |
|
|
) |
|
|
text_len = ed - st |
|
|
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
|
|
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1) |
|
|
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) |
|
|
|
|
|
|
|
|
second_per_grid_t = torch.as_tensor( |
|
|
second_per_grid_t, |
|
|
dtype=range_tensor.dtype, |
|
|
device=range_tensor.device, |
|
|
) |
|
|
|
|
|
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second |
|
|
time_tensor_long = time_tensor.long() |
|
|
t_index = time_tensor_long.flatten() |
|
|
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() |
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() |
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) |
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
|
|
|
|
|
if st < len(input_tokens): |
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
|
|
text_len = len(input_tokens) - st |
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) |
|
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
|
|
if attention_mask is not None: |
|
|
position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) |
|
|
else: |
|
|
position_ids[..., i, :] = llm_positions.to(position_ids.device) |
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
|
|
|
|
|
mrope_position_deltas_tensor = torch.tensor(mrope_position_deltas).unsqueeze(1).to( |
|
|
device=input_ids.device |
|
|
) |
|
|
return position_ids, mrope_position_deltas_tensor |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) |
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] |
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] |
|
|
else: |
|
|
position_ids = ( |
|
|
torch.arange(input_ids.shape[1], device=input_ids.device) |
|
|
.view(1, 1, -1) |
|
|
.expand(3, input_ids.shape[0], -1) |
|
|
) |
|
|
mrope_position_deltas = torch.zeros( |
|
|
[input_ids.shape[0], 1], |
|
|
device=input_ids.device, |
|
|
dtype=input_ids.dtype, |
|
|
) |
|
|
|
|
|
return position_ids, mrope_position_deltas |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[tuple, InfiniteVLModelOutputWithPast]: |
|
|
r""" |
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): |
|
|
The temporal, height and width of feature shape of each image in LLM. |
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): |
|
|
The temporal, height and width of feature shape of each video in LLM. |
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): |
|
|
The RoPE index difference between sequence length and multimodal RoPE. |
|
|
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): |
|
|
The time interval (in seconds) for each grid along the temporal dimension |
|
|
in the 3D position IDs. |
|
|
""" |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
if pixel_values is not None: |
|
|
image_embeds = self.get_image_features(pixel_values, image_grid_thw) |
|
|
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
image_mask, _ = self.get_placeholder_mask( |
|
|
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds |
|
|
) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) |
|
|
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
_, video_mask = self.get_placeholder_mask( |
|
|
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds |
|
|
) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if position_ids is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prefill_compiled_stage = is_torchdynamo_compiling() and ( |
|
|
(input_ids is not None and input_ids.shape[1] != 1) |
|
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) |
|
|
) |
|
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( |
|
|
(cache_position is not None and cache_position[0] == 0) |
|
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) |
|
|
) |
|
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: |
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
|
input_ids, |
|
|
image_grid_thw, |
|
|
video_grid_thw, |
|
|
second_per_grid_ts=second_per_grid_ts, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
|
|
else: |
|
|
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device) |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1) |
|
|
position_ids = position_ids + delta.to(position_ids.device) |
|
|
|
|
|
outputs = self.language_model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
output = InfiniteVLModelOutputWithPast( |
|
|
last_hidden_state=outputs.last_hidden_state, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=self.rope_deltas, |
|
|
) |
|
|
return output if return_dict else output.to_tuple() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteVLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): |
|
|
""" |
|
|
Output type for :class:`InfiniteVLQwen2_5_VLForConditionalGeneration`. |
|
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class InfiniteVLQwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration): |
|
|
""" |
|
|
InfiniteVL causal language model head on top of :class:`InfiniteVLModel`. |
|
|
""" |
|
|
|
|
|
|
|
|
accepts_loss_kwargs = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
**kwargs: Unpack[TransformersKwargs], |
|
|
) -> Union[tuple, InfiniteVLCausalLMOutputWithPast]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should either be in |
|
|
``[0, ..., config.vocab_size]`` or ``-100`` (see ``input_ids`` docstring). Tokens with indices set to |
|
|
``-100`` are ignored (masked), the loss is only computed for the tokens with labels in |
|
|
``[0, ..., config.vocab_size]``. |
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): |
|
|
The temporal, height and width of feature shape of each image in LLM. |
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): |
|
|
The temporal, height and width of feature shape of each video in LLM. |
|
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): |
|
|
The RoPE index difference between sequence length and multimodal RoPE. |
|
|
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): |
|
|
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. |
|
|
""" |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
pixel_values=pixel_values, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
video_grid_thw=video_grid_thw, |
|
|
second_per_grid_ts=second_per_grid_ts, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
cache_position=cache_position, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function( |
|
|
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs |
|
|
) |
|
|
|
|
|
return InfiniteVLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=outputs.rope_deltas, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
inputs_embeds=None, |
|
|
cache_position=None, |
|
|
position_ids=None, |
|
|
use_cache=True, |
|
|
pixel_values=None, |
|
|
pixel_values_videos=None, |
|
|
image_grid_thw=None, |
|
|
video_grid_thw=None, |
|
|
second_per_grid_ts=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
cache_position=cache_position, |
|
|
position_ids=position_ids, |
|
|
pixel_values=pixel_values, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
video_grid_thw=video_grid_thw, |
|
|
second_per_grid_ts=second_per_grid_ts, |
|
|
use_cache=use_cache, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cache_position[0] == 0 or self.model.rope_deltas is None: |
|
|
vision_positions, rope_deltas = self.model.get_rope_index( |
|
|
model_inputs.get("input_ids", None), |
|
|
image_grid_thw=image_grid_thw, |
|
|
video_grid_thw=video_grid_thw, |
|
|
second_per_grid_ts=second_per_grid_ts, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
self.model.rope_deltas = rope_deltas |
|
|
|
|
|
elif "position_ids" in model_inputs: |
|
|
batch_size, seq_length = model_inputs["position_ids"].shape |
|
|
device = model_inputs["position_ids"].device |
|
|
position_ids = torch.arange(seq_length, device=device) |
|
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) |
|
|
delta = cache_position[0] + self.model.rope_deltas |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
vision_positions = position_ids + delta.expand_as(position_ids) |
|
|
|
|
|
|
|
|
text_positions = model_inputs["position_ids"][None, ...] |
|
|
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0) |
|
|
|
|
|
if cache_position[0] != 0: |
|
|
model_inputs["pixel_values"] = None |
|
|
model_inputs["pixel_values_videos"] = None |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfiniteVLVideosProcessorKwargs(VideosKwargs, total=False): |
|
|
fps: Union[list[float], float] |
|
|
|
|
|
|
|
|
class InfiniteVLImagesKwargs(Qwen2VLImagesKwargs): |
|
|
pass |
|
|
|
|
|
|
|
|
class InfiniteVLProcessorKwargs(ProcessingKwargs, total=False): |
|
|
images_kwargs: InfiniteVLImagesKwargs |
|
|
videos_kwargs: InfiniteVLVideosProcessorKwargs |
|
|
_defaults = { |
|
|
"text_kwargs": { |
|
|
"padding": False, |
|
|
"return_mm_token_type_ids": False, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
class InfiniteVLProcessor(Qwen2VLProcessor): |
|
|
r""" |
|
|
Constructs an InfiniteVL processor which wraps a Qwen2-VL image processor |
|
|
and a Qwen2 tokenizer into a single processor. |
|
|
|
|
|
:class:`InfiniteVLProcessor` offers all the functionalities of |
|
|
:class:`Qwen2VLImageProcessor` and :class:`Qwen2TokenizerFast`. See |
|
|
:meth:`InfiniteVLProcessor.__call__` and :meth:`InfiniteVLProcessor.decode` |
|
|
for more information. |
|
|
|
|
|
Args: |
|
|
image_processor (:class:`Qwen2VLImageProcessor`, *optional*): |
|
|
The image processor is a required input. |
|
|
tokenizer (:class:`Qwen2TokenizerFast`, *optional*): |
|
|
The tokenizer is a required input. |
|
|
video_processor (:class:`InfiniteVLVideoProcessor`, *optional*): |
|
|
The video processor is a required input. |
|
|
chat_template (`str`, *optional*): |
|
|
A Jinja template which will be used to convert lists of messages |
|
|
in a chat into a tokenizable string. |
|
|
""" |
|
|
|
|
|
image_processor_class = "AutoImageProcessor" |
|
|
|
|
|
@property |
|
|
def model_input_names(self): |
|
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
|
image_processor_input_names = self.image_processor.model_input_names |
|
|
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
return names_from_processor + ["second_per_grid_ts"] |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Optional[ImageInput] = None, |
|
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, |
|
|
videos: Optional[VideoInput] = None, |
|
|
**kwargs: Unpack[InfiniteVLProcessorKwargs], |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Main method to prepare for the model one or several sequence(s) and image(s). |
|
|
|
|
|
This method forwards the ``text`` and ``kwargs`` arguments to |
|
|
:class:`Qwen2TokenizerFast.__call__` if ``text`` is not ``None`` |
|
|
to encode the text. To prepare the vision inputs, this method |
|
|
forwards the ``images`` / ``videos`` and ``kwargs`` arguments to |
|
|
:class:`Qwen2VLImageProcessor.__call__` and the corresponding |
|
|
video processor when they are not ``None``. |
|
|
""" |
|
|
output_kwargs = self._merge_kwargs( |
|
|
InfiniteVLProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
image_inputs = videos_inputs = {} |
|
|
if images is not None: |
|
|
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) |
|
|
image_grid_thw = image_inputs["image_grid_thw"] |
|
|
|
|
|
if videos is not None: |
|
|
fps = output_kwargs["videos_kwargs"].get("fps", 2.0) |
|
|
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) |
|
|
video_grid_thw = videos_inputs["video_grid_thw"] |
|
|
|
|
|
if isinstance(fps, (int, float)): |
|
|
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw) |
|
|
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): |
|
|
second_per_grid_ts = [self.video_processor.temporal_patch_size / tmp for tmp in fps] |
|
|
else: |
|
|
raise ValueError( |
|
|
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the " |
|
|
f"length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." |
|
|
) |
|
|
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) |
|
|
|
|
|
if not isinstance(text, list): |
|
|
text = [text] |
|
|
|
|
|
|
|
|
text = text.copy() |
|
|
if images is not None: |
|
|
merge_length = self.image_processor.merge_size**2 |
|
|
index = 0 |
|
|
for i in range(len(text)): |
|
|
while self.image_token in text[i]: |
|
|
num_image_tokens = image_grid_thw[index].prod() // merge_length |
|
|
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) |
|
|
index += 1 |
|
|
text[i] = text[i].replace("<|placeholder|>", self.image_token) |
|
|
|
|
|
if videos is not None: |
|
|
merge_length = self.video_processor.merge_size**2 |
|
|
index = 0 |
|
|
for i in range(len(text)): |
|
|
while self.video_token in text[i]: |
|
|
num_video_tokens = video_grid_thw[index].prod() // merge_length |
|
|
text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) |
|
|
index += 1 |
|
|
text[i] = text[i].replace("<|placeholder|>", self.video_token) |
|
|
|
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
|
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None) |
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
|
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) |
|
|
|
|
|
if return_mm_token_type_ids: |
|
|
array_ids = np.array(text_inputs["input_ids"]) |
|
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) |
|
|
mm_token_type_ids[array_ids == self.image_token_id] = 1 |
|
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() |
|
|
|
|
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) |
|
|
|
|
|
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs) -> MultiModalData: |
|
|
""" |
|
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. |
|
|
|
|
|
Args: |
|
|
image_sizes (`list[list[int]]`, *optional*): |
|
|
The input sizes formatted as (height, width) per each image. |
|
|
video_sizes (`list[list[int]]`, *optional*): |
|
|
The input sizes formatted as (num_frames, height, width) per each video. |
|
|
|
|
|
Returns: |
|
|
:class:`MultiModalData`: A :class:`MultiModalData` object holding number of tokens per each of the provided |
|
|
input modalities, along with other useful data. |
|
|
""" |
|
|
|
|
|
vision_data = {} |
|
|
merge_size: Optional[int] = None |
|
|
|
|
|
if image_sizes is not None: |
|
|
images_kwargs = InfiniteVLProcessorKwargs._defaults.get("images_kwargs", {}) |
|
|
images_kwargs.update(kwargs) |
|
|
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size |
|
|
|
|
|
num_image_patches = [ |
|
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) |
|
|
for image_size in image_sizes |
|
|
] |
|
|
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] |
|
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) |
|
|
|
|
|
if video_sizes is not None: |
|
|
videos_kwargs = InfiniteVLProcessorKwargs._defaults.get("videos_kwargs", {}) |
|
|
videos_kwargs.update(kwargs) |
|
|
|
|
|
video_merge_size = videos_kwargs.get("merge_size", None) or self.video_processor.merge_size |
|
|
|
|
|
num_video_patches = [ |
|
|
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) |
|
|
for video_size in video_sizes |
|
|
] |
|
|
num_video_tokens = [ |
|
|
(num_patches // video_merge_size**2) for num_patches in num_video_patches |
|
|
] |
|
|
vision_data["num_video_tokens"] = num_video_tokens |
|
|
|
|
|
return MultiModalData(**vision_data) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
|
|
"InfiniteVLConfig", |
|
|
"InfiniteVLTextConfig", |
|
|
"InfiniteVLQwen2_5_VLForConditionalGeneration", |
|
|
"InfiniteVLModel", |
|
|
"InfiniteVLPreTrainedModel", |
|
|
"InfiniteVLProcessor", |
|
|
] |