| import glob |
| import os |
| import re |
| import tempfile |
| import urllib.request |
| from os import PathLike |
| from typing import cast, Optional |
| from urllib.parse import urlparse |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import transformers.image_transforms as image_transforms |
| import transformers.image_utils as image_utils |
| import transformers.video_utils as video_utils |
| from PIL import Image |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_utils import ImageInput |
| from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2TokenizerFast |
| from transformers.models.siglip import SiglipImageProcessor, SiglipImageProcessorFast |
| from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs |
| from transformers.tokenization_utils_base import BatchEncoding, TextInput |
| from transformers.video_utils import VideoInput, VideoMetadata |
|
|
| from autogaze.models.autogaze import AutoGaze |
| from autogaze.models.autogaze import AutoGazeImageProcessor |
| from autogaze.datasets.video_utils import transform_video_for_pytorch |
|
|
|
|
| def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| """Find the closest aspect ratio from a set of target ratios. |
| |
| Referenced from https://github.com/OpenGVLab/InternVL and llava/mm_utils.py |
| """ |
| best_ratio_diff = float("inf") |
| best_ratio = (1, 1) |
| area = width * height |
| for ratio in target_ratios: |
| target_aspect_ratio = ratio[0] / ratio[1] |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| if ratio_diff < best_ratio_diff: |
| best_ratio_diff = ratio_diff |
| best_ratio = ratio |
| elif ratio_diff == best_ratio_diff: |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| best_ratio = ratio |
| return best_ratio |
|
|
|
|
| class NVILAProcessorKwargs(ProcessingKwargs, total=False): |
| _defaults = {} |
|
|
|
|
| def _load_video_frames(video_path: str, num_frames: int = 8) -> list[Image]: |
| """ |
| Load video frames from a video file path. |
| Similar to _load_video in llava/utils/media.py |
| |
| Args: |
| video_path: Path to the video file or directory of frames |
| num_frames: Number of frames to extract |
| |
| Returns: |
| List of PIL Images representing video frames |
| """ |
| vidcap = cv2.VideoCapture(video_path) |
| |
| if not vidcap.isOpened(): |
| raise ValueError(f"Failed to open video: {video_path}") |
| |
| frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| while frame_count > 0: |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
| if vidcap.grab(): |
| break |
| frame_count -= 1 |
| else: |
| vidcap.release() |
| raise ValueError(f"Video '{video_path}' has no frames.") |
| |
| indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) |
| frames = {} |
| for index in indices: |
| if index in frames: |
| continue |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) |
| success, frame = vidcap.read() |
| if not success: |
| continue |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| frames[index] = Image.fromarray(frame) |
| |
| vidcap.release() |
| |
| frames_to_return = [frames[index] for index in indices if index in frames] |
| if len(frames_to_return) < num_frames: |
| if frames_to_return: |
| frames_to_return = frames_to_return + [frames_to_return[-1]] * (num_frames - len(frames_to_return)) |
| else: |
| raise ValueError(f"Could not extract any frames from video: {video_path}") |
| |
| return frames_to_return |
|
|
|
|
| class NVILAProcessor(ProcessorMixin): |
| attributes = [ |
| "image_processor", |
| "tokenizer", |
| ] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "AutoTokenizer" |
| _auto_class = "AutoProcessor" |
|
|
| def __init__( |
| self, |
| image_processor: SiglipImageProcessor | SiglipImageProcessorFast, |
| tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast, |
| chat_template: str | None = None, |
| autogaze_model_id: str | None = None, |
| gazing_ratio_tile: list[float] | float = 0.75, |
| gazing_ratio_thumbnail: float | None = 0.75, |
| task_loss_requirement_tile: float = 0.7, |
| task_loss_requirement_thumbnail: float | None = 0.7, |
| target_scales: list[int] | None = None, |
| target_patch_size: int | None = None, |
| max_tiles_image: int = 12, |
| num_video_frames: int = 8, |
| max_tiles_video: int = 8, |
| num_video_frames_thumbnail: int = 8, |
| mm_projector_shuffle_num: int = 9, |
| max_batch_size_autogaze: int = 32, |
| **kwargs, |
| ): |
| super().__init__( |
| image_processor, |
| tokenizer, |
| chat_template=chat_template, |
| **kwargs, |
| ) |
|
|
| self.image_processor: SiglipImageProcessor | SiglipImageProcessorFast |
| self.tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast |
| |
| |
| self.autogaze_model_id = autogaze_model_id or "bfshi/AutoGaze" |
| self.gazing_ratio_tile = gazing_ratio_tile |
| self.gazing_ratio_thumbnail = gazing_ratio_thumbnail |
| self.task_loss_requirement_tile = task_loss_requirement_tile |
| self.task_loss_requirement_thumbnail = task_loss_requirement_thumbnail |
| self.target_scales = target_scales or [56, 112, 224, 448] |
| self.target_patch_size = target_patch_size or 16 |
| |
| |
| self.max_tiles_image = max_tiles_image |
| self.num_video_frames = num_video_frames |
| self.max_tiles_video = max_tiles_video |
| self.num_video_frames_thumbnail = num_video_frames_thumbnail |
| self.mm_projector_shuffle_num = mm_projector_shuffle_num |
| self.max_batch_size_autogaze = max_batch_size_autogaze |
| |
| |
| self._autogaze_model = None |
| self._autogaze_model = AutoGaze.from_pretrained( |
| self.autogaze_model_id, |
| device_map=None, |
| ) |
| self._autogaze_model.to("cuda").eval() |
| print("AutoGaze loaded successfully in processor") |
|
|
| def __call__( |
| self, |
| *, |
| text: TextInput | list[TextInput], |
| images: ImageInput | None = None, |
| videos: VideoInput | None = None, |
| **kwargs: Unpack[NVILAProcessorKwargs], |
| ) -> BatchFeature: |
| normalized_text, normalized_images, normalized_videos = self._normalize_inputs( |
| text=text, |
| images=images, |
| videos=videos, |
| ) |
|
|
| images_inputs, image_token_padding_strategy = ( |
| self._preprocess_images( |
| normalized_images, |
| **kwargs, |
| ) |
| if len(normalized_images) > 0 |
| else (BatchFeature(), []) |
| ) |
|
|
| videos_inputs = ( |
| self._preprocess_videos( |
| normalized_videos, |
| **kwargs, |
| ) |
| if len(normalized_videos) > 0 |
| else (BatchFeature(), []) |
| ) |
|
|
| |
| gazing_info = None |
| video_token_padding_strategy = [] |
| skip_tiles_gaze = self._should_gaze_all_patches(self.gazing_ratio_tile, self.task_loss_requirement_tile) |
| skip_thumbs_gaze = self._should_gaze_all_patches(self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail) |
| can_construct_without_autogaze = skip_tiles_gaze and skip_thumbs_gaze |
| if len(normalized_videos) > 0 and (self._autogaze_model is not None or can_construct_without_autogaze): |
| gazing_info = self._get_gazing_info_from_videos(videos_inputs) |
| |
| |
| |
| |
| |
| shuffle_num = self.mm_projector_shuffle_num |
| ns_list = videos_inputs["num_spatial_tiles_each_video"] |
|
|
| for vid_idx in range(len(gazing_info["if_padded_gazing_tiles"])): |
| tiles_if_pad = gazing_info["if_padded_gazing_tiles"][vid_idx] |
| tiles_num_gaze = gazing_info["num_gazing_each_frame_tiles"][vid_idx] |
| thumbs_if_pad = gazing_info["if_padded_gazing_thumbnails"][vid_idx] |
| thumbs_num_gaze = gazing_info["num_gazing_each_frame_thumbnails"][vid_idx] |
|
|
| ns = ns_list[vid_idx] |
| num_tiles = tiles_if_pad.shape[0] |
| T_tile = tiles_num_gaze.shape[1] |
| tc = num_tiles // ns |
| total_frames = tc * T_tile |
|
|
| |
| tile_non_padded = [] |
| for t_idx in range(num_tiles): |
| frame_sizes = tiles_num_gaze[t_idx].tolist() |
| frame_pad_segs = tiles_if_pad[t_idx].split(frame_sizes) |
| tile_non_padded.append( |
| [int((~seg).sum().item()) for seg in frame_pad_segs] |
| ) |
|
|
| total_tokens = 0 |
|
|
| |
| for g in range(total_frames): |
| chunk = g // T_tile |
| f_in_chunk = g % T_tile |
| frame_count = sum( |
| tile_non_padded[chunk * ns + s][f_in_chunk] |
| for s in range(ns) |
| ) |
| total_tokens += (frame_count + shuffle_num - 1) // shuffle_num |
|
|
| |
| for th_idx in range(thumbs_if_pad.shape[0]): |
| frame_sizes = thumbs_num_gaze[th_idx].tolist() |
| frame_pad_segs = thumbs_if_pad[th_idx].split(frame_sizes) |
| non_pad = sum(int((~seg).sum().item()) for seg in frame_pad_segs) |
| total_tokens += (non_pad + shuffle_num - 1) // shuffle_num |
|
|
| video_token_padding_strategy.append([total_tokens]) |
| else: |
| video_token_padding_strategy = [[(self.num_video_frames + self.num_video_frames_thumbnail) * 118] * len(normalized_videos)] |
|
|
| |
| |
| if len(normalized_videos) > 0: |
| videos_inputs.pop("pixel_values_videos_tiles_autogaze", None) |
| videos_inputs.pop("pixel_values_videos_thumbnails_autogaze", None) |
|
|
| text_inputs = self._preprocess_text( |
| normalized_text, |
| image_token_padding_strategy=image_token_padding_strategy, |
| video_token_padding_strategy=video_token_padding_strategy, |
| **kwargs, |
| ) |
|
|
| |
| batch_feature = BatchFeature( |
| { |
| **text_inputs, |
| **images_inputs, |
| **videos_inputs, |
| } |
| ) |
|
|
| |
| if gazing_info is not None: |
| batch_feature["gazing_info"] = gazing_info |
|
|
| return batch_feature |
|
|
| def batch_decode(self, *args, **kwargs) -> list[str]: |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def _normalize_inputs( |
| self, |
| *, |
| text: TextInput | list[TextInput], |
| images: ImageInput | None, |
| videos: VideoInput | None, |
| ) -> tuple[list[str], list[Image], list[list[Image]]]: |
| if isinstance(text, list): |
| normalized_text = text |
| else: |
| normalized_text = [text] |
|
|
| if images is not None and images != []: |
| image_flat_list = cast(list, image_utils.make_flat_list_of_images(images)) |
| normalized_images = [cast(Image, image_transforms.to_pil_image(image)) for image in image_flat_list] |
| else: |
| normalized_images = [] |
|
|
| if videos is not None and videos != []: |
| |
| |
| if not isinstance(videos, (list, tuple)): |
| videos = [videos] |
| |
| normalized_videos = [] |
| |
| num_frames = self.num_video_frames |
| for video_input in videos: |
| if isinstance(video_input, str): |
| parsed = urlparse(video_input) |
| if parsed.scheme in ("http", "https"): |
| suffix = os.path.splitext(parsed.path)[1] or ".mp4" |
| tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) |
| try: |
| urllib.request.urlretrieve(video_input, tmp.name) |
| video_frames = _load_video_frames(tmp.name, num_frames=num_frames) |
| finally: |
| tmp.close() |
| os.unlink(tmp.name) |
| else: |
| video_frames = _load_video_frames(video_input, num_frames=num_frames) |
| normalized_videos.append(video_frames) |
| elif isinstance(video_input, (list, tuple)): |
| |
| normalized_videos.append([ |
| cast(Image, image_transforms.to_pil_image(image)) for image in video_input |
| ]) |
| else: |
| |
| try: |
| video_list = cast(list[list], video_utils.make_batched_videos([video_input])) |
| normalized_videos.extend([ |
| [cast(Image, image_transforms.to_pil_image(image)) for image in video] |
| for video in video_list |
| ]) |
| except Exception: |
| raise ValueError( |
| f"Unsupported video input type: {type(video_input)}. " |
| "Expected str (file path) or list of PIL Images." |
| ) |
| else: |
| normalized_videos = [] |
|
|
| return normalized_text, normalized_images, normalized_videos |
|
|
| def _preprocess_images( |
| self, |
| images: list[Image], |
| **kwargs: Unpack[NVILAProcessorKwargs], |
| ) -> tuple[BatchFeature, list[list[int]]]: |
| """Preprocess images into spatial tiles plus a thumbnail. |
| |
| Each image is split into a grid of spatial tiles whose count is at |
| most ``max_tiles_image``. A thumbnail (the whole image resized to |
| ``image_size × image_size``) is appended. Every tile / thumbnail |
| is a single-frame "video" of shape ``(1, C, H, W)``. No AutoGaze |
| is applied — all patches are kept. |
| |
| Returns: |
| A tuple ``(images_inputs, padding_strategy)`` where |
| ``images_inputs`` is a ``BatchFeature`` with: |
| |
| - ``"pixel_values_images_tiles"`` – list of tensors, one per |
| image, each ``(num_tiles_i, 1, C, H, W)``. |
| - ``"pixel_values_images_thumbnails"`` – list of tensors, one |
| per image, each ``(1, 1, C, H, W)``. |
| - ``"num_spatial_tiles_each_image"`` – list of ints. |
| |
| ``padding_strategy`` is a list (one per image) of |
| ``[total_tokens]`` used for text-token padding. |
| """ |
| merged_kwargs = self._merge_kwargs( |
| NVILAProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| if hasattr(self.image_processor, "size"): |
| image_size = self.image_processor.size.get("height", 392) |
| else: |
| image_size = 392 |
|
|
| shuffle_num = self.mm_projector_shuffle_num |
|
|
| num_patches_each_scale = [ |
| (s // self.target_patch_size) ** 2 for s in self.target_scales |
| ] |
| total_patches_per_frame = sum(num_patches_each_scale) |
|
|
| pixel_values_images_tiles: list[torch.Tensor] = [] |
| pixel_values_images_thumbnails: list[torch.Tensor] = [] |
| num_spatial_tiles_each_image: list[int] = [] |
| padding_strategy: list[list[int]] = [] |
|
|
| for image in images: |
| image = image.convert("RGB") |
| orig_width, orig_height = image.size |
|
|
| max_spatial_tiles = max(self.max_tiles_image, 1) |
| aspect_ratio = orig_width / orig_height |
|
|
| target_ratios = { |
| (i, j) |
| for n in range(1, max_spatial_tiles + 1) |
| for i in range(1, n + 1) |
| for j in range(1, n + 1) |
| if 1 <= i * j <= max_spatial_tiles |
| } |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
| target_aspect_ratio = _find_closest_aspect_ratio( |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size |
| ) |
|
|
| target_width = image_size * target_aspect_ratio[0] |
| target_height = image_size * target_aspect_ratio[1] |
| num_tiles = target_aspect_ratio[0] * target_aspect_ratio[1] |
| num_cols = target_aspect_ratio[0] |
|
|
| resized = image.resize((target_width, target_height)) |
|
|
| |
| all_tile_images: list[Image] = [] |
| for tile_idx in range(num_tiles): |
| col = tile_idx % num_cols |
| row = tile_idx // num_cols |
| box = ( |
| col * image_size, |
| row * image_size, |
| (col + 1) * image_size, |
| (row + 1) * image_size, |
| ) |
| all_tile_images.append(resized.crop(box)) |
|
|
| thumbnail = image.resize((image_size, image_size)) |
| all_images_for_siglip = all_tile_images + [thumbnail] |
|
|
| |
| siglip_processed = self.image_processor( |
| all_images_for_siglip, **merged_kwargs["images_kwargs"], |
| )["pixel_values"] |
| if not isinstance(siglip_processed, torch.Tensor): |
| siglip_processed = torch.tensor(np.array(siglip_processed)) |
|
|
| |
| tiles_pv = siglip_processed[:num_tiles].unsqueeze(1) |
| thumb_pv = siglip_processed[num_tiles:].unsqueeze(1) |
|
|
| pixel_values_images_tiles.append(tiles_pv) |
| pixel_values_images_thumbnails.append(thumb_pv) |
| num_spatial_tiles_each_image.append(num_tiles) |
|
|
| |
| tiles_tokens = (num_tiles * total_patches_per_frame + shuffle_num - 1) // shuffle_num |
| thumb_tokens = (total_patches_per_frame + shuffle_num - 1) // shuffle_num |
| padding_strategy.append([tiles_tokens + thumb_tokens]) |
|
|
| images_inputs = BatchFeature({ |
| "pixel_values_images_tiles": pixel_values_images_tiles, |
| "pixel_values_images_thumbnails": pixel_values_images_thumbnails, |
| "num_spatial_tiles_each_image": num_spatial_tiles_each_image, |
| }) |
|
|
| return images_inputs, padding_strategy |
|
|
| def _preprocess_text( |
| self, |
| text: list[str], |
| *, |
| image_token_padding_strategy: list[list[int]], |
| video_token_padding_strategy: list[list[int]], |
| **kwargs: Unpack[NVILAProcessorKwargs], |
| ) -> BatchEncoding: |
| |
| messages = [[ |
| {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, |
| {"role": "user", "content": t} |
| ] for t in text] |
| text = self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| |
| assert isinstance(self.tokenizer.image_token, str) |
| assert isinstance(self.tokenizer.video_token, str) |
|
|
| for media_token, padding_strategy in ( |
| (self.tokenizer.image_token, image_token_padding_strategy), |
| (self.tokenizer.video_token, video_token_padding_strategy), |
| ): |
| assert sum([s.count(media_token) for s in text]) == len(padding_strategy) |
|
|
| |
| pad_lens = [len(x) for x in padding_strategy] |
| text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text] |
|
|
| |
| pad_lens = [y for x in padding_strategy for y in x] |
| text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text] |
|
|
| merged_kwargs = self._merge_kwargs( |
| NVILAProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| text_inputs = self.tokenizer( |
| text=text, |
| **merged_kwargs["text_kwargs"], |
| ) |
|
|
| return text_inputs |
|
|
| def _preprocess_videos( |
| self, |
| videos: list[list[Image]], |
| **kwargs: Unpack[NVILAProcessorKwargs], |
| ) -> BatchFeature: |
| """Preprocess videos into spatiotemporal tiles and thumbnails. |
| |
| Each video is split into a grid of spatiotemporal tiles and a set of |
| low-resolution thumbnail frames. Both SigLIP-processed and |
| AutoGaze-processed copies are produced. |
| |
| Spatial tiling |
| Every frame is resized so that its dimensions become a multiple of |
| ``image_size`` (from the SigLIP image processor) and then cropped |
| into ``(cols, rows)`` spatial tiles, where ``cols * rows <= |
| max_tiles_video``. The best ``(cols, rows)`` is chosen by matching |
| the original frame aspect ratio (same logic as |
| ``dynamic_preprocess`` in ``llava/mm_utils.py``). |
| |
| Temporal chunking |
| The T sampled frames are divided into ``T // max_num_frames`` |
| consecutive chunks of ``max_num_frames`` frames each, where |
| ``max_num_frames`` comes from the AutoGaze model config. |
| ``T`` must be divisible by ``max_num_frames``. |
| |
| Tile ordering |
| Tiles are ordered **temporal-chunk-first**: all spatial tiles for |
| the first temporal chunk, then all spatial tiles for the second |
| temporal chunk, and so on. |
| |
| Thumbnails |
| Each frame is also resized to ``image_size × image_size`` to form a |
| thumbnail. If the number of frames exceeds |
| ``num_video_frames_thumbnail``, thumbnails are uniformly subsampled |
| (every k-th frame) to that count. Each thumbnail is treated as a |
| single-frame video (temporal dim = 1). |
| |
| Args: |
| videos: List of videos, where each video is a list of PIL Images |
| (one per frame). |
| **kwargs: Additional keyword arguments forwarded to the SigLIP |
| image processor. |
| |
| Returns: |
| A tuple ``(videos_inputs, padding_strategy)`` where |
| |
| ``videos_inputs`` is a ``BatchFeature`` dict with the keys: |
| |
| - ``"pixel_values_videos_tiles"`` – list of tensors, one per video. |
| Each tensor has shape ``(num_tiles, T_tile, C, H, W)`` where |
| ``num_tiles = num_spatial_tiles * temporal_chunks``, |
| ``T_tile = max_num_frames`` (from AutoGaze config), |
| and ``H = W = image_size``. |
| Processed by the SigLIP image processor. |
| - ``"pixel_values_videos_thumbnails"`` – list of tensors, one per |
| video. Each tensor has shape |
| ``(T_thumbnail, 1, C, H, W)`` where ``T_thumbnail <= |
| num_video_frames_thumbnail`` and ``H = W = image_size``. |
| Processed by the SigLIP image processor. |
| - ``"pixel_values_videos_tiles_autogaze"`` *(optional)* – same |
| structure as ``pixel_values_videos_tiles`` but processed by the |
| AutoGaze ``transform_video_for_pytorch`` transform. |
| Only present when AutoGaze is available. |
| - ``"pixel_values_videos_thumbnails_autogaze"`` *(optional)* – same |
| structure as ``pixel_values_videos_thumbnails`` but processed by |
| the AutoGaze transform. Only present when AutoGaze is available. |
| |
| ``padding_strategy`` is a list (one entry per video) of lists of |
| ints used for text-token padding. Currently a placeholder; the |
| final strategy depends on downstream gazing results. |
| """ |
| merged_kwargs = self._merge_kwargs( |
| NVILAProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| |
| if hasattr(self.image_processor, "size"): |
| image_size = self.image_processor.size.get("height", 392) |
| else: |
| image_size = 392 |
|
|
| |
| if self._autogaze_model is not None: |
| autogaze_max_num_frames = self._autogaze_model.config.max_num_frames |
| else: |
| autogaze_max_num_frames = 16 |
|
|
| |
| autogaze_transform = None |
| largest_scale = max(self.target_scales) |
| autogaze_transform = AutoGazeImageProcessor.from_pretrained( |
| self.autogaze_model_id, |
| size=(largest_scale, largest_scale), |
| ) |
|
|
| pixel_values_videos_tiles = [] |
| pixel_values_videos_thumbnails = [] |
| pixel_values_videos_tiles_autogaze = [] |
| pixel_values_videos_thumbnails_autogaze = [] |
| num_spatial_tiles_each_video = [] |
|
|
| for video in videos: |
| video = [img.convert("RGB") for img in video] |
| num_frames = len(video) |
| orig_width, orig_height = video[0].size |
|
|
| |
| temporal_chunks = num_frames // autogaze_max_num_frames |
| assert temporal_chunks >= 1 and num_frames % autogaze_max_num_frames == 0, ( |
| f"Number of frames ({num_frames}) must be divisible by " |
| f"AutoGaze max_num_frames ({autogaze_max_num_frames})" |
| ) |
|
|
| |
| |
| max_spatial_tiles = max(self.max_tiles_video, 1) |
|
|
| |
| aspect_ratio = orig_width / orig_height |
|
|
| target_ratios = { |
| (i, j) |
| for n in range(1, max_spatial_tiles + 1) |
| for i in range(1, n + 1) |
| for j in range(1, n + 1) |
| if 1 <= i * j <= max_spatial_tiles |
| } |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
| target_aspect_ratio = _find_closest_aspect_ratio( |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size |
| ) |
|
|
| target_width = image_size * target_aspect_ratio[0] |
| target_height = image_size * target_aspect_ratio[1] |
| num_spatial_tiles = target_aspect_ratio[0] * target_aspect_ratio[1] |
| num_cols = target_aspect_ratio[0] |
|
|
| |
| |
| spatial_tile_frames = [[] for _ in range(num_spatial_tiles)] |
| thumbnail_frames = [] |
|
|
| for frame in video: |
| |
| resized_frame = frame.resize((target_width, target_height)) |
|
|
| |
| for tile_idx in range(num_spatial_tiles): |
| col = tile_idx % num_cols |
| row = tile_idx // num_cols |
| box = ( |
| col * image_size, |
| row * image_size, |
| (col + 1) * image_size, |
| (row + 1) * image_size, |
| ) |
| tile = resized_frame.crop(box) |
| spatial_tile_frames[tile_idx].append(tile) |
|
|
| |
| thumbnail = frame.resize((image_size, image_size)) |
| thumbnail_frames.append(thumbnail) |
|
|
| |
| |
| |
| num_tiles = temporal_chunks * num_spatial_tiles |
| T_tile = autogaze_max_num_frames |
| all_tile_images = [] |
| for t_chunk in range(temporal_chunks): |
| for spatial_idx in range(num_spatial_tiles): |
| start = t_chunk * T_tile |
| end = start + T_tile |
| all_tile_images.extend(spatial_tile_frames[spatial_idx][start:end]) |
|
|
| |
| siglip_processed = self.image_processor( |
| all_tile_images, **merged_kwargs["images_kwargs"], |
| )["pixel_values"] |
| if not isinstance(siglip_processed, torch.Tensor): |
| siglip_processed = torch.tensor(np.array(siglip_processed)) |
| video_tiles_siglip = siglip_processed.reshape(num_tiles, T_tile, *siglip_processed.shape[1:]) |
| pixel_values_videos_tiles.append(video_tiles_siglip) |
|
|
| |
| if autogaze_transform is not None: |
| all_tile_np = np.stack([np.array(f) for f in all_tile_images]) |
| autogaze_processed = transform_video_for_pytorch(all_tile_np, autogaze_transform) |
| video_tiles_autogaze = autogaze_processed.reshape(num_tiles, T_tile, *autogaze_processed.shape[1:]) |
| pixel_values_videos_tiles_autogaze.append(video_tiles_autogaze) |
|
|
| |
| |
| if len(thumbnail_frames) > self.num_video_frames_thumbnail: |
| step = len(thumbnail_frames) // self.num_video_frames_thumbnail |
| sampled_thumbnail_frames = thumbnail_frames[::step][: self.num_video_frames_thumbnail] |
| else: |
| sampled_thumbnail_frames = thumbnail_frames |
|
|
| T_thumb = len(sampled_thumbnail_frames) |
|
|
| |
| siglip_processed = self.image_processor( |
| sampled_thumbnail_frames, **merged_kwargs["images_kwargs"], |
| )["pixel_values"] |
| if not isinstance(siglip_processed, torch.Tensor): |
| siglip_processed = torch.tensor(np.array(siglip_processed)) |
| |
| video_thumbnails_siglip = siglip_processed.unsqueeze(1) |
| pixel_values_videos_thumbnails.append(video_thumbnails_siglip) |
|
|
| |
| if autogaze_transform is not None: |
| all_thumb_np = np.stack([np.array(f) for f in sampled_thumbnail_frames]) |
| autogaze_processed = transform_video_for_pytorch(all_thumb_np, autogaze_transform) |
| video_thumbnails_autogaze = autogaze_processed.unsqueeze(1) |
| pixel_values_videos_thumbnails_autogaze.append(video_thumbnails_autogaze) |
|
|
| num_spatial_tiles_each_video.append(num_spatial_tiles) |
|
|
| print( |
| f"Video tiling: {num_frames} frames @ {orig_width}x{orig_height} → " |
| f"{num_spatial_tiles} spatial × {temporal_chunks} temporal = " |
| f"{num_spatial_tiles * temporal_chunks} tiles, each " |
| f"{autogaze_max_num_frames}×{image_size}×{image_size}; " |
| f"{len(sampled_thumbnail_frames)} thumbnail frames" |
| ) |
|
|
| |
| videos_inputs = BatchFeature( |
| { |
| "pixel_values_videos_tiles": pixel_values_videos_tiles, |
| "pixel_values_videos_thumbnails": pixel_values_videos_thumbnails, |
| "num_spatial_tiles_each_video": num_spatial_tiles_each_video, |
| } |
| ) |
| if pixel_values_videos_tiles_autogaze: |
| videos_inputs["pixel_values_videos_tiles_autogaze"] = pixel_values_videos_tiles_autogaze |
| if pixel_values_videos_thumbnails_autogaze: |
| videos_inputs["pixel_values_videos_thumbnails_autogaze"] = pixel_values_videos_thumbnails_autogaze |
|
|
| return videos_inputs |
| |
| @staticmethod |
| def _should_gaze_all_patches(gazing_ratio, task_loss_requirement) -> bool: |
| """Return True when the gazing config means every patch is kept. |
| |
| This is the case when ``gazing_ratio`` is ``None`` (no gazing at all), |
| or when ``gazing_ratio == 1`` (keep 100 %) **and** |
| ``task_loss_requirement is None`` (no adaptive pruning). |
| """ |
| if gazing_ratio is None: |
| return True |
| if task_loss_requirement is not None: |
| return False |
| if isinstance(gazing_ratio, (list, tuple)): |
| return all(r == 1 for r in gazing_ratio) |
| return gazing_ratio == 1 |
|
|
| @staticmethod |
| def _sort_gazing_pos_per_frame( |
| gazing_pos: torch.Tensor, |
| if_padded: torch.Tensor, |
| num_gazing_each_frame: torch.Tensor, |
| ) -> torch.Tensor: |
| """Sort non-padded gazing positions in ascending order within each frame. |
| |
| Padded positions are left untouched at the end of each frame's segment |
| so that the total count (padded + non-padded) per frame is unchanged. |
| |
| Args: |
| gazing_pos: ``(B, N)`` tensor of gazing patch indices. |
| if_padded: ``(B, N)`` bool tensor (``True`` = padded / dummy). |
| num_gazing_each_frame: ``(B, T)`` tensor giving the number of |
| gazing positions (padded + non-padded) for each frame. |
| |
| Returns: |
| A new ``(B, N)`` tensor with the same values as *gazing_pos* |
| except that the non-padded entries within every frame are sorted. |
| """ |
| sorted_pos = gazing_pos.clone() |
| B, _ = gazing_pos.shape |
| T = num_gazing_each_frame.shape[1] |
|
|
| for b in range(B): |
| offset = 0 |
| for t in range(T): |
| count = int(num_gazing_each_frame[b, t].item()) |
| frame_pos = gazing_pos[b, offset : offset + count] |
| frame_pad = if_padded[b, offset : offset + count] |
|
|
| |
| real_mask = ~frame_pad |
| real_pos = frame_pos[real_mask] |
|
|
| |
| real_pos_sorted = real_pos.sort()[0] |
|
|
| |
| real_indices = real_mask.nonzero(as_tuple=True)[0] |
| sorted_pos[b, offset + real_indices] = real_pos_sorted |
|
|
| offset += count |
|
|
| return sorted_pos |
|
|
| def _run_autogaze_batched( |
| self, |
| all_videos: torch.Tensor, |
| autogaze_device: torch.device, |
| cpu_device: torch.device, |
| gazing_ratio, |
| task_loss_requirement, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Run AutoGaze in minibatches and return combined results on CPU. |
| |
| Different minibatches may produce different per-frame gazing counts |
| (e.g. when ``task_loss_requirement`` triggers adaptive pruning). |
| This method pads each frame's segment to the *maximum* count across |
| all minibatches so that the results can be concatenated along the |
| batch dimension. |
| |
| Args: |
| all_videos: ``(B, T, C, H, W)`` tensor of videos to process. |
| autogaze_device: Device where AutoGaze runs (typically CUDA). |
| cpu_device: Device for the returned tensors (typically CPU). |
| gazing_ratio: Gazing ratio to pass to AutoGaze. |
| task_loss_requirement: Task loss requirement to pass to AutoGaze. |
| |
| Returns: |
| A tuple ``(gazing_pos, if_padded, num_gazing)`` where |
| |
| - ``gazing_pos`` is ``(B, N_max)`` on *cpu_device* |
| - ``if_padded`` is ``(B, N_max)`` bool on *cpu_device* |
| - ``num_gazing`` is ``(B, T)`` on *cpu_device* |
| |
| ``N_max = sum(max_per_frame)`` where ``max_per_frame[t]`` is the |
| largest per-frame count across all minibatches. |
| """ |
| total = all_videos.shape[0] |
| bs = self.max_batch_size_autogaze |
|
|
| batch_results: list[dict] = [] |
|
|
| with torch.inference_mode(): |
| for start in range(0, total, bs): |
| batch = all_videos[start : start + bs] |
|
|
| gaze = self._autogaze_model( |
| {"video": batch.to(autogaze_device)}, |
| gazing_ratio=gazing_ratio, |
| task_loss_requirement=task_loss_requirement, |
| target_scales=self.target_scales, |
| target_patch_size=self.target_patch_size, |
| ) |
|
|
| ng = gaze["num_gazing_each_frame"] |
| if isinstance(ng, list): |
| ng = torch.tensor(ng, device=cpu_device, dtype=torch.long) |
| elif not isinstance(ng, torch.Tensor): |
| ng = torch.tensor(ng, device=cpu_device, dtype=torch.long) |
| else: |
| ng = ng.to(cpu_device) |
| if ng.dim() == 2: |
| ng = ng[0] |
|
|
| batch_results.append({ |
| "gazing_pos": gaze["gazing_pos"].to(cpu_device), |
| "if_padded": gaze["if_padded_gazing"].to(cpu_device), |
| "num_gazing": ng, |
| "batch_size": batch.shape[0], |
| }) |
|
|
| |
| if len(batch_results) == 1: |
| r = batch_results[0] |
| num_gazing = r["num_gazing"].unsqueeze(0).expand(total, -1).contiguous() |
| return r["gazing_pos"], r["if_padded"], num_gazing |
|
|
| |
| all_ng = torch.stack([r["num_gazing"] for r in batch_results], dim=0) |
| max_per_frame = all_ng.max(dim=0).values |
| max_N = int(max_per_frame.sum().item()) |
| T = max_per_frame.shape[0] |
|
|
| padded_pos_list = [] |
| padded_mask_list = [] |
|
|
| for r in batch_results: |
| src_pos = r["gazing_pos"] |
| src_pad = r["if_padded"] |
| src_ng = r["num_gazing"] |
| mini_B = r["batch_size"] |
|
|
| if int(src_ng.sum().item()) == max_N: |
| padded_pos_list.append(src_pos) |
| padded_mask_list.append(src_pad) |
| continue |
|
|
| dst_pos = torch.zeros(mini_B, max_N, device=cpu_device, dtype=src_pos.dtype) |
| dst_pad = torch.ones(mini_B, max_N, device=cpu_device, dtype=torch.bool) |
|
|
| src_off = 0 |
| dst_off = 0 |
| for t in range(T): |
| sc = int(src_ng[t].item()) |
| dc = int(max_per_frame[t].item()) |
| dst_pos[:, dst_off : dst_off + sc] = src_pos[:, src_off : src_off + sc] |
| dst_pad[:, dst_off : dst_off + sc] = src_pad[:, src_off : src_off + sc] |
| src_off += sc |
| dst_off += dc |
|
|
| padded_pos_list.append(dst_pos) |
| padded_mask_list.append(dst_pad) |
|
|
| gazing_pos = torch.cat(padded_pos_list, dim=0) |
| if_padded = torch.cat(padded_mask_list, dim=0) |
| num_gazing = max_per_frame.unsqueeze(0).expand(total, -1).contiguous() |
|
|
| return gazing_pos, if_padded, num_gazing |
|
|
| def _get_gazing_info_from_videos( |
| self, |
| videos_inputs: BatchFeature, |
| ) -> Optional[dict]: |
| """Run AutoGaze on the preprocessed tiles and thumbnails. |
| |
| All tiles from all videos are batched together (they share the same |
| temporal dimension ``T_tile``). Similarly, all thumbnails are batched |
| together (temporal dim = 1). AutoGaze is run once on each batch and |
| the results are split back per-video. |
| |
| When a gazing ratio is 1 and the corresponding task_loss_requirement is |
| None (or gazing_ratio is None), all patches are kept and AutoGaze is |
| skipped for that component. If both tiles and thumbnails meet this |
| condition, AutoGaze is not invoked at all. |
| |
| Args: |
| videos_inputs: The ``BatchFeature`` returned by |
| ``_preprocess_videos``, which must contain the keys |
| ``pixel_values_videos_tiles_autogaze`` and |
| ``pixel_values_videos_thumbnails_autogaze`` (unless the |
| corresponding component can skip AutoGaze). |
| |
| Returns: |
| A dict with the following keys (or ``None`` if AutoGaze is |
| unavailable or the required inputs are missing): |
| |
| - ``"gazing_pos_tiles"`` – list of tensors, one per video, each |
| shaped ``(num_tiles_i, N)``. |
| - ``"num_gazing_each_frame_tiles"`` – list of tensors, one per |
| video, each shaped ``(num_tiles_i, T_tile)``. |
| - ``"if_padded_gazing_tiles"`` – list of bool tensors, one per |
| video, each shaped ``(num_tiles_i, N)``. |
| - ``"gazing_pos_thumbnails"`` – list of tensors, one per video, |
| each shaped ``(T_thumb_i, N')``. |
| - ``"num_gazing_each_frame_thumbnails"`` – list of tensors, one per |
| video, each shaped ``(T_thumb_i, 1)``. |
| - ``"if_padded_gazing_thumbnails"`` – list of bool tensors, one per |
| video, each shaped ``(T_thumb_i, N')``. |
| """ |
| skip_tiles = self._should_gaze_all_patches( |
| self.gazing_ratio_tile, self.task_loss_requirement_tile |
| ) |
| skip_thumbnails = self._should_gaze_all_patches( |
| self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail |
| ) |
| need_autogaze = not skip_tiles or not skip_thumbnails |
|
|
| if need_autogaze and self._autogaze_model is None: |
| return None |
|
|
| |
| siglip_tiles = videos_inputs["pixel_values_videos_tiles"] |
| siglip_thumbs = videos_inputs["pixel_values_videos_thumbnails"] |
| num_tiles_per_video = [t.shape[0] for t in siglip_tiles] |
| num_thumbs_per_video = [t.shape[0] for t in siglip_thumbs] |
|
|
| device = torch.device("cpu") |
| autogaze_device = torch.device("cuda") if torch.cuda.is_available() else device |
|
|
| |
| num_patches_each_scale = [ |
| (s // self.target_patch_size) ** 2 for s in self.target_scales |
| ] |
| total_patches_per_frame = sum(num_patches_each_scale) |
|
|
| |
| if need_autogaze: |
| current_device = next(self._autogaze_model.parameters()).device |
| if current_device != autogaze_device: |
| self._autogaze_model = self._autogaze_model.to(autogaze_device) |
|
|
| |
| if skip_tiles: |
| total_tiles = sum(num_tiles_per_video) |
| T_tile = siglip_tiles[0].shape[1] |
| per_frame_pos = torch.arange(total_patches_per_frame, device=device, dtype=torch.long) |
| tiles_gazing_pos = per_frame_pos.repeat(T_tile).unsqueeze(0).expand(total_tiles, -1).contiguous() |
| tiles_if_padded = torch.zeros( |
| total_tiles, T_tile * total_patches_per_frame, device=device, dtype=torch.bool |
| ) |
| tiles_num_gazing = torch.full( |
| (total_tiles, T_tile), total_patches_per_frame, device=device, dtype=torch.long |
| ) |
| else: |
| tiles_autogaze = videos_inputs.get("pixel_values_videos_tiles_autogaze") |
| if tiles_autogaze is None: |
| return None |
|
|
| all_tiles = torch.cat(tiles_autogaze, dim=0) |
| tiles_gazing_pos, tiles_if_padded, tiles_num_gazing = self._run_autogaze_batched( |
| all_tiles, autogaze_device, device, |
| self.gazing_ratio_tile, self.task_loss_requirement_tile, |
| ) |
| tiles_gazing_pos = self._sort_gazing_pos_per_frame( |
| tiles_gazing_pos, tiles_if_padded, tiles_num_gazing |
| ) |
|
|
| |
| if skip_thumbnails: |
| total_thumbs = sum(num_thumbs_per_video) |
| per_thumb_pos = torch.arange( |
| total_patches_per_frame, device=device, dtype=torch.long |
| ) |
| thumbs_gazing_pos = per_thumb_pos.unsqueeze(0).expand(total_thumbs, -1).contiguous() |
| thumbs_if_padded = torch.zeros_like(thumbs_gazing_pos, dtype=torch.bool) |
| thumbs_num_gazing = torch.full( |
| (total_thumbs, 1), total_patches_per_frame, |
| device=device, dtype=torch.long, |
| ) |
| else: |
| thumbs_autogaze = videos_inputs.get("pixel_values_videos_thumbnails_autogaze") |
| if thumbs_autogaze is None: |
| return None |
|
|
| all_thumbs = torch.cat(thumbs_autogaze, dim=0) |
| thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing = self._run_autogaze_batched( |
| all_thumbs, autogaze_device, device, |
| self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail, |
| ) |
| thumbs_gazing_pos = self._sort_gazing_pos_per_frame( |
| thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing |
| ) |
|
|
| |
| tiles_gazing_pos_list = list(torch.split(tiles_gazing_pos, num_tiles_per_video, dim=0)) |
| tiles_if_padded_list = list(torch.split(tiles_if_padded, num_tiles_per_video, dim=0)) |
| tiles_num_gazing_list = list(torch.split(tiles_num_gazing, num_tiles_per_video, dim=0)) |
|
|
| thumbs_gazing_pos_list = list(torch.split(thumbs_gazing_pos, num_thumbs_per_video, dim=0)) |
| thumbs_if_padded_list = list(torch.split(thumbs_if_padded, num_thumbs_per_video, dim=0)) |
| thumbs_num_gazing_list = list(torch.split(thumbs_num_gazing, num_thumbs_per_video, dim=0)) |
|
|
| return { |
| "gazing_pos_tiles": tiles_gazing_pos_list, |
| "num_gazing_each_frame_tiles": tiles_num_gazing_list, |
| "if_padded_gazing_tiles": tiles_if_padded_list, |
| "gazing_pos_thumbnails": thumbs_gazing_pos_list, |
| "num_gazing_each_frame_thumbnails": thumbs_num_gazing_list, |
| "if_padded_gazing_thumbnails": thumbs_if_padded_list, |
| } |