from collections import deque from dataclasses import dataclass import torch import numpy as np @dataclass class ChunkCum: cum: int image_grid_thw: tuple[int, int, int] | None = None video_grid_thw: tuple[int, int, int] | None = None def _visual_token_cums( sequence_idx: int, input_ids: torch.Tensor | np.ndarray, image_token_id: int, video_token_id: int, merge_size: int, focus_size: int, image_grid_thw: torch.Tensor | np.ndarray | None, video_grid_thw: torch.Tensor | np.ndarray | None, **kwargs, ) -> list[ChunkCum]: cums: deque[ChunkCum] = deque() video_idx = 0 frame_idx = 0 image_idx = 0 token_idx = 0 in_video = False cum = 0 sequence = input_ids[sequence_idx].tolist() while token_idx < len(sequence): token = sequence[token_idx] if token == image_token_id: assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used" _, h, w = image_grid_thw[image_idx].tolist() num_tokens = h * w // (merge_size ** 2) cums.append(ChunkCum( cum=num_tokens, image_grid_thw=(1, h, w), video_grid_thw=None ) ) token_idx += num_tokens image_idx += 1 elif token == video_token_id: assert video_grid_thw is not None, "video_grid_thw must be provided when video_token_id is used" t, h, w = video_grid_thw[video_idx].tolist() assert t % focus_size == 0, f"Number of frames {t} must be divisible by focus_size {focus_size}" num_tokens = h * w // (merge_size ** 2) cum += num_tokens if (frame_idx + 1) % focus_size == 0: cums.append(ChunkCum( cum=cum, image_grid_thw=None, video_grid_thw=(focus_size, h, w), )) cum = 0 in_video = False else: in_video = True frame_idx += 1 if frame_idx == t: video_idx += 1 frame_idx = 0 token_idx += num_tokens else: if not in_video: cums.append(ChunkCum(cum=cum, image_grid_thw=None, video_grid_thw=None)) else: cum += 1 token_idx += 1 return list(cums) def visual_token_cums( input_ids: torch.Tensor | np.ndarray, image_token_id: int, video_token_id: int, merge_size: int, focus_size: int, image_grid_thw: torch.Tensor | np.ndarray | None, video_grid_thw: torch.Tensor | np.ndarray | None, **kwargs, ) -> list[list[ChunkCum]]: return [ _visual_token_cums( sequence_idx=i, input_ids=input_ids, image_token_id=image_token_id, video_token_id=video_token_id, merge_size=merge_size, focus_size=focus_size, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, ) for i in range(input_ids.shape[0]) ] @dataclass class Chunk: start: int end: int image_grid_thws: list[tuple[int, int, int]] video_grid_thws: list[tuple[int, int, int]] def chunk_tokens( max_chunk_size: int, input_ids: torch.Tensor | np.ndarray, image_token_id: int, video_token_id: int, merge_size: int, focus_size: int, image_grid_thw: torch.Tensor | np.ndarray | None, video_grid_thw: torch.Tensor | np.ndarray | None, **kwargs, ) -> list[list[Chunk]]: cums = visual_token_cums( input_ids=input_ids, image_token_id=image_token_id, video_token_id=video_token_id, merge_size=merge_size, focus_size=focus_size, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, **kwargs, ) chunked_cums: list[list[Chunk]] = [] for sequence_cums in cums: chunks: list[Chunk] = [] current_chunk_start = 0 current_chunk_size = 0 current_image_grid_thws: list[tuple[int, int, int]] = [] current_video_grid_thws: list[tuple[int, int, int]] = [] for cum in sequence_cums: if cum.image_grid_thw is not None: current_image_grid_thws.append(cum.image_grid_thw) if cum.video_grid_thw is not None: current_video_grid_thws.append(cum.video_grid_thw) if current_chunk_size + cum.cum > max_chunk_size: chunks.append(Chunk( start=current_chunk_start, end=current_chunk_start + current_chunk_size, image_grid_thws=current_image_grid_thws, video_grid_thws=current_video_grid_thws )) current_chunk_start += current_chunk_size current_chunk_size = 0 current_image_grid_thws = [] current_video_grid_thws = [] current_chunk_size += cum.cum if current_chunk_size > 0: chunks.append(Chunk( start=current_chunk_start, end=current_chunk_start + current_chunk_size, image_grid_thws=current_image_grid_thws, video_grid_thws=current_video_grid_thws, )) chunked_cums.append(chunks) num_chunks = max(len(chunks) for chunks in chunked_cums) for chunks in chunked_cums: while len(chunks) < num_chunks: chunks.append(Chunk( start=chunks[-1].end, end=chunks[-1].end, image_grid_thws=[], video_grid_thws=[], )) return chunked_cums