TYTTYTTYT commited on
Commit
0c9c5ce
·
verified ·
1 Parent(s): 496d1ea

add image and video chunk to each token chunk in the processor

Browse files
Files changed (2) hide show
  1. chunk_utils.py +65 -14
  2. processing_qwen3_vl.py +43 -5
chunk_utils.py CHANGED
@@ -1,8 +1,17 @@
1
  from collections import deque
 
2
 
3
  import torch
4
  import numpy as np
5
 
 
 
 
 
 
 
 
 
6
  def _visual_token_cums(
7
  sequence_idx: int,
8
  input_ids: torch.Tensor | np.ndarray,
@@ -13,8 +22,8 @@ def _visual_token_cums(
13
  image_grid_thw: torch.Tensor | np.ndarray | None,
14
  video_grid_thw: torch.Tensor | np.ndarray | None,
15
  **kwargs,
16
- ) -> list[int]:
17
- cums: deque[int] = deque()
18
 
19
  video_idx = 0
20
  frame_idx = 0
@@ -30,7 +39,12 @@ def _visual_token_cums(
30
  assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used"
31
  _, h, w = image_grid_thw[image_idx].tolist()
32
  num_tokens = h * w // (merge_size ** 2)
33
- cums.append(num_tokens)
 
 
 
 
 
34
  token_idx += num_tokens
35
  image_idx += 1
36
  elif token == video_token_id:
@@ -41,7 +55,11 @@ def _visual_token_cums(
41
  cum += num_tokens
42
 
43
  if (frame_idx + 1) % focus_size == 0:
44
- cums.append(cum)
 
 
 
 
45
  cum = 0
46
  in_video = False
47
  else:
@@ -56,13 +74,14 @@ def _visual_token_cums(
56
 
57
  else:
58
  if not in_video:
59
- cums.append(1)
60
  else:
61
  cum += 1
62
  token_idx += 1
63
 
64
  return list(cums)
65
 
 
66
  def visual_token_cums(
67
  input_ids: torch.Tensor | np.ndarray,
68
  image_token_id: int,
@@ -72,7 +91,7 @@ def visual_token_cums(
72
  image_grid_thw: torch.Tensor | np.ndarray | None,
73
  video_grid_thw: torch.Tensor | np.ndarray | None,
74
  **kwargs,
75
- ) -> list[list[int]]:
76
  return [
77
  _visual_token_cums(
78
  sequence_idx=i,
@@ -87,6 +106,15 @@ def visual_token_cums(
87
  for i in range(input_ids.shape[0])
88
  ]
89
 
 
 
 
 
 
 
 
 
 
90
  def chunk_tokens(
91
  max_chunk_size: int,
92
  input_ids: torch.Tensor | np.ndarray,
@@ -97,7 +125,7 @@ def chunk_tokens(
97
  image_grid_thw: torch.Tensor | np.ndarray | None,
98
  video_grid_thw: torch.Tensor | np.ndarray | None,
99
  **kwargs,
100
- ) -> list[list[tuple[int, int]]]:
101
  cums = visual_token_cums(
102
  input_ids=input_ids,
103
  image_token_id=image_token_id,
@@ -109,29 +137,52 @@ def chunk_tokens(
109
  **kwargs,
110
  )
111
 
112
- chunked_cums: list[list[tuple[int, int]]] = []
113
 
114
  for sequence_cums in cums:
115
- chunks = []
116
  current_chunk_start = 0
117
  current_chunk_size = 0
 
 
118
 
119
  for cum in sequence_cums:
120
- if current_chunk_size + cum > max_chunk_size:
121
- chunks.append((current_chunk_start, current_chunk_start + current_chunk_size))
 
 
 
 
 
 
 
 
 
122
  current_chunk_start += current_chunk_size
123
  current_chunk_size = 0
 
 
124
 
125
- current_chunk_size += cum
126
 
127
  if current_chunk_size > 0:
128
- chunks.append((current_chunk_start, current_chunk_start + current_chunk_size))
 
 
 
 
 
129
 
130
  chunked_cums.append(chunks)
131
 
132
  num_chunks = max(len(chunks) for chunks in chunked_cums)
133
  for chunks in chunked_cums:
134
  while len(chunks) < num_chunks:
135
- chunks.append((chunks[-1][1], chunks[-1][1]))
 
 
 
 
 
136
 
137
  return chunked_cums
 
1
  from collections import deque
2
+ from dataclasses import dataclass
3
 
4
  import torch
5
  import numpy as np
6
 
7
+
8
+ @dataclass
9
+ class ChunkCum:
10
+ cum: int
11
+ image_grid_thw: tuple[int, int, int] | None = None
12
+ video_grid_thw: tuple[int, int, int] | None = None
13
+
14
+
15
  def _visual_token_cums(
16
  sequence_idx: int,
17
  input_ids: torch.Tensor | np.ndarray,
 
22
  image_grid_thw: torch.Tensor | np.ndarray | None,
23
  video_grid_thw: torch.Tensor | np.ndarray | None,
24
  **kwargs,
25
+ ) -> list[ChunkCum]:
26
+ cums: deque[ChunkCum] = deque()
27
 
28
  video_idx = 0
29
  frame_idx = 0
 
39
  assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used"
40
  _, h, w = image_grid_thw[image_idx].tolist()
41
  num_tokens = h * w // (merge_size ** 2)
42
+ cums.append(ChunkCum(
43
+ cum=num_tokens,
44
+ image_grid_thw=(1, h, w),
45
+ video_grid_thw=None
46
+ )
47
+ )
48
  token_idx += num_tokens
49
  image_idx += 1
50
  elif token == video_token_id:
 
55
  cum += num_tokens
56
 
57
  if (frame_idx + 1) % focus_size == 0:
58
+ cums.append(ChunkCum(
59
+ cum=cum,
60
+ image_grid_thw=None,
61
+ video_grid_thw=(focus_size, h, w),
62
+ ))
63
  cum = 0
64
  in_video = False
65
  else:
 
74
 
75
  else:
76
  if not in_video:
77
+ cums.append(ChunkCum(cum=cum, image_grid_thw=None, video_grid_thw=None))
78
  else:
79
  cum += 1
80
  token_idx += 1
81
 
82
  return list(cums)
83
 
84
+
85
  def visual_token_cums(
86
  input_ids: torch.Tensor | np.ndarray,
87
  image_token_id: int,
 
91
  image_grid_thw: torch.Tensor | np.ndarray | None,
92
  video_grid_thw: torch.Tensor | np.ndarray | None,
93
  **kwargs,
94
+ ) -> list[list[ChunkCum]]:
95
  return [
96
  _visual_token_cums(
97
  sequence_idx=i,
 
106
  for i in range(input_ids.shape[0])
107
  ]
108
 
109
+
110
+ @dataclass
111
+ class Chunk:
112
+ start: int
113
+ end: int
114
+ image_grid_thws: list[tuple[int, int, int]]
115
+ video_grid_thws: list[tuple[int, int, int]]
116
+
117
+
118
  def chunk_tokens(
119
  max_chunk_size: int,
120
  input_ids: torch.Tensor | np.ndarray,
 
125
  image_grid_thw: torch.Tensor | np.ndarray | None,
126
  video_grid_thw: torch.Tensor | np.ndarray | None,
127
  **kwargs,
128
+ ) -> list[list[Chunk]]:
129
  cums = visual_token_cums(
130
  input_ids=input_ids,
131
  image_token_id=image_token_id,
 
137
  **kwargs,
138
  )
139
 
140
+ chunked_cums: list[list[Chunk]] = []
141
 
142
  for sequence_cums in cums:
143
+ chunks: list[Chunk] = []
144
  current_chunk_start = 0
145
  current_chunk_size = 0
146
+ current_image_grid_thws: list[tuple[int, int, int]] = []
147
+ current_video_grid_thws: list[tuple[int, int, int]] = []
148
 
149
  for cum in sequence_cums:
150
+ if cum.image_grid_thw is not None:
151
+ current_image_grid_thws.append(cum.image_grid_thw)
152
+ if cum.video_grid_thw is not None:
153
+ current_video_grid_thws.append(cum.video_grid_thw)
154
+ if current_chunk_size + cum.cum > max_chunk_size:
155
+ chunks.append(Chunk(
156
+ start=current_chunk_start,
157
+ end=current_chunk_start + current_chunk_size,
158
+ image_grid_thws=current_image_grid_thws,
159
+ video_grid_thws=current_video_grid_thws
160
+ ))
161
  current_chunk_start += current_chunk_size
162
  current_chunk_size = 0
163
+ current_image_grid_thws = []
164
+ current_video_grid_thws = []
165
 
166
+ current_chunk_size += cum.cum
167
 
168
  if current_chunk_size > 0:
169
+ chunks.append(Chunk(
170
+ start=current_chunk_start,
171
+ end=current_chunk_start + current_chunk_size,
172
+ image_grid_thws=current_image_grid_thws,
173
+ video_grid_thws=current_video_grid_thws,
174
+ ))
175
 
176
  chunked_cums.append(chunks)
177
 
178
  num_chunks = max(len(chunks) for chunks in chunked_cums)
179
  for chunks in chunked_cums:
180
  while len(chunks) < num_chunks:
181
+ chunks.append(Chunk(
182
+ start=chunks[-1].end,
183
+ end=chunks[-1].end,
184
+ image_grid_thws=[],
185
+ video_grid_thws=[],
186
+ ))
187
 
188
  return chunked_cums
processing_qwen3_vl.py CHANGED
@@ -1,7 +1,7 @@
1
- from typing import Optional, Union
2
 
 
3
  import numpy as np
4
- from transformers.feature_extraction_utils import BatchFeature
5
  from transformers.image_utils import ImageInput
6
  from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
7
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
@@ -14,6 +14,44 @@ from .chunk_utils import chunk_tokens
14
  logger = logging.get_logger(__name__)
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
18
  focus_size: Optional[int]
19
  max_chunk_size: Optional[int]
@@ -99,7 +137,7 @@ class ZFQwen3VLProcessor(ProcessorMixin):
99
  text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, # type: ignore
100
  videos: VideoInput = None, # type: ignore
101
  **kwargs: Unpack[Qwen3VLProcessorKwargs],
102
- ) -> BatchFeature:
103
  """
104
  Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
105
  and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
@@ -125,7 +163,7 @@ class ZFQwen3VLProcessor(ProcessorMixin):
125
  - `'jax'`: Return JAX `jnp.ndarray` objects.
126
 
127
  Returns:
128
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
129
 
130
  - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
131
  - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
@@ -246,7 +284,7 @@ class ZFQwen3VLProcessor(ProcessorMixin):
246
  image_token_mask = image_token_mask * array_attention_mask
247
  video_token_mask = video_token_mask * array_attention_mask
248
 
249
- return BatchFeature(data={
250
  **text_inputs,
251
  **image_inputs,
252
  **videos_inputs,
 
1
+ from typing import Any, Optional, Union
2
 
3
+ import torch
4
  import numpy as np
 
5
  from transformers.image_utils import ImageInput
6
  from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
7
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
 
14
  logger = logging.get_logger(__name__)
15
 
16
 
17
+ class MMFeature(dict):
18
+ def __init__(self, data, tensor_type: str | None = None):
19
+ super().__init__(data)
20
+ self.tensor_type = tensor_type
21
+ self.convert_to_tensor()
22
+
23
+ def convert_to_tensor(self) -> "MMFeature":
24
+ if self.tensor_type is None:
25
+ return self
26
+
27
+ match self.tensor_type:
28
+ case "pt":
29
+ for k, v in self.items():
30
+ if not isinstance(v, torch.Tensor):
31
+ try:
32
+ self[k] = torch.tensor(v)
33
+ except Exception:
34
+ pass
35
+ case "np":
36
+ for k, v in self.items():
37
+ if not isinstance(v, np.ndarray):
38
+ try:
39
+ self[k] = np.array(v)
40
+ except Exception:
41
+ pass
42
+ case _:
43
+ raise ValueError(f"Unsupported tensor type: {self.tensor_type}")
44
+
45
+ return self
46
+
47
+ def to(self, target: Any) -> "MMFeature":
48
+ for k, v in self.items():
49
+ if isinstance(v, torch.Tensor):
50
+ self[k] = v.to(target)
51
+
52
+ return self
53
+
54
+
55
  class Qwen3VLVideosProcessorKwargs(VideosKwargs, total=False):
56
  focus_size: Optional[int]
57
  max_chunk_size: Optional[int]
 
137
  text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, # type: ignore
138
  videos: VideoInput = None, # type: ignore
139
  **kwargs: Unpack[Qwen3VLProcessorKwargs],
140
+ ) -> MMFeature:
141
  """
142
  Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
143
  and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
 
163
  - `'jax'`: Return JAX `jnp.ndarray` objects.
164
 
165
  Returns:
166
+ [`MMFeature`]: A [`MMFeature`] with the following fields:
167
 
168
  - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
169
  - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
 
284
  image_token_mask = image_token_mask * array_attention_mask
285
  video_token_mask = video_token_mask * array_attention_mask
286
 
287
+ return MMFeature(data={
288
  **text_inputs,
289
  **image_inputs,
290
  **videos_inputs,