TYTTYTTYT commited on
Commit
e7818b4
·
verified ·
1 Parent(s): 12c3a2e

add prefill chunking

Browse files
chunk_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ def _visual_token_cums(
7
+ sequence_idx: int,
8
+ input_ids: torch.Tensor | np.ndarray,
9
+ image_token_id: int,
10
+ video_token_id: int,
11
+ merge_size: int,
12
+ focus_size: int,
13
+ image_grid_thw: torch.Tensor | np.ndarray | None,
14
+ video_grid_thw: torch.Tensor | np.ndarray | None,
15
+ **kwargs,
16
+ ) -> list[int]:
17
+ cums: deque[int] = deque()
18
+
19
+ video_idx = 0
20
+ frame_idx = 0
21
+ image_idx = 0
22
+ token_idx = 0
23
+ in_video = False
24
+ cum = 0
25
+ sequence = input_ids[sequence_idx].tolist()
26
+
27
+ while token_idx < len(sequence):
28
+ token = sequence[token_idx]
29
+ if token == image_token_id:
30
+ assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used"
31
+ _, h, w = image_grid_thw[image_idx].tolist()
32
+ num_tokens = h * w // (merge_size ** 2)
33
+ cums.append(num_tokens)
34
+ token_idx += num_tokens
35
+ image_idx += 1
36
+ elif token == video_token_id:
37
+ assert video_grid_thw is not None, "video_grid_thw must be provided when video_token_id is used"
38
+ t, h, w = video_grid_thw[video_idx].tolist()
39
+ assert t % focus_size == 0, f"Number of frames {t} must be divisible by focus_size {focus_size}"
40
+ num_tokens = h * w // (merge_size ** 2)
41
+ cum += num_tokens
42
+
43
+ if (frame_idx + 1) % focus_size == 0:
44
+ cums.append(cum)
45
+ cum = 0
46
+ in_video = False
47
+ else:
48
+ in_video = True
49
+
50
+ frame_idx += 1
51
+ if frame_idx == t:
52
+ video_idx += 1
53
+ frame_idx = 0
54
+
55
+ token_idx += num_tokens
56
+
57
+ else:
58
+ if not in_video:
59
+ cums.append(1)
60
+ else:
61
+ cum += 1
62
+ token_idx += 1
63
+
64
+ return list(cums)
65
+
66
+ def visual_token_cums(
67
+ input_ids: torch.Tensor | np.ndarray,
68
+ image_token_id: int,
69
+ video_token_id: int,
70
+ merge_size: int,
71
+ focus_size: int,
72
+ image_grid_thw: torch.Tensor | np.ndarray | None,
73
+ video_grid_thw: torch.Tensor | np.ndarray | None,
74
+ **kwargs,
75
+ ) -> list[list[int]]:
76
+ return [
77
+ _visual_token_cums(
78
+ sequence_idx=i,
79
+ input_ids=input_ids,
80
+ image_token_id=image_token_id,
81
+ video_token_id=video_token_id,
82
+ merge_size=merge_size,
83
+ focus_size=focus_size,
84
+ image_grid_thw=image_grid_thw,
85
+ video_grid_thw=video_grid_thw,
86
+ )
87
+ for i in range(input_ids.shape[0])
88
+ ]
89
+
90
+ def chunk_tokens(
91
+ max_chunk_size: int,
92
+ input_ids: torch.Tensor | np.ndarray,
93
+ image_token_id: int,
94
+ video_token_id: int,
95
+ merge_size: int,
96
+ focus_size: int,
97
+ image_grid_thw: torch.Tensor | np.ndarray | None,
98
+ video_grid_thw: torch.Tensor | np.ndarray | None,
99
+ **kwargs,
100
+ ) -> list[list[tuple[int, int]]]:
101
+ cums = visual_token_cums(
102
+ input_ids=input_ids,
103
+ image_token_id=image_token_id,
104
+ video_token_id=video_token_id,
105
+ merge_size=merge_size,
106
+ focus_size=focus_size,
107
+ image_grid_thw=image_grid_thw,
108
+ video_grid_thw=video_grid_thw,
109
+ **kwargs,
110
+ )
111
+
112
+ chunked_cums: list[list[tuple[int, int]]] = []
113
+
114
+ for sequence_cums in cums:
115
+ chunks = []
116
+ current_chunk_start = 0
117
+ current_chunk_size = 0
118
+
119
+ for cum in sequence_cums:
120
+ if current_chunk_size + cum > max_chunk_size:
121
+ chunks.append((current_chunk_start, current_chunk_start + current_chunk_size))
122
+ current_chunk_start += current_chunk_size
123
+ current_chunk_size = 0
124
+
125
+ current_chunk_size += cum
126
+
127
+ if current_chunk_size > 0:
128
+ chunks.append((current_chunk_start, current_chunk_start + current_chunk_size))
129
+
130
+ chunked_cums.append(chunks)
131
+
132
+ num_chunks = max(len(chunks) for chunks in chunked_cums)
133
+ for chunks in chunked_cums:
134
+ while len(chunks) < num_chunks:
135
+ chunks.append((chunks[-1][1], chunks[-1][1]))
136
+
137
+ return chunked_cums
processing_qwen3_vl.py CHANGED
@@ -1,7 +1,6 @@
1
  from typing import Optional, Union
