add prefill chunking
Browse files- chunk_utils.py +137 -0
- processing_qwen3_vl.py +24 -2
- video_preprocessor_config.json +1 -0
chunk_utils.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def _visual_token_cums(
|
| 7 |
+
sequence_idx: int,
|
| 8 |
+
input_ids: torch.Tensor | np.ndarray,
|
| 9 |
+
image_token_id: int,
|
| 10 |
+
video_token_id: int,
|
| 11 |
+
merge_size: int,
|
| 12 |
+
focus_size: int,
|
| 13 |
+
image_grid_thw: torch.Tensor | np.ndarray | None,
|
| 14 |
+
video_grid_thw: torch.Tensor | np.ndarray | None,
|
| 15 |
+
**kwargs,
|
| 16 |
+
) -> list[int]:
|
| 17 |
+
cums: deque[int] = deque()
|
| 18 |
+
|
| 19 |
+
video_idx = 0
|
| 20 |
+
frame_idx = 0
|
| 21 |
+
image_idx = 0
|
| 22 |
+
token_idx = 0
|
| 23 |
+
in_video = False
|
| 24 |
+
cum = 0
|
| 25 |
+
sequence = input_ids[sequence_idx].tolist()
|
| 26 |
+
|
| 27 |
+
while token_idx < len(sequence):
|
| 28 |
+
token = sequence[token_idx]
|
| 29 |
+
if token == image_token_id:
|
| 30 |
+
assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used"
|
| 31 |
+
_, h, w = image_grid_thw[image_idx].tolist()
|
| 32 |
+
num_tokens = h * w // (merge_size ** 2)
|
| 33 |
+
cums.append(num_tokens)
|
| 34 |
+
token_idx += num_tokens
|
| 35 |
+
image_idx += 1
|
| 36 |
+
elif token == video_token_id:
|
| 37 |
+
assert video_grid_thw is not None, "video_grid_thw must be provided when video_token_id is used"
|
| 38 |
+
t, h, w = video_grid_thw[video_idx].tolist()
|
| 39 |
+
assert t % focus_size == 0, f"Number of frames {t} must be divisible by focus_size {focus_size}"
|
| 40 |
+
num_tokens = h * w // (merge_size ** 2)
|
| 41 |
+
cum += num_tokens
|
| 42 |
+
|
| 43 |
+
if (frame_idx + 1) % focus_size == 0:
|
| 44 |
+
cums.append(cum)
|
| 45 |
+
cum = 0
|
| 46 |
+
in_video = False
|
| 47 |
+
else:
|
| 48 |
+
in_video = True
|
| 49 |
+
|
| 50 |
+
frame_idx += 1
|
| 51 |
+
if frame_idx == t:
|
| 52 |
+
video_idx += 1
|
| 53 |
+
frame_idx = 0
|
| 54 |
+
|
| 55 |
+
token_idx += num_tokens
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
if not in_video:
|
| 59 |
+
cums.append(1)
|
| 60 |
+
else:
|
| 61 |
+
cum += 1
|
| 62 |
+
token_idx += 1
|
| 63 |
+
|
| 64 |
+
return list(cums)
|
| 65 |
+
|
| 66 |
+
def visual_token_cums(
|
| 67 |
+
input_ids: torch.Tensor | np.ndarray,
|
| 68 |
+
image_token_id: int,
|
| 69 |
+
video_token_id: int,
|
| 70 |
+
merge_size: int,
|
| 71 |
+
focus_size: int,
|
| 72 |
+
image_grid_thw: torch.Tensor | np.ndarray | None,
|
| 73 |
+
video_grid_thw: torch.Tensor | np.ndarray | None,
|
| 74 |
+
**kwargs,
|
| 75 |
+
) -> list[list[int]]:
|
| 76 |
+
return [
|
| 77 |
+
_visual_token_cums(
|
| 78 |
+
sequence_idx=i,
|
| 79 |
+
input_ids=input_ids,
|
| 80 |
+
image_token_id=image_token_id,
|
| 81 |
+
video_token_id=video_token_id,
|
| 82 |
+
merge_size=merge_size,
|
| 83 |
+
focus_size=focus_size,
|
| 84 |
+
image_grid_thw=image_grid_thw,
|
| 85 |
+
video_grid_thw=video_grid_thw,
|
| 86 |
+
)
|
| 87 |
+
for i in range(input_ids.shape[0])
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
def chunk_tokens(
|
| 91 |
+
max_chunk_size: int,
|
| 92 |
+
input_ids: torch.Tensor | np.ndarray,
|
| 93 |
+
image_token_id: int,
|
| 94 |
+
video_token_id: int,
|
| 95 |
+
merge_size: int,
|
| 96 |
+
focus_size: int,
|
| 97 |
+
image_grid_thw: torch.Tensor | np.ndarray | None,
|
| 98 |
+
video_grid_thw: torch.Tensor | np.ndarray | None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
) -> list[list[tuple[int, int]]]:
|
| 101 |
+
cums = visual_token_cums(
|
| 102 |
+
input_ids=input_ids,
|
| 103 |
+
image_token_id=image_token_id,
|
| 104 |
+
video_token_id=video_token_id,
|
| 105 |
+
merge_size=merge_size,
|
| 106 |
+
focus_size=focus_size,
|
| 107 |
+
image_grid_thw=image_grid_thw,
|
| 108 |
+
video_grid_thw=video_grid_thw,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
chunked_cums: list[list[tuple[int, int]]] = []
|
| 113 |
+
|
| 114 |
+
for sequence_cums in cums:
|
| 115 |
+
chunks = []
|
| 116 |
+
current_chunk_start = 0
|
| 117 |
+
current_chunk_size = 0
|
| 118 |
+
|
| 119 |
+
for cum in sequence_cums:
|
| 120 |
+
if current_chunk_size + cum > max_chunk_size:
|
| 121 |
+
chunks.append((current_chunk_start, current_chunk_start + current_chunk_size))
|
| 122 |
+
current_chunk_start += current_chunk_size
|
| 123 |
+
current_chunk_size = 0
|
| 124 |
+
|
| 125 |
+
current_chunk_size += cum
|
| 126 |
+
|
| 127 |
+
if current_chunk_size > 0:
|
| 128 |
+
chunks.append((current_chunk_start, current_chunk_start + current_chunk_size))
|
| 129 |
+
|
| 130 |
+
chunked_cums.append(chunks)
|
| 131 |
+
|
| 132 |
+
num_chunks = max(len(chunks) for chunks in chunked_cums)
|
| 133 |
+
for chunks in chunked_cums:
|
| 134 |
+
while len(chunks) < num_chunks:
|
| 135 |
+
chunks.append((chunks[-1][1], chunks[-1][1]))
|
| 136 |
+
|
| 137 |
+
return chunked_cums
|
processing_qwen3_vl.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
|
| 3 |
import numpy as np
|
| 4 |
-
|
| 5 |
from transformers.feature_extraction_utils import BatchFeature
|
| 6 |
from transformers.image_utils import ImageInput
|
| 7 |
from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
|
@@ -9,12 +8,15 @@ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
| 9 |
from transformers.utils import logging
|
| 10 |
from transformers.video_utils import VideoInput
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
logger = logging.get_logger(__name__)
|
| 14 |
|
| 15 |
|
| 16 |
class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
|
| 17 |
focus_size: Optional[int]
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class Qwen3VLImagesKwargs(ImagesKwargs):
|
|
@@ -225,7 +227,27 @@ class ZFQwen3VLProcessor(ProcessorMixin):
|
|
| 225 |
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 226 |
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 227 |
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
| 231 |
"""
|
|
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from transformers.feature_extraction_utils import BatchFeature
|
| 5 |
from transformers.image_utils import ImageInput
|
| 6 |
from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
|
|
|
| 8 |
from transformers.utils import logging
|
| 9 |
from transformers.video_utils import VideoInput
|
| 10 |
|
| 11 |
+
from .chunk_utils import chunk_tokens
|
| 12 |
+
|
| 13 |
|
| 14 |
logger = logging.get_logger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
|
| 18 |
focus_size: Optional[int]
|
| 19 |
+
max_chunk_size: Optional[int]
|
| 20 |
|
| 21 |
|
| 22 |
class Qwen3VLImagesKwargs(ImagesKwargs):
|
|
|
|
| 227 |
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
| 228 |
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
| 229 |
|
| 230 |
+
chunks = chunk_tokens(
|
| 231 |
+
max_chunk_size=self.video_processor.max_chunk_size, # type: ignore
|
| 232 |
+
input_ids=np.array(text_inputs["input_ids"]),
|
| 233 |
+
image_token_id=self.image_token_id,
|
| 234 |
+
video_token_id=self.video_token_id,
|
| 235 |
+
merge_size=self.image_processor.merge_size, # type: ignore
|
| 236 |
+
focus_size=self.video_processor.focus_size, # type: ignore
|
| 237 |
+
image_grid_thw=image_grid_thw,
|
| 238 |
+
video_grid_thw=video_grid_thw,
|
| 239 |
+
)
|
| 240 |
+
image_token_mask = (text_inputs["input_ids"] == self.image_token_id)
|
| 241 |
+
video_token_mask = (text_inputs["input_ids"] == self.video_token_id)
|
| 242 |
+
|
| 243 |
+
return BatchFeature(data={
|
| 244 |
+
**text_inputs,
|
| 245 |
+
**image_inputs,
|
| 246 |
+
**videos_inputs,
|
| 247 |
+
"token_chunks": chunks,
|
| 248 |
+
"image_token_mask": image_token_mask,
|
| 249 |
+
"video_token_mask": video_token_mask,
|
| 250 |
+
}, tensor_type=return_tensors)
|
| 251 |
|
| 252 |
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
| 253 |
"""
|
video_preprocessor_config.json
CHANGED
|
@@ -26,6 +26,7 @@
|
|
| 26 |
0.5
|
| 27 |
],
|
| 28 |
"input_data_format": null,
|
|
|
|
| 29 |
"max_frames": 2048,
|
| 30 |
"merge_size": 2,
|
| 31 |
"min_frames": 4,
|
|
|
|
| 26 |
0.5
|
| 27 |
],
|
| 28 |
"input_data_format": null,
|
| 29 |
+
"max_chunk_size": 4096,
|
| 30 |
"max_frames": 2048,
|
| 31 |
"merge_size": 2,
|
| 32 |
"min_frames": 4,
|