| | import glob |
| | import os |
| | import re |
| | import tempfile |
| | import urllib.request |
| | from os import PathLike |
| | from typing import cast, Optional |
| | from urllib.parse import urlparse |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import transformers.image_transforms as image_transforms |
| | import transformers.image_utils as image_utils |
| | import transformers.video_utils as video_utils |
| | from PIL import Image |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.image_utils import ImageInput |
| | from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2TokenizerFast |
| | from transformers.models.siglip import SiglipImageProcessor, SiglipImageProcessorFast |
| | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs |
| | from transformers.tokenization_utils_base import BatchEncoding, TextInput |
| | from transformers.video_utils import VideoInput, VideoMetadata |
| |
|
| | from autogaze.models.autogaze import AutoGaze |
| | from autogaze.models.autogaze import AutoGazeImageProcessor |
| | from autogaze.datasets.video_utils import transform_video_for_pytorch |
| |
|
| |
|
| | def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| | """Find the closest aspect ratio from a set of target ratios. |
| | |
| | Referenced from https://github.com/OpenGVLab/InternVL and llava/mm_utils.py |
| | """ |
| | best_ratio_diff = float("inf") |
| | best_ratio = (1, 1) |
| | area = width * height |
| | for ratio in target_ratios: |
| | target_aspect_ratio = ratio[0] / ratio[1] |
| | ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| | if ratio_diff < best_ratio_diff: |
| | best_ratio_diff = ratio_diff |
| | best_ratio = ratio |
| | elif ratio_diff == best_ratio_diff: |
| | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| | best_ratio = ratio |
| | return best_ratio |
| |
|
| |
|
| | class NVILAProcessorKwargs(ProcessingKwargs, total=False): |
| | _defaults = {} |
| |
|
| |
|
| | def _load_video_frames(video_path: str, num_frames: int = 8) -> list[Image]: |
| | """ |
| | Load video frames from a video file path. |
| | Similar to _load_video in llava/utils/media.py |
| | |
| | Args: |
| | video_path: Path to the video file or directory of frames |
| | num_frames: Number of frames to extract |
| | |
| | Returns: |
| | List of PIL Images representing video frames |
| | """ |
| | vidcap = cv2.VideoCapture(video_path) |
| | |
| | if not vidcap.isOpened(): |
| | raise ValueError(f"Failed to open video: {video_path}") |
| | |
| | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | while frame_count > 0: |
| | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
| | if vidcap.grab(): |
| | break |
| | frame_count -= 1 |
| | else: |
| | vidcap.release() |
| | raise ValueError(f"Video '{video_path}' has no frames.") |
| | |
| | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) |
| | frames = {} |
| | for index in indices: |
| | if index in frames: |
| | continue |
| | vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) |
| | success, frame = vidcap.read() |
| | if not success: |
| | continue |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames[index] = Image.fromarray(frame) |
| | |
| | vidcap.release() |
| | |
| | frames_to_return = [frames[index] for index in indices if index in frames] |
| | if len(frames_to_return) < num_frames: |
| | if frames_to_return: |
| | frames_to_return = frames_to_return + [frames_to_return[-1]] * (num_frames - len(frames_to_return)) |
| | else: |
| | raise ValueError(f"Could not extract any frames from video: {video_path}") |
| | |
| | return frames_to_return |
| |
|
| |
|
| | class NVILAProcessor(ProcessorMixin): |
| | attributes = [ |
| | "image_processor", |
| | "tokenizer", |
| | ] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = "AutoTokenizer" |
| | _auto_class = "AutoProcessor" |
| |
|
| | def __init__( |
| | self, |
| | image_processor: SiglipImageProcessor | SiglipImageProcessorFast, |
| | tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast, |
| | chat_template: str | None = None, |
| | autogaze_model_id: str | None = None, |
| | gazing_ratio_tile: list[float] | float = 0.75, |
| | gazing_ratio_thumbnail: float | None = 0.75, |
| | task_loss_requirement_tile: float = 0.7, |
| | task_loss_requirement_thumbnail: float | None = 0.7, |
| | target_scales: list[int] | None = None, |
| | target_patch_size: int | None = None, |
| | max_tiles_image: int = 12, |
| | num_video_frames: int = 8, |
| | max_tiles_video: int = 8, |
| | num_video_frames_thumbnail: int = 8, |
| | mm_projector_shuffle_num: int = 9, |
| | max_batch_size_autogaze: int = 32, |
| | **kwargs, |
| | ): |
| | super().__init__( |
| | image_processor, |
| | tokenizer, |
| | chat_template=chat_template, |
| | **kwargs, |
| | ) |
| |
|
| | self.image_processor: SiglipImageProcessor | SiglipImageProcessorFast |
| | self.tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast |
| | |
| | |
| | self.autogaze_model_id = autogaze_model_id or "bfshi/AutoGaze" |
| | self.gazing_ratio_tile = gazing_ratio_tile |
| | self.gazing_ratio_thumbnail = gazing_ratio_thumbnail |
| | self.task_loss_requirement_tile = task_loss_requirement_tile |
| | self.task_loss_requirement_thumbnail = task_loss_requirement_thumbnail |
| | self.target_scales = target_scales or [56, 112, 224, 448] |
| | self.target_patch_size = target_patch_size or 16 |
| | |
| | |
| | self.max_tiles_image = max_tiles_image |
| | self.num_video_frames = num_video_frames |
| | self.max_tiles_video = max_tiles_video |
| | self.num_video_frames_thumbnail = num_video_frames_thumbnail |
| | self.mm_projector_shuffle_num = mm_projector_shuffle_num |
| | self.max_batch_size_autogaze = max_batch_size_autogaze |
| | |
| | |
| | self._autogaze_model = None |
| | self._autogaze_model = AutoGaze.from_pretrained( |
| | self.autogaze_model_id, |
| | device_map=None, |
| | ) |
| | self._autogaze_model.to("cuda").eval() |
| | print("AutoGaze loaded successfully in processor") |
| |
|
| | def __call__( |
| | self, |
| | *, |
| | text: TextInput | list[TextInput], |
| | images: ImageInput | None = None, |
| | videos: VideoInput | None = None, |
| | **kwargs: Unpack[NVILAProcessorKwargs], |
| | ) -> BatchFeature: |
| | normalized_text, normalized_images, normalized_videos = self._normalize_inputs( |
| | text=text, |
| | images=images, |
| | videos=videos, |
| | ) |
| |
|
| | images_inputs, image_token_padding_strategy = ( |
| | self._preprocess_images( |
| | normalized_images, |
| | **kwargs, |
| | ) |
| | if len(normalized_images) > 0 |
| | else (BatchFeature(), []) |
| | ) |
| |
|
| | videos_inputs = ( |
| | self._preprocess_videos( |
| | normalized_videos, |
| | **kwargs, |
| | ) |
| | if len(normalized_videos) > 0 |
| | else (BatchFeature(), []) |
| | ) |
| |
|
| | |
| | gazing_info = None |
| | video_token_padding_strategy = [] |
| | skip_tiles_gaze = self._should_gaze_all_patches(self.gazing_ratio_tile, self.task_loss_requirement_tile) |
| | skip_thumbs_gaze = self._should_gaze_all_patches(self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail) |
| | can_construct_without_autogaze = skip_tiles_gaze and skip_thumbs_gaze |
| | if len(normalized_videos) > 0 and (self._autogaze_model is not None or can_construct_without_autogaze): |
| | gazing_info = self._get_gazing_info_from_videos(videos_inputs) |
| | |
| | |
| | |
| | |
| | |
| | shuffle_num = self.mm_projector_shuffle_num |
| | ns_list = videos_inputs["num_spatial_tiles_each_video"] |
| |
|
| | for vid_idx in range(len(gazing_info["if_padded_gazing_tiles"])): |
| | tiles_if_pad = gazing_info["if_padded_gazing_tiles"][vid_idx] |
| | tiles_num_gaze = gazing_info["num_gazing_each_frame_tiles"][vid_idx] |
| | thumbs_if_pad = gazing_info["if_padded_gazing_thumbnails"][vid_idx] |
| | thumbs_num_gaze = gazing_info["num_gazing_each_frame_thumbnails"][vid_idx] |
| |
|
| | ns = ns_list[vid_idx] |
| | num_tiles = tiles_if_pad.shape[0] |
| | T_tile = tiles_num_gaze.shape[1] |
| | tc = num_tiles // ns |
| | total_frames = tc * T_tile |
| |
|
| | |
| | tile_non_padded = [] |
| | for t_idx in range(num_tiles): |
| | frame_sizes = tiles_num_gaze[t_idx].tolist() |
| | frame_pad_segs = tiles_if_pad[t_idx].split(frame_sizes) |
| | tile_non_padded.append( |
| | [int((~seg).sum().item()) for seg in frame_pad_segs] |
| | ) |
| |
|
| | total_tokens = 0 |
| |
|
| | |
| | for g in range(total_frames): |
| | chunk = g // T_tile |
| | f_in_chunk = g % T_tile |
| | frame_count = sum( |
| | tile_non_padded[chunk * ns + s][f_in_chunk] |
| | for s in range(ns) |
| | ) |
| | total_tokens += (frame_count + shuffle_num - 1) // shuffle_num |
| |
|
| | |
| | for th_idx in range(thumbs_if_pad.shape[0]): |
| | frame_sizes = thumbs_num_gaze[th_idx].tolist() |
| | frame_pad_segs = thumbs_if_pad[th_idx].split(frame_sizes) |
| | non_pad = sum(int((~seg).sum().item()) for seg in frame_pad_segs) |
| | total_tokens += (non_pad + shuffle_num - 1) // shuffle_num |
| |
|
| | video_token_padding_strategy.append([total_tokens]) |
| | else: |
| | video_token_padding_strategy = [[(self.num_video_frames + self.num_video_frames_thumbnail) * 118] * len(normalized_videos)] |
| |
|
| | |
| | |
| | if len(normalized_videos) > 0: |
| | videos_inputs.pop("pixel_values_videos_tiles_autogaze", None) |
| | videos_inputs.pop("pixel_values_videos_thumbnails_autogaze", None) |
| |
|
| | text_inputs = self._preprocess_text( |
| | normalized_text, |
| | image_token_padding_strategy=image_token_padding_strategy, |
| | video_token_padding_strategy=video_token_padding_strategy, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | batch_feature = BatchFeature( |
| | { |
| | **text_inputs, |
| | **images_inputs, |
| | **videos_inputs, |
| | } |
| | ) |
| |
|
| | |
| | if gazing_info is not None: |
| | batch_feature["gazing_info"] = gazing_info |
| |
|
| | return batch_feature |
| |
|
| | def batch_decode(self, *args, **kwargs) -> list[str]: |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | def _normalize_inputs( |
| | self, |
| | *, |
| | text: TextInput | list[TextInput], |
| | images: ImageInput | None, |
| | videos: VideoInput | None, |
| | ) -> tuple[list[str], list[Image], list[list[Image]]]: |
| | if isinstance(text, list): |
| | normalized_text = text |
| | else: |
| | normalized_text = [text] |
| |
|
| | if images is not None and images != []: |
| | image_flat_list = cast(list, image_utils.make_flat_list_of_images(images)) |
| | normalized_images = [cast(Image, image_transforms.to_pil_image(image)) for image in image_flat_list] |
| | else: |
| | normalized_images = [] |
| |
|
| | if videos is not None and videos != []: |
| | |
| | |
| | if not isinstance(videos, (list, tuple)): |
| | videos = [videos] |
| | |
| | normalized_videos = [] |
| | |
| | num_frames = self.num_video_frames |
| | for video_input in videos: |
| | if isinstance(video_input, str): |
| | parsed = urlparse(video_input) |
| | if parsed.scheme in ("http", "https"): |
| | suffix = os.path.splitext(parsed.path)[1] or ".mp4" |
| | tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) |
| | try: |
| | urllib.request.urlretrieve(video_input, tmp.name) |
| | video_frames = _load_video_frames(tmp.name, num_frames=num_frames) |
| | finally: |
| | tmp.close() |
| | os.unlink(tmp.name) |
| | else: |
| | video_frames = _load_video_frames(video_input, num_frames=num_frames) |
| | normalized_videos.append(video_frames) |
| | elif isinstance(video_input, (list, tuple)): |
| | |
| | normalized_videos.append([ |
| | cast(Image, image_transforms.to_pil_image(image)) for image in video_input |
| | ]) |
| | else: |
| | |
| | try: |
| | video_list = cast(list[list], video_utils.make_batched_videos([video_input])) |
| | normalized_videos.extend([ |
| | [cast(Image, image_transforms.to_pil_image(image)) for image in video] |
| | for video in video_list |
| | ]) |
| | except Exception: |
| | raise ValueError( |
| | f"Unsupported video input type: {type(video_input)}. " |
| | "Expected str (file path) or list of PIL Images." |
| | ) |
| | else: |
| | normalized_videos = [] |
| |
|
| | return normalized_text, normalized_images, normalized_videos |
| |
|
| | def _preprocess_images( |
| | self, |
| | images: list[Image], |
| | **kwargs: Unpack[NVILAProcessorKwargs], |
| | ) -> tuple[BatchFeature, list[list[int]]]: |
| | """Preprocess images into spatial tiles plus a thumbnail. |
| | |
| | Each image is split into a grid of spatial tiles whose count is at |
| | most ``max_tiles_image``. A thumbnail (the whole image resized to |
| | ``image_size × image_size``) is appended. Every tile / thumbnail |
| | is a single-frame "video" of shape ``(1, C, H, W)``. No AutoGaze |
| | is applied — all patches are kept. |
| | |
| | Returns: |
| | A tuple ``(images_inputs, padding_strategy)`` where |
| | ``images_inputs`` is a ``BatchFeature`` with: |
| | |
| | - ``"pixel_values_images_tiles"`` – list of tensors, one per |
| | image, each ``(num_tiles_i, 1, C, H, W)``. |
| | - ``"pixel_values_images_thumbnails"`` – list of tensors, one |
| | per image, each ``(1, 1, C, H, W)``. |
| | - ``"num_spatial_tiles_each_image"`` – list of ints. |
| | |
| | ``padding_strategy`` is a list (one per image) of |
| | ``[total_tokens]`` used for text-token padding. |
| | """ |
| | merged_kwargs = self._merge_kwargs( |
| | NVILAProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | if hasattr(self.image_processor, "size"): |
| | image_size = self.image_processor.size.get("height", 392) |
| | else: |
| | image_size = 392 |
| |
|
| | shuffle_num = self.mm_projector_shuffle_num |
| |
|
| | num_patches_each_scale = [ |
| | (s // self.target_patch_size) ** 2 for s in self.target_scales |
| | ] |
| | total_patches_per_frame = sum(num_patches_each_scale) |
| |
|
| | pixel_values_images_tiles: list[torch.Tensor] = [] |
| | pixel_values_images_thumbnails: list[torch.Tensor] = [] |
| | num_spatial_tiles_each_image: list[int] = [] |
| | padding_strategy: list[list[int]] = [] |
| |
|
| | for image in images: |
| | image = image.convert("RGB") |
| | orig_width, orig_height = image.size |
| |
|
| | max_spatial_tiles = max(self.max_tiles_image, 1) |
| | aspect_ratio = orig_width / orig_height |
| |
|
| | target_ratios = { |
| | (i, j) |
| | for n in range(1, max_spatial_tiles + 1) |
| | for i in range(1, n + 1) |
| | for j in range(1, n + 1) |
| | if 1 <= i * j <= max_spatial_tiles |
| | } |
| | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| |
|
| | target_aspect_ratio = _find_closest_aspect_ratio( |
| | aspect_ratio, target_ratios, orig_width, orig_height, image_size |
| | ) |
| |
|
| | target_width = image_size * target_aspect_ratio[0] |
| | target_height = image_size * target_aspect_ratio[1] |
| | num_tiles = target_aspect_ratio[0] * target_aspect_ratio[1] |
| | num_cols = target_aspect_ratio[0] |
| |
|
| | resized = image.resize((target_width, target_height)) |
| |
|
| | |
| | all_tile_images: list[Image] = [] |
| | for tile_idx in range(num_tiles): |
| | col = tile_idx % num_cols |
| | row = tile_idx // num_cols |
| | box = ( |
| | col * image_size, |
| | row * image_size, |
| | (col + 1) * image_size, |
| | (row + 1) * image_size, |
| | ) |
| | all_tile_images.append(resized.crop(box)) |
| |
|
| | thumbnail = image.resize((image_size, image_size)) |
| | all_images_for_siglip = all_tile_images + [thumbnail] |
| |
|
| | |
| | siglip_processed = self.image_processor( |
| | all_images_for_siglip, **merged_kwargs["images_kwargs"], |
| | )["pixel_values"] |
| | if not isinstance(siglip_processed, torch.Tensor): |
| | siglip_processed = torch.tensor(np.array(siglip_processed)) |
| |
|
| | |
| | tiles_pv = siglip_processed[:num_tiles].unsqueeze(1) |
| | thumb_pv = siglip_processed[num_tiles:].unsqueeze(1) |
| |
|
| | pixel_values_images_tiles.append(tiles_pv) |
| | pixel_values_images_thumbnails.append(thumb_pv) |
| | num_spatial_tiles_each_image.append(num_tiles) |
| |
|
| | |
| | tiles_tokens = (num_tiles * total_patches_per_frame + shuffle_num - 1) // shuffle_num |
| | thumb_tokens = (total_patches_per_frame + shuffle_num - 1) // shuffle_num |
| | padding_strategy.append([tiles_tokens + thumb_tokens]) |
| |
|
| | images_inputs = BatchFeature({ |
| | "pixel_values_images_tiles": pixel_values_images_tiles, |
| | "pixel_values_images_thumbnails": pixel_values_images_thumbnails, |
| | "num_spatial_tiles_each_image": num_spatial_tiles_each_image, |
| | }) |
| |
|
| | return images_inputs, padding_strategy |
| |
|
| | def _preprocess_text( |
| | self, |
| | text: list[str], |
| | *, |
| | image_token_padding_strategy: list[list[int]], |
| | video_token_padding_strategy: list[list[int]], |
| | **kwargs: Unpack[NVILAProcessorKwargs], |
| | ) -> BatchEncoding: |
| | |
| | messages = [[ |
| | {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, |
| | {"role": "user", "content": t} |
| | ] for t in text] |
| | text = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| |
|
| | |
| | assert isinstance(self.tokenizer.image_token, str) |
| | assert isinstance(self.tokenizer.video_token, str) |
| |
|
| | for media_token, padding_strategy in ( |
| | (self.tokenizer.image_token, image_token_padding_strategy), |
| | (self.tokenizer.video_token, video_token_padding_strategy), |
| | ): |
| | assert sum([s.count(media_token) for s in text]) == len(padding_strategy) |
| |
|
| | |
| | pad_lens = [len(x) for x in padding_strategy] |
| | text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text] |
| |
|
| | |
| | pad_lens = [y for x in padding_strategy for y in x] |
| | text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text] |
| |
|
| | merged_kwargs = self._merge_kwargs( |
| | NVILAProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | text_inputs = self.tokenizer( |
| | text=text, |
| | **merged_kwargs["text_kwargs"], |
| | ) |
| |
|
| | return text_inputs |
| |
|
| | def _preprocess_videos( |
| | self, |
| | videos: list[list[Image]], |
| | **kwargs: Unpack[NVILAProcessorKwargs], |
| | ) -> BatchFeature: |
| | """Preprocess videos into spatiotemporal tiles and thumbnails. |
| | |
| | Each video is split into a grid of spatiotemporal tiles and a set of |
| | low-resolution thumbnail frames. Both SigLIP-processed and |
| | AutoGaze-processed copies are produced. |
| | |
| | Spatial tiling |
| | Every frame is resized so that its dimensions become a multiple of |
| | ``image_size`` (from the SigLIP image processor) and then cropped |
| | into ``(cols, rows)`` spatial tiles, where ``cols * rows <= |
| | max_tiles_video``. The best ``(cols, rows)`` is chosen by matching |
| | the original frame aspect ratio (same logic as |
| | ``dynamic_preprocess`` in ``llava/mm_utils.py``). |
| | |
| | Temporal chunking |
| | The T sampled frames are divided into ``T // max_num_frames`` |
| | consecutive chunks of ``max_num_frames`` frames each, where |
| | ``max_num_frames`` comes from the AutoGaze model config. |
| | ``T`` must be divisible by ``max_num_frames``. |
| | |
| | Tile ordering |
| | Tiles are ordered **temporal-chunk-first**: all spatial tiles for |
| | the first temporal chunk, then all spatial tiles for the second |
| | temporal chunk, and so on. |
| | |
| | Thumbnails |
| | Each frame is also resized to ``image_size × image_size`` to form a |
| | thumbnail. If the number of frames exceeds |
| | ``num_video_frames_thumbnail``, thumbnails are uniformly subsampled |
| | (every k-th frame) to that count. Each thumbnail is treated as a |
| | single-frame video (temporal dim = 1). |
| | |
| | Args: |
| | videos: List of videos, where each video is a list of PIL Images |
| | (one per frame). |
| | **kwargs: Additional keyword arguments forwarded to the SigLIP |
| | image processor. |
| | |
| | Returns: |
| | A tuple ``(videos_inputs, padding_strategy)`` where |
| | |
| | ``videos_inputs`` is a ``BatchFeature`` dict with the keys: |
| | |
| | - ``"pixel_values_videos_tiles"`` – list of tensors, one per video. |
| | Each tensor has shape ``(num_tiles, T_tile, C, H, W)`` where |
| | ``num_tiles = num_spatial_tiles * temporal_chunks``, |
| | ``T_tile = max_num_frames`` (from AutoGaze config), |
| | and ``H = W = image_size``. |
| | Processed by the SigLIP image processor. |
| | - ``"pixel_values_videos_thumbnails"`` – list of tensors, one per |
| | video. Each tensor has shape |
| | ``(T_thumbnail, 1, C, H, W)`` where ``T_thumbnail <= |
| | num_video_frames_thumbnail`` and ``H = W = image_size``. |
| | Processed by the SigLIP image processor. |
| | - ``"pixel_values_videos_tiles_autogaze"`` *(optional)* – same |
| | structure as ``pixel_values_videos_tiles`` but processed by the |
| | AutoGaze ``transform_video_for_pytorch`` transform. |
| | Only present when AutoGaze is available. |
| | - ``"pixel_values_videos_thumbnails_autogaze"`` *(optional)* – same |
| | structure as ``pixel_values_videos_thumbnails`` but processed by |
| | the AutoGaze transform. Only present when AutoGaze is available. |
| | |
| | ``padding_strategy`` is a list (one entry per video) of lists of |
| | ints used for text-token padding. Currently a placeholder; the |
| | final strategy depends on downstream gazing results. |
| | """ |
| | merged_kwargs = self._merge_kwargs( |
| | NVILAProcessorKwargs, |
| | tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | if hasattr(self.image_processor, "size"): |
| | image_size = self.image_processor.size.get("height", 392) |
| | else: |
| | image_size = 392 |
| |
|
| | |
| | if self._autogaze_model is not None: |
| | autogaze_max_num_frames = self._autogaze_model.config.max_num_frames |
| | else: |
| | autogaze_max_num_frames = 16 |
| |
|
| | |
| | autogaze_transform = None |
| | largest_scale = max(self.target_scales) |
| | autogaze_transform = AutoGazeImageProcessor.from_pretrained( |
| | self.autogaze_model_id, |
| | size=(largest_scale, largest_scale), |
| | ) |
| |
|
| | pixel_values_videos_tiles = [] |
| | pixel_values_videos_thumbnails = [] |
| | pixel_values_videos_tiles_autogaze = [] |
| | pixel_values_videos_thumbnails_autogaze = [] |
| | num_spatial_tiles_each_video = [] |
| |
|
| | for video in videos: |
| | video = [img.convert("RGB") for img in video] |
| | num_frames = len(video) |
| | orig_width, orig_height = video[0].size |
| |
|
| | |
| | temporal_chunks = num_frames // autogaze_max_num_frames |
| | assert temporal_chunks >= 1 and num_frames % autogaze_max_num_frames == 0, ( |
| | f"Number of frames ({num_frames}) must be divisible by " |
| | f"AutoGaze max_num_frames ({autogaze_max_num_frames})" |
| | ) |
| |
|
| | |
| | |
| | max_spatial_tiles = max(self.max_tiles_video, 1) |
| |
|
| | |
| | aspect_ratio = orig_width / orig_height |
| |
|
| | target_ratios = { |
| | (i, j) |
| | for n in range(1, max_spatial_tiles + 1) |
| | for i in range(1, n + 1) |
| | for j in range(1, n + 1) |
| | if 1 <= i * j <= max_spatial_tiles |
| | } |
| | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| |
|
| | target_aspect_ratio = _find_closest_aspect_ratio( |
| | aspect_ratio, target_ratios, orig_width, orig_height, image_size |
| | ) |
| |
|
| | target_width = image_size * target_aspect_ratio[0] |
| | target_height = image_size * target_aspect_ratio[1] |
| | num_spatial_tiles = target_aspect_ratio[0] * target_aspect_ratio[1] |
| | num_cols = target_aspect_ratio[0] |
| |
|
| | |
| | |
| | spatial_tile_frames = [[] for _ in range(num_spatial_tiles)] |
| | thumbnail_frames = [] |
| |
|
| | for frame in video: |
| | |
| | resized_frame = frame.resize((target_width, target_height)) |
| |
|
| | |
| | for tile_idx in range(num_spatial_tiles): |
| | col = tile_idx % num_cols |
| | row = tile_idx // num_cols |
| | box = ( |
| | col * image_size, |
| | row * image_size, |
| | (col + 1) * image_size, |
| | (row + 1) * image_size, |
| | ) |
| | tile = resized_frame.crop(box) |
| | spatial_tile_frames[tile_idx].append(tile) |
| |
|
| | |
| | thumbnail = frame.resize((image_size, image_size)) |
| | thumbnail_frames.append(thumbnail) |
| |
|
| | |
| | |
| | |
| | num_tiles = temporal_chunks * num_spatial_tiles |
| | T_tile = autogaze_max_num_frames |
| | all_tile_images = [] |
| | for t_chunk in range(temporal_chunks): |
| | for spatial_idx in range(num_spatial_tiles): |
| | start = t_chunk * T_tile |
| | end = start + T_tile |
| | all_tile_images.extend(spatial_tile_frames[spatial_idx][start:end]) |
| |
|
| | |
| | siglip_processed = self.image_processor( |
| | all_tile_images, **merged_kwargs["images_kwargs"], |
| | )["pixel_values"] |
| | if not isinstance(siglip_processed, torch.Tensor): |
| | siglip_processed = torch.tensor(np.array(siglip_processed)) |
| | video_tiles_siglip = siglip_processed.reshape(num_tiles, T_tile, *siglip_processed.shape[1:]) |
| | pixel_values_videos_tiles.append(video_tiles_siglip) |
| |
|
| | |
| | if autogaze_transform is not None: |
| | all_tile_np = np.stack([np.array(f) for f in all_tile_images]) |
| | autogaze_processed = transform_video_for_pytorch(all_tile_np, autogaze_transform) |
| | video_tiles_autogaze = autogaze_processed.reshape(num_tiles, T_tile, *autogaze_processed.shape[1:]) |
| | pixel_values_videos_tiles_autogaze.append(video_tiles_autogaze) |
| |
|
| | |
| | |
| | if len(thumbnail_frames) > self.num_video_frames_thumbnail: |
| | step = len(thumbnail_frames) // self.num_video_frames_thumbnail |
| | sampled_thumbnail_frames = thumbnail_frames[::step][: self.num_video_frames_thumbnail] |
| | else: |
| | sampled_thumbnail_frames = thumbnail_frames |
| |
|
| | T_thumb = len(sampled_thumbnail_frames) |
| |
|
| | |
| | siglip_processed = self.image_processor( |
| | sampled_thumbnail_frames, **merged_kwargs["images_kwargs"], |
| | )["pixel_values"] |
| | if not isinstance(siglip_processed, torch.Tensor): |
| | siglip_processed = torch.tensor(np.array(siglip_processed)) |
| | |
| | video_thumbnails_siglip = siglip_processed.unsqueeze(1) |
| | pixel_values_videos_thumbnails.append(video_thumbnails_siglip) |
| |
|
| | |
| | if autogaze_transform is not None: |
| | all_thumb_np = np.stack([np.array(f) for f in sampled_thumbnail_frames]) |
| | autogaze_processed = transform_video_for_pytorch(all_thumb_np, autogaze_transform) |
| | video_thumbnails_autogaze = autogaze_processed.unsqueeze(1) |
| | pixel_values_videos_thumbnails_autogaze.append(video_thumbnails_autogaze) |
| |
|
| | num_spatial_tiles_each_video.append(num_spatial_tiles) |
| |
|
| | print( |
| | f"Video tiling: {num_frames} frames @ {orig_width}x{orig_height} → " |
| | f"{num_spatial_tiles} spatial × {temporal_chunks} temporal = " |
| | f"{num_spatial_tiles * temporal_chunks} tiles, each " |
| | f"{autogaze_max_num_frames}×{image_size}×{image_size}; " |
| | f"{len(sampled_thumbnail_frames)} thumbnail frames" |
| | ) |
| |
|
| | |
| | videos_inputs = BatchFeature( |
| | { |
| | "pixel_values_videos_tiles": pixel_values_videos_tiles, |
| | "pixel_values_videos_thumbnails": pixel_values_videos_thumbnails, |
| | "num_spatial_tiles_each_video": num_spatial_tiles_each_video, |
| | } |
| | ) |
| | if pixel_values_videos_tiles_autogaze: |
| | videos_inputs["pixel_values_videos_tiles_autogaze"] = pixel_values_videos_tiles_autogaze |
| | if pixel_values_videos_thumbnails_autogaze: |
| | videos_inputs["pixel_values_videos_thumbnails_autogaze"] = pixel_values_videos_thumbnails_autogaze |
| |
|
| | return videos_inputs |
| | |
| | @staticmethod |
| | def _should_gaze_all_patches(gazing_ratio, task_loss_requirement) -> bool: |
| | """Return True when the gazing config means every patch is kept. |
| | |
| | This is the case when ``gazing_ratio`` is ``None`` (no gazing at all), |
| | or when ``gazing_ratio == 1`` (keep 100 %) **and** |
| | ``task_loss_requirement is None`` (no adaptive pruning). |
| | """ |
| | if gazing_ratio is None: |
| | return True |
| | if task_loss_requirement is not None: |
| | return False |
| | if isinstance(gazing_ratio, (list, tuple)): |
| | return all(r == 1 for r in gazing_ratio) |
| | return gazing_ratio == 1 |
| |
|
| | @staticmethod |
| | def _sort_gazing_pos_per_frame( |
| | gazing_pos: torch.Tensor, |
| | if_padded: torch.Tensor, |
| | num_gazing_each_frame: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Sort non-padded gazing positions in ascending order within each frame. |
| | |
| | Padded positions are left untouched at the end of each frame's segment |
| | so that the total count (padded + non-padded) per frame is unchanged. |
| | |
| | Args: |
| | gazing_pos: ``(B, N)`` tensor of gazing patch indices. |
| | if_padded: ``(B, N)`` bool tensor (``True`` = padded / dummy). |
| | num_gazing_each_frame: ``(B, T)`` tensor giving the number of |
| | gazing positions (padded + non-padded) for each frame. |
| | |
| | Returns: |
| | A new ``(B, N)`` tensor with the same values as *gazing_pos* |
| | except that the non-padded entries within every frame are sorted. |
| | """ |
| | sorted_pos = gazing_pos.clone() |
| | B, _ = gazing_pos.shape |
| | T = num_gazing_each_frame.shape[1] |
| |
|
| | for b in range(B): |
| | offset = 0 |
| | for t in range(T): |
| | count = int(num_gazing_each_frame[b, t].item()) |
| | frame_pos = gazing_pos[b, offset : offset + count] |
| | frame_pad = if_padded[b, offset : offset + count] |
| |
|
| | |
| | real_mask = ~frame_pad |
| | real_pos = frame_pos[real_mask] |
| |
|
| | |
| | real_pos_sorted = real_pos.sort()[0] |
| |
|
| | |
| | real_indices = real_mask.nonzero(as_tuple=True)[0] |
| | sorted_pos[b, offset + real_indices] = real_pos_sorted |
| |
|
| | offset += count |
| |
|
| | return sorted_pos |
| |
|
| | def _run_autogaze_batched( |
| | self, |
| | all_videos: torch.Tensor, |
| | autogaze_device: torch.device, |
| | cpu_device: torch.device, |
| | gazing_ratio, |
| | task_loss_requirement, |
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """Run AutoGaze in minibatches and return combined results on CPU. |
| | |
| | Different minibatches may produce different per-frame gazing counts |
| | (e.g. when ``task_loss_requirement`` triggers adaptive pruning). |
| | This method pads each frame's segment to the *maximum* count across |
| | all minibatches so that the results can be concatenated along the |
| | batch dimension. |
| | |
| | Args: |
| | all_videos: ``(B, T, C, H, W)`` tensor of videos to process. |
| | autogaze_device: Device where AutoGaze runs (typically CUDA). |
| | cpu_device: Device for the returned tensors (typically CPU). |
| | gazing_ratio: Gazing ratio to pass to AutoGaze. |
| | task_loss_requirement: Task loss requirement to pass to AutoGaze. |
| | |
| | Returns: |
| | A tuple ``(gazing_pos, if_padded, num_gazing)`` where |
| | |
| | - ``gazing_pos`` is ``(B, N_max)`` on *cpu_device* |
| | - ``if_padded`` is ``(B, N_max)`` bool on *cpu_device* |
| | - ``num_gazing`` is ``(B, T)`` on *cpu_device* |
| | |
| | ``N_max = sum(max_per_frame)`` where ``max_per_frame[t]`` is the |
| | largest per-frame count across all minibatches. |
| | """ |
| | total = all_videos.shape[0] |
| | bs = self.max_batch_size_autogaze |
| |
|
| | batch_results: list[dict] = [] |
| |
|
| | with torch.inference_mode(): |
| | for start in range(0, total, bs): |
| | batch = all_videos[start : start + bs] |
| |
|
| | gaze = self._autogaze_model( |
| | {"video": batch.to(autogaze_device)}, |
| | gazing_ratio=gazing_ratio, |
| | task_loss_requirement=task_loss_requirement, |
| | target_scales=self.target_scales, |
| | target_patch_size=self.target_patch_size, |
| | ) |
| |
|
| | ng = gaze["num_gazing_each_frame"] |
| | if isinstance(ng, list): |
| | ng = torch.tensor(ng, device=cpu_device, dtype=torch.long) |
| | elif not isinstance(ng, torch.Tensor): |
| | ng = torch.tensor(ng, device=cpu_device, dtype=torch.long) |
| | else: |
| | ng = ng.to(cpu_device) |
| | if ng.dim() == 2: |
| | ng = ng[0] |
| |
|
| | batch_results.append({ |
| | "gazing_pos": gaze["gazing_pos"].to(cpu_device), |
| | "if_padded": gaze["if_padded_gazing"].to(cpu_device), |
| | "num_gazing": ng, |
| | "batch_size": batch.shape[0], |
| | }) |
| |
|
| | |
| | if len(batch_results) == 1: |
| | r = batch_results[0] |
| | num_gazing = r["num_gazing"].unsqueeze(0).expand(total, -1).contiguous() |
| | return r["gazing_pos"], r["if_padded"], num_gazing |
| |
|
| | |
| | all_ng = torch.stack([r["num_gazing"] for r in batch_results], dim=0) |
| | max_per_frame = all_ng.max(dim=0).values |
| | max_N = int(max_per_frame.sum().item()) |
| | T = max_per_frame.shape[0] |
| |
|
| | padded_pos_list = [] |
| | padded_mask_list = [] |
| |
|
| | for r in batch_results: |
| | src_pos = r["gazing_pos"] |
| | src_pad = r["if_padded"] |
| | src_ng = r["num_gazing"] |
| | mini_B = r["batch_size"] |
| |
|
| | if int(src_ng.sum().item()) == max_N: |
| | padded_pos_list.append(src_pos) |
| | padded_mask_list.append(src_pad) |
| | continue |
| |
|
| | dst_pos = torch.zeros(mini_B, max_N, device=cpu_device, dtype=src_pos.dtype) |
| | dst_pad = torch.ones(mini_B, max_N, device=cpu_device, dtype=torch.bool) |
| |
|
| | src_off = 0 |
| | dst_off = 0 |
| | for t in range(T): |
| | sc = int(src_ng[t].item()) |
| | dc = int(max_per_frame[t].item()) |
| | dst_pos[:, dst_off : dst_off + sc] = src_pos[:, src_off : src_off + sc] |
| | dst_pad[:, dst_off : dst_off + sc] = src_pad[:, src_off : src_off + sc] |
| | src_off += sc |
| | dst_off += dc |
| |
|
| | padded_pos_list.append(dst_pos) |
| | padded_mask_list.append(dst_pad) |
| |
|
| | gazing_pos = torch.cat(padded_pos_list, dim=0) |
| | if_padded = torch.cat(padded_mask_list, dim=0) |
| | num_gazing = max_per_frame.unsqueeze(0).expand(total, -1).contiguous() |
| |
|
| | return gazing_pos, if_padded, num_gazing |
| |
|
| | def _get_gazing_info_from_videos( |
| | self, |
| | videos_inputs: BatchFeature, |
| | ) -> Optional[dict]: |
| | """Run AutoGaze on the preprocessed tiles and thumbnails. |
| | |
| | All tiles from all videos are batched together (they share the same |
| | temporal dimension ``T_tile``). Similarly, all thumbnails are batched |
| | together (temporal dim = 1). AutoGaze is run once on each batch and |
| | the results are split back per-video. |
| | |
| | When a gazing ratio is 1 and the corresponding task_loss_requirement is |
| | None (or gazing_ratio is None), all patches are kept and AutoGaze is |
| | skipped for that component. If both tiles and thumbnails meet this |
| | condition, AutoGaze is not invoked at all. |
| | |
| | Args: |
| | videos_inputs: The ``BatchFeature`` returned by |
| | ``_preprocess_videos``, which must contain the keys |
| | ``pixel_values_videos_tiles_autogaze`` and |
| | ``pixel_values_videos_thumbnails_autogaze`` (unless the |
| | corresponding component can skip AutoGaze). |
| | |
| | Returns: |
| | A dict with the following keys (or ``None`` if AutoGaze is |
| | unavailable or the required inputs are missing): |
| | |
| | - ``"gazing_pos_tiles"`` – list of tensors, one per video, each |
| | shaped ``(num_tiles_i, N)``. |
| | - ``"num_gazing_each_frame_tiles"`` – list of tensors, one per |
| | video, each shaped ``(num_tiles_i, T_tile)``. |
| | - ``"if_padded_gazing_tiles"`` – list of bool tensors, one per |
| | video, each shaped ``(num_tiles_i, N)``. |
| | - ``"gazing_pos_thumbnails"`` – list of tensors, one per video, |
| | each shaped ``(T_thumb_i, N')``. |
| | - ``"num_gazing_each_frame_thumbnails"`` – list of tensors, one per |
| | video, each shaped ``(T_thumb_i, 1)``. |
| | - ``"if_padded_gazing_thumbnails"`` – list of bool tensors, one per |
| | video, each shaped ``(T_thumb_i, N')``. |
| | """ |
| | skip_tiles = self._should_gaze_all_patches( |
| | self.gazing_ratio_tile, self.task_loss_requirement_tile |
| | ) |
| | skip_thumbnails = self._should_gaze_all_patches( |
| | self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail |
| | ) |
| | need_autogaze = not skip_tiles or not skip_thumbnails |
| |
|
| | if need_autogaze and self._autogaze_model is None: |
| | return None |
| |
|
| | |
| | siglip_tiles = videos_inputs["pixel_values_videos_tiles"] |
| | siglip_thumbs = videos_inputs["pixel_values_videos_thumbnails"] |
| | num_tiles_per_video = [t.shape[0] for t in siglip_tiles] |
| | num_thumbs_per_video = [t.shape[0] for t in siglip_thumbs] |
| |
|
| | device = torch.device("cpu") |
| | autogaze_device = torch.device("cuda") if torch.cuda.is_available() else device |
| |
|
| | |
| | num_patches_each_scale = [ |
| | (s // self.target_patch_size) ** 2 for s in self.target_scales |
| | ] |
| | total_patches_per_frame = sum(num_patches_each_scale) |
| |
|
| | |
| | if need_autogaze: |
| | current_device = next(self._autogaze_model.parameters()).device |
| | if current_device != autogaze_device: |
| | self._autogaze_model = self._autogaze_model.to(autogaze_device) |
| |
|
| | |
| | if skip_tiles: |
| | total_tiles = sum(num_tiles_per_video) |
| | T_tile = siglip_tiles[0].shape[1] |
| | per_frame_pos = torch.arange(total_patches_per_frame, device=device, dtype=torch.long) |
| | tiles_gazing_pos = per_frame_pos.repeat(T_tile).unsqueeze(0).expand(total_tiles, -1).contiguous() |
| | tiles_if_padded = torch.zeros( |
| | total_tiles, T_tile * total_patches_per_frame, device=device, dtype=torch.bool |
| | ) |
| | tiles_num_gazing = torch.full( |
| | (total_tiles, T_tile), total_patches_per_frame, device=device, dtype=torch.long |
| | ) |
| | else: |
| | tiles_autogaze = videos_inputs.get("pixel_values_videos_tiles_autogaze") |
| | if tiles_autogaze is None: |
| | return None |
| |
|
| | all_tiles = torch.cat(tiles_autogaze, dim=0) |
| | tiles_gazing_pos, tiles_if_padded, tiles_num_gazing = self._run_autogaze_batched( |
| | all_tiles, autogaze_device, device, |
| | self.gazing_ratio_tile, self.task_loss_requirement_tile, |
| | ) |
| | tiles_gazing_pos = self._sort_gazing_pos_per_frame( |
| | tiles_gazing_pos, tiles_if_padded, tiles_num_gazing |
| | ) |
| |
|
| | |
| | if skip_thumbnails: |
| | total_thumbs = sum(num_thumbs_per_video) |
| | per_thumb_pos = torch.arange( |
| | total_patches_per_frame, device=device, dtype=torch.long |
| | ) |
| | thumbs_gazing_pos = per_thumb_pos.unsqueeze(0).expand(total_thumbs, -1).contiguous() |
| | thumbs_if_padded = torch.zeros_like(thumbs_gazing_pos, dtype=torch.bool) |
| | thumbs_num_gazing = torch.full( |
| | (total_thumbs, 1), total_patches_per_frame, |
| | device=device, dtype=torch.long, |
| | ) |
| | else: |
| | thumbs_autogaze = videos_inputs.get("pixel_values_videos_thumbnails_autogaze") |
| | if thumbs_autogaze is None: |
| | return None |
| |
|
| | all_thumbs = torch.cat(thumbs_autogaze, dim=0) |
| | thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing = self._run_autogaze_batched( |
| | all_thumbs, autogaze_device, device, |
| | self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail, |
| | ) |
| | thumbs_gazing_pos = self._sort_gazing_pos_per_frame( |
| | thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing |
| | ) |
| |
|
| | |
| | tiles_gazing_pos_list = list(torch.split(tiles_gazing_pos, num_tiles_per_video, dim=0)) |
| | tiles_if_padded_list = list(torch.split(tiles_if_padded, num_tiles_per_video, dim=0)) |
| | tiles_num_gazing_list = list(torch.split(tiles_num_gazing, num_tiles_per_video, dim=0)) |
| |
|
| | thumbs_gazing_pos_list = list(torch.split(thumbs_gazing_pos, num_thumbs_per_video, dim=0)) |
| | thumbs_if_padded_list = list(torch.split(thumbs_if_padded, num_thumbs_per_video, dim=0)) |
| | thumbs_num_gazing_list = list(torch.split(thumbs_num_gazing, num_thumbs_per_video, dim=0)) |
| |
|
| | return { |
| | "gazing_pos_tiles": tiles_gazing_pos_list, |
| | "num_gazing_each_frame_tiles": tiles_num_gazing_list, |
| | "if_padded_gazing_tiles": tiles_if_padded_list, |
| | "gazing_pos_thumbnails": thumbs_gazing_pos_list, |
| | "num_gazing_each_frame_thumbnails": thumbs_num_gazing_list, |
| | "if_padded_gazing_thumbnails": thumbs_if_padded_list, |
| | } |