| | from collections import deque |
| | from dataclasses import dataclass |
| |
|
| | import torch |
| | import numpy as np |
| |
|
| |
|
| | @dataclass |
| | class ChunkCum: |
| | cum: int |
| | image_grid_thw: tuple[int, int, int] | None = None |
| | video_grid_thw: tuple[int, int, int] | None = None |
| |
|
| |
|
| | def _visual_token_cums( |
| | sequence_idx: int, |
| | input_ids: torch.Tensor | np.ndarray, |
| | image_token_id: int, |
| | video_token_id: int, |
| | merge_size: int, |
| | focus_size: int, |
| | image_grid_thw: torch.Tensor | np.ndarray | None, |
| | video_grid_thw: torch.Tensor | np.ndarray | None, |
| | **kwargs, |
| | ) -> list[ChunkCum]: |
| | cums: deque[ChunkCum] = deque() |
| |
|
| | video_idx = 0 |
| | frame_idx = 0 |
| | image_idx = 0 |
| | token_idx = 0 |
| | in_video = False |
| | cum = 0 |
| | sequence = input_ids[sequence_idx].tolist() |
| |
|
| | while token_idx < len(sequence): |
| | token = sequence[token_idx] |
| | if token == image_token_id: |
| | assert image_grid_thw is not None, "image_grid_thw must be provided when image_token_id is used" |
| | _, h, w = image_grid_thw[image_idx].tolist() |
| | num_tokens = h * w // (merge_size ** 2) |
| | cums.append(ChunkCum( |
| | cum=num_tokens, |
| | image_grid_thw=(1, h, w), |
| | video_grid_thw=None |
| | ) |
| | ) |
| | token_idx += num_tokens |
| | image_idx += 1 |
| | elif token == video_token_id: |
| | assert video_grid_thw is not None, "video_grid_thw must be provided when video_token_id is used" |
| | t, h, w = video_grid_thw[video_idx].tolist() |
| | assert t % focus_size == 0, f"Number of frames {t} must be divisible by focus_size {focus_size}" |
| | num_tokens = h * w // (merge_size ** 2) |
| | cum += num_tokens |
| |
|
| | if (frame_idx + 1) % focus_size == 0: |
| | cums.append(ChunkCum( |
| | cum=cum, |
| | image_grid_thw=None, |
| | video_grid_thw=(focus_size, h, w), |
| | )) |
| | cum = 0 |
| | in_video = False |
| | else: |
| | in_video = True |
| |
|
| | frame_idx += 1 |
| | if frame_idx == t: |
| | video_idx += 1 |
| | frame_idx = 0 |
| |
|
| | token_idx += num_tokens |
| |
|
| | else: |
| | if not in_video: |
| | cums.append(ChunkCum(cum=cum, image_grid_thw=None, video_grid_thw=None)) |
| | else: |
| | cum += 1 |
| | token_idx += 1 |
| |
|
| | return list(cums) |
| |
|
| |
|
| | def visual_token_cums( |
| | input_ids: torch.Tensor | np.ndarray, |
| | image_token_id: int, |
| | video_token_id: int, |
| | merge_size: int, |
| | focus_size: int, |
| | image_grid_thw: torch.Tensor | np.ndarray | None, |
| | video_grid_thw: torch.Tensor | np.ndarray | None, |
| | **kwargs, |
| | ) -> list[list[ChunkCum]]: |
| | return [ |
| | _visual_token_cums( |
| | sequence_idx=i, |
| | input_ids=input_ids, |
| | image_token_id=image_token_id, |
| | video_token_id=video_token_id, |
| | merge_size=merge_size, |
| | focus_size=focus_size, |
| | image_grid_thw=image_grid_thw, |
| | video_grid_thw=video_grid_thw, |
| | ) |
| | for i in range(input_ids.shape[0]) |
| | ] |
| |
|
| |
|
| | @dataclass |
| | class Chunk: |
| | start: int |
| | end: int |
| | image_grid_thws: list[tuple[int, int, int]] |
| | video_grid_thws: list[tuple[int, int, int]] |
| |
|
| |
|
| | def chunk_tokens( |
| | max_chunk_size: int, |
| | input_ids: torch.Tensor | np.ndarray, |
| | image_token_id: int, |
| | video_token_id: int, |
| | merge_size: int, |
| | focus_size: int, |
| | image_grid_thw: torch.Tensor | np.ndarray | None, |
| | video_grid_thw: torch.Tensor | np.ndarray | None, |
| | **kwargs, |
| | ) -> list[list[Chunk]]: |
| | cums = visual_token_cums( |
| | input_ids=input_ids, |
| | image_token_id=image_token_id, |
| | video_token_id=video_token_id, |
| | merge_size=merge_size, |
| | focus_size=focus_size, |
| | image_grid_thw=image_grid_thw, |
| | video_grid_thw=video_grid_thw, |
| | **kwargs, |
| | ) |
| |
|
| | chunked_cums: list[list[Chunk]] = [] |
| |
|
| | for sequence_cums in cums: |
| | chunks: list[Chunk] = [] |
| | current_chunk_start = 0 |
| | current_chunk_size = 0 |
| | current_image_grid_thws: list[tuple[int, int, int]] = [] |
| | current_video_grid_thws: list[tuple[int, int, int]] = [] |
| |
|
| | for cum in sequence_cums: |
| | if cum.image_grid_thw is not None: |
| | current_image_grid_thws.append(cum.image_grid_thw) |
| | if cum.video_grid_thw is not None: |
| | current_video_grid_thws.append(cum.video_grid_thw) |
| | if current_chunk_size + cum.cum > max_chunk_size: |
| | chunks.append(Chunk( |
| | start=current_chunk_start, |
| | end=current_chunk_start + current_chunk_size, |
| | image_grid_thws=current_image_grid_thws, |
| | video_grid_thws=current_video_grid_thws |
| | )) |
| | current_chunk_start += current_chunk_size |
| | current_chunk_size = 0 |
| | current_image_grid_thws = [] |
| | current_video_grid_thws = [] |
| |
|
| | current_chunk_size += cum.cum |
| |
|
| | if current_chunk_size > 0: |
| | chunks.append(Chunk( |
| | start=current_chunk_start, |
| | end=current_chunk_start + current_chunk_size, |
| | image_grid_thws=current_image_grid_thws, |
| | video_grid_thws=current_video_grid_thws, |
| | )) |
| |
|
| | chunked_cums.append(chunks) |
| |
|
| | num_chunks = max(len(chunks) for chunks in chunked_cums) |
| | for chunks in chunked_cums: |
| | while len(chunks) < num_chunks: |
| | chunks.append(Chunk( |
| | start=chunks[-1].end, |
| | end=chunks[-1].end, |
| | image_grid_thws=[], |
| | video_grid_thws=[], |
| | )) |
| |
|
| | return chunked_cums |
| |
|