zf_qwen3_vl_processor / chunk_utils.py
TYTTYTTYT's picture
add image and video chunk to each token chunk in the processor
0c9c5ce verified
from collections import deque
from dataclasses import dataclass
import torch
import numpy as np
@dataclass
class ChunkCum:
cum: int
image_grid_thw: tuple[int, int, int] | None = None
video_grid_thw: tuple[int, int, int] | None = None
def _visual_token_cums(
sequence_idx: int,
input_ids: torch.Tensor | np.ndarray,
image_token_id: int,
video_token_id: int,
merge_size: int,
focus_size: int,
image_grid_thw: torch.Tensor | np.ndarray | None,
video_grid_thw: torch.Tensor | np.ndarray | None,
**kwargs,
) -> list[ChunkCum]:
cums: deque[ChunkCum] = deque()
video_idx = 0
frame_idx = 0
image_idx = 0
token_idx = 0
in_video = False
cum = 0
sequence = input_ids[sequence_idx].tolist()
while token_idx < len(sequence):
token = sequence[token_idx]
if token == image_token_id:
assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used"
_, h, w = image_grid_thw[image_idx].tolist()
num_tokens = h * w // (merge_size ** 2)
cums.append(ChunkCum(
cum=num_tokens,
image_grid_thw=(1, h, w),
video_grid_thw=None
)
)
token_idx += num_tokens
image_idx += 1
elif token == video_token_id:
assert video_grid_thw is not None, "video_grid_thw must be provided when video_token_id is used"
t, h, w = video_grid_thw[video_idx].tolist()
assert t % focus_size == 0, f"Number of frames {t} must be divisible by focus_size {focus_size}"
num_tokens = h * w // (merge_size ** 2)
cum += num_tokens
if (frame_idx + 1) % focus_size == 0:
cums.append(ChunkCum(
cum=cum,
image_grid_thw=None,
video_grid_thw=(focus_size, h, w),
))
cum = 0
in_video = False
else:
in_video = True
frame_idx += 1
if frame_idx == t:
video_idx += 1
frame_idx = 0
token_idx += num_tokens
else:
if not in_video:
cums.append(ChunkCum(cum=cum, image_grid_thw=None, video_grid_thw=None))
else:
cum += 1
token_idx += 1
return list(cums)
def visual_token_cums(
input_ids: torch.Tensor | np.ndarray,
image_token_id: int,
video_token_id: int,
merge_size: int,
focus_size: int,
image_grid_thw: torch.Tensor | np.ndarray | None,
video_grid_thw: torch.Tensor | np.ndarray | None,
**kwargs,
) -> list[list[ChunkCum]]:
return [
_visual_token_cums(
sequence_idx=i,
input_ids=input_ids,
image_token_id=image_token_id,
video_token_id=video_token_id,
merge_size=merge_size,
focus_size=focus_size,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
)
for i in range(input_ids.shape[0])
]
@dataclass
class Chunk:
start: int
end: int
image_grid_thws: list[tuple[int, int, int]]
video_grid_thws: list[tuple[int, int, int]]
def chunk_tokens(
max_chunk_size: int,
input_ids: torch.Tensor | np.ndarray,
image_token_id: int,
video_token_id: int,
merge_size: int,
focus_size: int,
image_grid_thw: torch.Tensor | np.ndarray | None,
video_grid_thw: torch.Tensor | np.ndarray | None,
**kwargs,
) -> list[list[Chunk]]:
cums = visual_token_cums(
input_ids=input_ids,
image_token_id=image_token_id,
video_token_id=video_token_id,
merge_size=merge_size,
focus_size=focus_size,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
**kwargs,
)
chunked_cums: list[list[Chunk]] = []
for sequence_cums in cums:
chunks: list[Chunk] = []
current_chunk_start = 0
current_chunk_size = 0
current_image_grid_thws: list[tuple[int, int, int]] = []
current_video_grid_thws: list[tuple[int, int, int]] = []
for cum in sequence_cums:
if cum.image_grid_thw is not None:
current_image_grid_thws.append(cum.image_grid_thw)
if cum.video_grid_thw is not None:
current_video_grid_thws.append(cum.video_grid_thw)
if current_chunk_size + cum.cum > max_chunk_size:
chunks.append(Chunk(
start=current_chunk_start,
end=current_chunk_start + current_chunk_size,
image_grid_thws=current_image_grid_thws,
video_grid_thws=current_video_grid_thws
))
current_chunk_start += current_chunk_size
current_chunk_size = 0
current_image_grid_thws = []
current_video_grid_thws = []
current_chunk_size += cum.cum
if current_chunk_size > 0:
chunks.append(Chunk(
start=current_chunk_start,
end=current_chunk_start + current_chunk_size,
image_grid_thws=current_image_grid_thws,
video_grid_thws=current_video_grid_thws,
))
chunked_cums.append(chunks)
num_chunks = max(len(chunks) for chunks in chunked_cums)
for chunks in chunked_cums:
while len(chunks) < num_chunks:
chunks.append(Chunk(
start=chunks[-1].end,
end=chunks[-1].end,
image_grid_thws=[],
video_grid_thws=[],
))
return chunked_cums