2
 
3
  import numpy as np
4
-
5
  from transformers.feature_extraction_utils import BatchFeature
6
  from transformers.image_utils import ImageInput
7
  from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
@@ -9,12 +8,15 @@ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
9
  from transformers.utils import logging
10
  from transformers.video_utils import VideoInput
11
 
 
 
12
 
13
  logger = logging.get_logger(__name__)
14
 
15
 
16
  class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
17
  focus_size: Optional[int]
 
18
 
19
 
20
  class Qwen3VLImagesKwargs(ImagesKwargs):
@@ -225,7 +227,27 @@ class ZFQwen3VLProcessor(ProcessorMixin):
225
  mm_token_type_ids[array_ids == self.image_token_id] = 1
226
  text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
227
 
228
- return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
231
  """
 
1
  from typing import Optional, Union
2
 
3
  import numpy as np
 
4
  from transformers.feature_extraction_utils import BatchFeature
5
  from transformers.image_utils import ImageInput
6
  from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
 
8
  from transformers.utils import logging
9
  from transformers.video_utils import VideoInput
10
 
11
+ from .chunk_utils import chunk_tokens
12
+
13
 
14
  logger = logging.get_logger(__name__)
15
 
16
 
17
  class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
18
  focus_size: Optional[int]
19
+ max_chunk_size: Optional[int]
20
 
21
 
22
  class Qwen3VLImagesKwargs(ImagesKwargs):
 
227
  mm_token_type_ids[array_ids == self.image_token_id] = 1
228
  text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
229
 
230
+ chunks = chunk_tokens(
231
+ max_chunk_size=self.video_processor.max_chunk_size, # type: ignore
232
+ input_ids=np.array(text_inputs["input_ids"]),
233
+ image_token_id=self.image_token_id,
234
+ video_token_id=self.video_token_id,
235
+ merge_size=self.image_processor.merge_size, # type: ignore
236
+ focus_size=self.video_processor.focus_size, # type: ignore
237
+ image_grid_thw=image_grid_thw,
238
+ video_grid_thw=video_grid_thw,
239
+ )
240
+ image_token_mask = (text_inputs["input_ids"] == self.image_token_id)
241
+ video_token_mask = (text_inputs["input_ids"] == self.video_token_id)
242
+
243
+ return BatchFeature(data={
244
+ **text_inputs,
245
+ **image_inputs,
246
+ **videos_inputs,
247
+ "token_chunks": chunks,
248
+ "image_token_mask": image_token_mask,
249
+ "video_token_mask": video_token_mask,
250
+ }, tensor_type=return_tensors)
251
 
252
  def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
253
  """
video_preprocessor_config.json CHANGED
@@ -26,6 +26,7 @@
26
  0.5
27
  ],
28
  "input_data_format": null,
 
29
  "max_frames": 2048,
30
  "merge_size": 2,
31
  "min_frames": 4,
 
26
  0.5
27
  ],
28
  "input_data_format": null,
29
+ "max_chunk_size": 4096,
30
  "max_frames": 2048,
31
  "merge_size": 2,
32
  "min_frames": 4,