import glob import os import re import tempfile import urllib.request from os import PathLike from typing import cast, Optional from urllib.parse import urlparse import cv2 import numpy as np import torch import transformers.image_transforms as image_transforms import transformers.image_utils as image_utils import transformers.video_utils as video_utils from PIL import Image from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2TokenizerFast from transformers.models.siglip import SiglipImageProcessor, SiglipImageProcessorFast from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs from transformers.tokenization_utils_base import BatchEncoding, TextInput from transformers.video_utils import VideoInput, VideoMetadata from autogaze.models.autogaze import AutoGaze from autogaze.models.autogaze import AutoGazeImageProcessor from autogaze.datasets.video_utils import transform_video_for_pytorch def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): """Find the closest aspect ratio from a set of target ratios. Referenced from https://github.com/OpenGVLab/InternVL and llava/mm_utils.py """ best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio class NVILAProcessorKwargs(ProcessingKwargs, total=False): _defaults = {} # type: ignore def _load_video_frames(video_path: str, num_frames: int = 8) -> list[Image]: """ Load video frames from a video file path. Similar to _load_video in llava/utils/media.py Args: video_path: Path to the video file or directory of frames num_frames: Number of frames to extract Returns: List of PIL Images representing video frames """ vidcap = cv2.VideoCapture(video_path) if not vidcap.isOpened(): raise ValueError(f"Failed to open video: {video_path}") frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) while frame_count > 0: vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) if vidcap.grab(): break frame_count -= 1 else: vidcap.release() raise ValueError(f"Video '{video_path}' has no frames.") indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) frames = {} for index in indices: if index in frames: continue vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) success, frame = vidcap.read() if not success: continue frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames[index] = Image.fromarray(frame) vidcap.release() frames_to_return = [frames[index] for index in indices if index in frames] if len(frames_to_return) < num_frames: if frames_to_return: frames_to_return = frames_to_return + [frames_to_return[-1]] * (num_frames - len(frames_to_return)) else: raise ValueError(f"Could not extract any frames from video: {video_path}") return frames_to_return class NVILAProcessor(ProcessorMixin): attributes = [ "image_processor", "tokenizer", ] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" _auto_class = "AutoProcessor" def __init__( self, image_processor: SiglipImageProcessor | SiglipImageProcessorFast, tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast, chat_template: str | None = None, autogaze_model_id: str | None = None, gazing_ratio_tile: list[float] | float = 0.75, gazing_ratio_thumbnail: float | None = 0.75, task_loss_requirement_tile: float = 0.7, task_loss_requirement_thumbnail: float | None = 0.7, target_scales: list[int] | None = None, target_patch_size: int | None = None, max_tiles_image: int = 12, num_video_frames: int = 8, max_tiles_video: int = 8, num_video_frames_thumbnail: int = 8, mm_projector_shuffle_num: int = 9, max_batch_size_autogaze: int = 32, **kwargs, ): super().__init__( image_processor, tokenizer, chat_template=chat_template, **kwargs, ) self.image_processor: SiglipImageProcessor | SiglipImageProcessorFast self.tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast # AutoGaze configuration self.autogaze_model_id = autogaze_model_id or "bfshi/AutoGaze" self.gazing_ratio_tile = gazing_ratio_tile self.gazing_ratio_thumbnail = gazing_ratio_thumbnail self.task_loss_requirement_tile = task_loss_requirement_tile self.task_loss_requirement_thumbnail = task_loss_requirement_thumbnail self.target_scales = target_scales or [56, 112, 224, 448] self.target_patch_size = target_patch_size or 16 # Image / video processing configuration self.max_tiles_image = max_tiles_image self.num_video_frames = num_video_frames self.max_tiles_video = max_tiles_video self.num_video_frames_thumbnail = num_video_frames_thumbnail self.mm_projector_shuffle_num = mm_projector_shuffle_num self.max_batch_size_autogaze = max_batch_size_autogaze # Load AutoGaze if available self._autogaze_model = None self._autogaze_model = AutoGaze.from_pretrained( self.autogaze_model_id, device_map=None, ) self._autogaze_model.to("cuda").eval() print("AutoGaze loaded successfully in processor") def __call__( self, *, text: TextInput | list[TextInput], images: ImageInput | None = None, videos: VideoInput | None = None, **kwargs: Unpack[NVILAProcessorKwargs], ) -> BatchFeature: normalized_text, normalized_images, normalized_videos = self._normalize_inputs( text=text, images=images, videos=videos, ) images_inputs, image_token_padding_strategy = ( self._preprocess_images( normalized_images, **kwargs, ) if len(normalized_images) > 0 else (BatchFeature(), []) ) videos_inputs = ( self._preprocess_videos( normalized_videos, **kwargs, ) if len(normalized_videos) > 0 else (BatchFeature(), []) ) # Run AutoGaze on preprocessed tiles/thumbnails and compute padding gazing_info = None video_token_padding_strategy = [] skip_tiles_gaze = self._should_gaze_all_patches(self.gazing_ratio_tile, self.task_loss_requirement_tile) skip_thumbs_gaze = self._should_gaze_all_patches(self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail) can_construct_without_autogaze = skip_tiles_gaze and skip_thumbs_gaze if len(normalized_videos) > 0 and (self._autogaze_model is not None or can_construct_without_autogaze): gazing_info = self._get_gazing_info_from_videos(videos_inputs) # Compute video padding strategy from gazing results. # Because the mm_projector uses TokenShuffle(9), each # "effective frame" is padded to a multiple of 9 before # projection, then divided by 9. So total tokens per # video = sum_over_frames(ceil(non_padded_per_frame / 9)). shuffle_num = self.mm_projector_shuffle_num ns_list = videos_inputs["num_spatial_tiles_each_video"] for vid_idx in range(len(gazing_info["if_padded_gazing_tiles"])): tiles_if_pad = gazing_info["if_padded_gazing_tiles"][vid_idx] # (num_tiles, N) tiles_num_gaze = gazing_info["num_gazing_each_frame_tiles"][vid_idx] # (num_tiles, T_tile) thumbs_if_pad = gazing_info["if_padded_gazing_thumbnails"][vid_idx] # (T_thumb, N') thumbs_num_gaze = gazing_info["num_gazing_each_frame_thumbnails"][vid_idx] # (T_thumb, 1) ns = ns_list[vid_idx] num_tiles = tiles_if_pad.shape[0] T_tile = tiles_num_gaze.shape[1] tc = num_tiles // ns # temporal chunks total_frames = tc * T_tile # Non-padded count per tile per frame tile_non_padded = [] # tile_non_padded[tile][frame] = int for t_idx in range(num_tiles): frame_sizes = tiles_num_gaze[t_idx].tolist() frame_pad_segs = tiles_if_pad[t_idx].split(frame_sizes) tile_non_padded.append( [int((~seg).sum().item()) for seg in frame_pad_segs] ) total_tokens = 0 # Tile effective frames (all spatial tiles for one temporal frame) for g in range(total_frames): chunk = g // T_tile f_in_chunk = g % T_tile frame_count = sum( tile_non_padded[chunk * ns + s][f_in_chunk] for s in range(ns) ) total_tokens += (frame_count + shuffle_num - 1) // shuffle_num # Thumbnail frames (each is 1 frame) for th_idx in range(thumbs_if_pad.shape[0]): frame_sizes = thumbs_num_gaze[th_idx].tolist() frame_pad_segs = thumbs_if_pad[th_idx].split(frame_sizes) non_pad = sum(int((~seg).sum().item()) for seg in frame_pad_segs) total_tokens += (non_pad + shuffle_num - 1) // shuffle_num video_token_padding_strategy.append([total_tokens]) else: video_token_padding_strategy = [[(self.num_video_frames + self.num_video_frames_thumbnail) * 118] * len(normalized_videos)] # Remove AutoGaze-processed pixel values — they were only needed # for computing gazing_info and should not be sent to the model. if len(normalized_videos) > 0: videos_inputs.pop("pixel_values_videos_tiles_autogaze", None) videos_inputs.pop("pixel_values_videos_thumbnails_autogaze", None) text_inputs = self._preprocess_text( normalized_text, image_token_padding_strategy=image_token_padding_strategy, video_token_padding_strategy=video_token_padding_strategy, **kwargs, ) # Combine all inputs batch_feature = BatchFeature( { **text_inputs, **images_inputs, **videos_inputs, } ) # Attach gazing_info so the model can use it downstream if gazing_info is not None: batch_feature["gazing_info"] = gazing_info return batch_feature def batch_decode(self, *args, **kwargs) -> list[str]: return self.tokenizer.batch_decode(*args, **kwargs) def _normalize_inputs( self, *, text: TextInput | list[TextInput], images: ImageInput | None, videos: VideoInput | None, ) -> tuple[list[str], list[Image], list[list[Image]]]: if isinstance(text, list): normalized_text = text else: normalized_text = [text] if images is not None and images != []: image_flat_list = cast(list, image_utils.make_flat_list_of_images(images)) normalized_images = [cast(Image, image_transforms.to_pil_image(image)) for image in image_flat_list] else: normalized_images = [] if videos is not None and videos != []: # Handle video inputs - can be file paths (str) or lists of PIL Images # videos can be a single item or a list if not isinstance(videos, (list, tuple)): videos = [videos] normalized_videos = [] # Use num_video_frames from processor config num_frames = self.num_video_frames for video_input in videos: if isinstance(video_input, str): parsed = urlparse(video_input) if parsed.scheme in ("http", "https"): suffix = os.path.splitext(parsed.path)[1] or ".mp4" tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) try: urllib.request.urlretrieve(video_input, tmp.name) video_frames = _load_video_frames(tmp.name, num_frames=num_frames) finally: tmp.close() os.unlink(tmp.name) else: video_frames = _load_video_frames(video_input, num_frames=num_frames) normalized_videos.append(video_frames) elif isinstance(video_input, (list, tuple)): # If it's already a list of images, convert them to PIL Images normalized_videos.append([ cast(Image, image_transforms.to_pil_image(image)) for image in video_input ]) else: # Try to use video_utils for other types try: video_list = cast(list[list], video_utils.make_batched_videos([video_input])) normalized_videos.extend([ [cast(Image, image_transforms.to_pil_image(image)) for image in video] for video in video_list ]) except Exception: raise ValueError( f"Unsupported video input type: {type(video_input)}. " "Expected str (file path) or list of PIL Images." ) else: normalized_videos = [] return normalized_text, normalized_images, normalized_videos def _preprocess_images( self, images: list[Image], **kwargs: Unpack[NVILAProcessorKwargs], ) -> tuple[BatchFeature, list[list[int]]]: """Preprocess images into spatial tiles plus a thumbnail. Each image is split into a grid of spatial tiles whose count is at most ``max_tiles_image``. A thumbnail (the whole image resized to ``image_size × image_size``) is appended. Every tile / thumbnail is a single-frame "video" of shape ``(1, C, H, W)``. No AutoGaze is applied — all patches are kept. Returns: A tuple ``(images_inputs, padding_strategy)`` where ``images_inputs`` is a ``BatchFeature`` with: - ``"pixel_values_images_tiles"`` – list of tensors, one per image, each ``(num_tiles_i, 1, C, H, W)``. - ``"pixel_values_images_thumbnails"`` – list of tensors, one per image, each ``(1, 1, C, H, W)``. - ``"num_spatial_tiles_each_image"`` – list of ints. ``padding_strategy`` is a list (one per image) of ``[total_tokens]`` used for text-token padding. """ merged_kwargs = self._merge_kwargs( NVILAProcessorKwargs, # type: ignore tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if hasattr(self.image_processor, "size"): image_size = self.image_processor.size.get("height", 392) else: image_size = 392 shuffle_num = self.mm_projector_shuffle_num num_patches_each_scale = [ (s // self.target_patch_size) ** 2 for s in self.target_scales ] total_patches_per_frame = sum(num_patches_each_scale) pixel_values_images_tiles: list[torch.Tensor] = [] pixel_values_images_thumbnails: list[torch.Tensor] = [] num_spatial_tiles_each_image: list[int] = [] padding_strategy: list[list[int]] = [] for image in images: image = image.convert("RGB") orig_width, orig_height = image.size max_spatial_tiles = max(self.max_tiles_image, 1) aspect_ratio = orig_width / orig_height target_ratios = { (i, j) for n in range(1, max_spatial_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 1 <= i * j <= max_spatial_tiles } target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_aspect_ratio = _find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size ) target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] num_tiles = target_aspect_ratio[0] * target_aspect_ratio[1] num_cols = target_aspect_ratio[0] resized = image.resize((target_width, target_height)) # Spatial tiles + thumbnail (whole image resized) all_tile_images: list[Image] = [] for tile_idx in range(num_tiles): col = tile_idx % num_cols row = tile_idx // num_cols box = ( col * image_size, row * image_size, (col + 1) * image_size, (row + 1) * image_size, ) all_tile_images.append(resized.crop(box)) thumbnail = image.resize((image_size, image_size)) all_images_for_siglip = all_tile_images + [thumbnail] # SigLIP: process tiles + thumbnail at once → (num_tiles+1, C, H, W) siglip_processed = self.image_processor( all_images_for_siglip, **merged_kwargs["images_kwargs"], )["pixel_values"] if not isinstance(siglip_processed, torch.Tensor): siglip_processed = torch.tensor(np.array(siglip_processed)) # Split into tiles and thumbnail, add temporal dim tiles_pv = siglip_processed[:num_tiles].unsqueeze(1) # (num_tiles, 1, C, H, W) thumb_pv = siglip_processed[num_tiles:].unsqueeze(1) # (1, 1, C, H, W) pixel_values_images_tiles.append(tiles_pv) pixel_values_images_thumbnails.append(thumb_pv) num_spatial_tiles_each_image.append(num_tiles) # Padding: tiles effective frame + thumbnail effective frame tiles_tokens = (num_tiles * total_patches_per_frame + shuffle_num - 1) // shuffle_num thumb_tokens = (total_patches_per_frame + shuffle_num - 1) // shuffle_num padding_strategy.append([tiles_tokens + thumb_tokens]) images_inputs = BatchFeature({ "pixel_values_images_tiles": pixel_values_images_tiles, "pixel_values_images_thumbnails": pixel_values_images_thumbnails, "num_spatial_tiles_each_image": num_spatial_tiles_each_image, }) return images_inputs, padding_strategy def _preprocess_text( self, text: list[str], *, image_token_padding_strategy: list[list[int]], video_token_padding_strategy: list[list[int]], **kwargs: Unpack[NVILAProcessorKwargs], ) -> BatchEncoding: # Apply chat template to text messages = [[ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, {"role": "user", "content": t} ] for t in text] text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Pad media tokens. assert isinstance(self.tokenizer.image_token, str) assert isinstance(self.tokenizer.video_token, str) for media_token, padding_strategy in ( (self.tokenizer.image_token, image_token_padding_strategy), (self.tokenizer.video_token, video_token_padding_strategy), ): assert sum([s.count(media_token) for s in text]) == len(padding_strategy) # Pad to number of tiles. pad_lens = [len(x) for x in padding_strategy] text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text] # Pad to number of features. pad_lens = [y for x in padding_strategy for y in x] text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text] merged_kwargs = self._merge_kwargs( NVILAProcessorKwargs, # type: ignore tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) text_inputs = self.tokenizer( text=text, **merged_kwargs["text_kwargs"], ) return text_inputs def _preprocess_videos( self, videos: list[list[Image]], **kwargs: Unpack[NVILAProcessorKwargs], ) -> BatchFeature: """Preprocess videos into spatiotemporal tiles and thumbnails. Each video is split into a grid of spatiotemporal tiles and a set of low-resolution thumbnail frames. Both SigLIP-processed and AutoGaze-processed copies are produced. Spatial tiling Every frame is resized so that its dimensions become a multiple of ``image_size`` (from the SigLIP image processor) and then cropped into ``(cols, rows)`` spatial tiles, where ``cols * rows <= max_tiles_video``. The best ``(cols, rows)`` is chosen by matching the original frame aspect ratio (same logic as ``dynamic_preprocess`` in ``llava/mm_utils.py``). Temporal chunking The T sampled frames are divided into ``T // max_num_frames`` consecutive chunks of ``max_num_frames`` frames each, where ``max_num_frames`` comes from the AutoGaze model config. ``T`` must be divisible by ``max_num_frames``. Tile ordering Tiles are ordered **temporal-chunk-first**: all spatial tiles for the first temporal chunk, then all spatial tiles for the second temporal chunk, and so on. Thumbnails Each frame is also resized to ``image_size × image_size`` to form a thumbnail. If the number of frames exceeds ``num_video_frames_thumbnail``, thumbnails are uniformly subsampled (every k-th frame) to that count. Each thumbnail is treated as a single-frame video (temporal dim = 1). Args: videos: List of videos, where each video is a list of PIL Images (one per frame). **kwargs: Additional keyword arguments forwarded to the SigLIP image processor. Returns: A tuple ``(videos_inputs, padding_strategy)`` where ``videos_inputs`` is a ``BatchFeature`` dict with the keys: - ``"pixel_values_videos_tiles"`` – list of tensors, one per video. Each tensor has shape ``(num_tiles, T_tile, C, H, W)`` where ``num_tiles = num_spatial_tiles * temporal_chunks``, ``T_tile = max_num_frames`` (from AutoGaze config), and ``H = W = image_size``. Processed by the SigLIP image processor. - ``"pixel_values_videos_thumbnails"`` – list of tensors, one per video. Each tensor has shape ``(T_thumbnail, 1, C, H, W)`` where ``T_thumbnail <= num_video_frames_thumbnail`` and ``H = W = image_size``. Processed by the SigLIP image processor. - ``"pixel_values_videos_tiles_autogaze"`` *(optional)* – same structure as ``pixel_values_videos_tiles`` but processed by the AutoGaze ``transform_video_for_pytorch`` transform. Only present when AutoGaze is available. - ``"pixel_values_videos_thumbnails_autogaze"`` *(optional)* – same structure as ``pixel_values_videos_thumbnails`` but processed by the AutoGaze transform. Only present when AutoGaze is available. ``padding_strategy`` is a list (one entry per video) of lists of ints used for text-token padding. Currently a placeholder; the final strategy depends on downstream gazing results. """ merged_kwargs = self._merge_kwargs( NVILAProcessorKwargs, # type: ignore tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) # Get siglip image size (tile spatial resolution) if hasattr(self.image_processor, "size"): image_size = self.image_processor.size.get("height", 392) else: image_size = 392 # Get AutoGaze max_num_frames for temporal chunking if self._autogaze_model is not None: autogaze_max_num_frames = self._autogaze_model.config.max_num_frames else: autogaze_max_num_frames = 16 # default # Load AutoGaze transform if available autogaze_transform = None largest_scale = max(self.target_scales) autogaze_transform = AutoGazeImageProcessor.from_pretrained( self.autogaze_model_id, size=(largest_scale, largest_scale), ) pixel_values_videos_tiles = [] pixel_values_videos_thumbnails = [] pixel_values_videos_tiles_autogaze = [] pixel_values_videos_thumbnails_autogaze = [] num_spatial_tiles_each_video = [] for video in videos: video = [img.convert("RGB") for img in video] num_frames = len(video) orig_width, orig_height = video[0].size # --- Temporal chunking --- temporal_chunks = num_frames // autogaze_max_num_frames assert temporal_chunks >= 1 and num_frames % autogaze_max_num_frames == 0, ( f"Number of frames ({num_frames}) must be divisible by " f"AutoGaze max_num_frames ({autogaze_max_num_frames})" ) # --- Spatial tiling --- # max_tiles_video directly controls the max number of spatial tiles max_spatial_tiles = max(self.max_tiles_video, 1) # Use dynamic_preprocess-style approach for finding best spatial aspect ratio aspect_ratio = orig_width / orig_height target_ratios = { (i, j) for n in range(1, max_spatial_tiles + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 1 <= i * j <= max_spatial_tiles } target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) target_aspect_ratio = _find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size ) target_width = image_size * target_aspect_ratio[0] # cols * image_size target_height = image_size * target_aspect_ratio[1] # rows * image_size num_spatial_tiles = target_aspect_ratio[0] * target_aspect_ratio[1] num_cols = target_aspect_ratio[0] # --- Build per-frame spatial tiles and thumbnails --- # spatial_tile_frames[spatial_idx] = list of T PIL Images spatial_tile_frames = [[] for _ in range(num_spatial_tiles)] thumbnail_frames = [] for frame in video: # Resize frame for spatial tiling resized_frame = frame.resize((target_width, target_height)) # Split into spatial tiles for tile_idx in range(num_spatial_tiles): col = tile_idx % num_cols row = tile_idx // num_cols box = ( col * image_size, row * image_size, (col + 1) * image_size, (row + 1) * image_size, ) tile = resized_frame.crop(box) spatial_tile_frames[tile_idx].append(tile) # Thumbnail: resize whole frame to image_size x image_size thumbnail = frame.resize((image_size, image_size)) thumbnail_frames.append(thumbnail) # --- Assemble spatiotemporal tiles --- # Collect all tile images in flat order: temporal chunk (outer) × # spatial tile (inner) × frame-within-chunk (innermost). num_tiles = temporal_chunks * num_spatial_tiles T_tile = autogaze_max_num_frames all_tile_images = [] for t_chunk in range(temporal_chunks): for spatial_idx in range(num_spatial_tiles): start = t_chunk * T_tile end = start + T_tile all_tile_images.extend(spatial_tile_frames[spatial_idx][start:end]) # SigLIP: process all tile images at once → (num_tiles * T_tile, C, H, W) siglip_processed = self.image_processor( all_tile_images, **merged_kwargs["images_kwargs"], )["pixel_values"] if not isinstance(siglip_processed, torch.Tensor): siglip_processed = torch.tensor(np.array(siglip_processed)) video_tiles_siglip = siglip_processed.reshape(num_tiles, T_tile, *siglip_processed.shape[1:]) pixel_values_videos_tiles.append(video_tiles_siglip) # AutoGaze transform: process all tile images at once if autogaze_transform is not None: all_tile_np = np.stack([np.array(f) for f in all_tile_images]) # (num_tiles * T_tile, H, W, 3) autogaze_processed = transform_video_for_pytorch(all_tile_np, autogaze_transform) video_tiles_autogaze = autogaze_processed.reshape(num_tiles, T_tile, *autogaze_processed.shape[1:]) pixel_values_videos_tiles_autogaze.append(video_tiles_autogaze) # --- Assemble thumbnails --- # Subsample thumbnails if needed (keep every k-th frame) if len(thumbnail_frames) > self.num_video_frames_thumbnail: step = len(thumbnail_frames) // self.num_video_frames_thumbnail sampled_thumbnail_frames = thumbnail_frames[::step][: self.num_video_frames_thumbnail] else: sampled_thumbnail_frames = thumbnail_frames T_thumb = len(sampled_thumbnail_frames) # SigLIP: process all thumbnail images at once → (T_thumb, C, H, W) siglip_processed = self.image_processor( sampled_thumbnail_frames, **merged_kwargs["images_kwargs"], )["pixel_values"] if not isinstance(siglip_processed, torch.Tensor): siglip_processed = torch.tensor(np.array(siglip_processed)) # Each thumbnail is a single-frame video → (T_thumb, 1, C, H, W) video_thumbnails_siglip = siglip_processed.unsqueeze(1) pixel_values_videos_thumbnails.append(video_thumbnails_siglip) # AutoGaze transform: process all thumbnail images at once if autogaze_transform is not None: all_thumb_np = np.stack([np.array(f) for f in sampled_thumbnail_frames]) # (T_thumb, H, W, 3) autogaze_processed = transform_video_for_pytorch(all_thumb_np, autogaze_transform) video_thumbnails_autogaze = autogaze_processed.unsqueeze(1) # (T_thumb, 1, C, H, W) pixel_values_videos_thumbnails_autogaze.append(video_thumbnails_autogaze) num_spatial_tiles_each_video.append(num_spatial_tiles) print( f"Video tiling: {num_frames} frames @ {orig_width}x{orig_height} → " f"{num_spatial_tiles} spatial × {temporal_chunks} temporal = " f"{num_spatial_tiles * temporal_chunks} tiles, each " f"{autogaze_max_num_frames}×{image_size}×{image_size}; " f"{len(sampled_thumbnail_frames)} thumbnail frames" ) # Build output BatchFeature videos_inputs = BatchFeature( { "pixel_values_videos_tiles": pixel_values_videos_tiles, "pixel_values_videos_thumbnails": pixel_values_videos_thumbnails, "num_spatial_tiles_each_video": num_spatial_tiles_each_video, } ) if pixel_values_videos_tiles_autogaze: videos_inputs["pixel_values_videos_tiles_autogaze"] = pixel_values_videos_tiles_autogaze if pixel_values_videos_thumbnails_autogaze: videos_inputs["pixel_values_videos_thumbnails_autogaze"] = pixel_values_videos_thumbnails_autogaze return videos_inputs @staticmethod def _should_gaze_all_patches(gazing_ratio, task_loss_requirement) -> bool: """Return True when the gazing config means every patch is kept. This is the case when ``gazing_ratio`` is ``None`` (no gazing at all), or when ``gazing_ratio == 1`` (keep 100 %) **and** ``task_loss_requirement is None`` (no adaptive pruning). """ if gazing_ratio is None: return True if task_loss_requirement is not None: return False if isinstance(gazing_ratio, (list, tuple)): return all(r == 1 for r in gazing_ratio) return gazing_ratio == 1 @staticmethod def _sort_gazing_pos_per_frame( gazing_pos: torch.Tensor, if_padded: torch.Tensor, num_gazing_each_frame: torch.Tensor, ) -> torch.Tensor: """Sort non-padded gazing positions in ascending order within each frame. Padded positions are left untouched at the end of each frame's segment so that the total count (padded + non-padded) per frame is unchanged. Args: gazing_pos: ``(B, N)`` tensor of gazing patch indices. if_padded: ``(B, N)`` bool tensor (``True`` = padded / dummy). num_gazing_each_frame: ``(B, T)`` tensor giving the number of gazing positions (padded + non-padded) for each frame. Returns: A new ``(B, N)`` tensor with the same values as *gazing_pos* except that the non-padded entries within every frame are sorted. """ sorted_pos = gazing_pos.clone() B, _ = gazing_pos.shape T = num_gazing_each_frame.shape[1] for b in range(B): offset = 0 for t in range(T): count = int(num_gazing_each_frame[b, t].item()) frame_pos = gazing_pos[b, offset : offset + count] frame_pad = if_padded[b, offset : offset + count] # Indices of non-padded (real) positions within the frame segment real_mask = ~frame_pad real_pos = frame_pos[real_mask] # Sort the real positions real_pos_sorted = real_pos.sort()[0] # Write sorted values back at the correct locations real_indices = real_mask.nonzero(as_tuple=True)[0] sorted_pos[b, offset + real_indices] = real_pos_sorted offset += count return sorted_pos def _run_autogaze_batched( self, all_videos: torch.Tensor, autogaze_device: torch.device, cpu_device: torch.device, gazing_ratio, task_loss_requirement, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run AutoGaze in minibatches and return combined results on CPU. Different minibatches may produce different per-frame gazing counts (e.g. when ``task_loss_requirement`` triggers adaptive pruning). This method pads each frame's segment to the *maximum* count across all minibatches so that the results can be concatenated along the batch dimension. Args: all_videos: ``(B, T, C, H, W)`` tensor of videos to process. autogaze_device: Device where AutoGaze runs (typically CUDA). cpu_device: Device for the returned tensors (typically CPU). gazing_ratio: Gazing ratio to pass to AutoGaze. task_loss_requirement: Task loss requirement to pass to AutoGaze. Returns: A tuple ``(gazing_pos, if_padded, num_gazing)`` where - ``gazing_pos`` is ``(B, N_max)`` on *cpu_device* - ``if_padded`` is ``(B, N_max)`` bool on *cpu_device* - ``num_gazing`` is ``(B, T)`` on *cpu_device* ``N_max = sum(max_per_frame)`` where ``max_per_frame[t]`` is the largest per-frame count across all minibatches. """ total = all_videos.shape[0] bs = self.max_batch_size_autogaze batch_results: list[dict] = [] with torch.inference_mode(): for start in range(0, total, bs): batch = all_videos[start : start + bs] gaze = self._autogaze_model( {"video": batch.to(autogaze_device)}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement, target_scales=self.target_scales, target_patch_size=self.target_patch_size, ) ng = gaze["num_gazing_each_frame"] if isinstance(ng, list): ng = torch.tensor(ng, device=cpu_device, dtype=torch.long) elif not isinstance(ng, torch.Tensor): ng = torch.tensor(ng, device=cpu_device, dtype=torch.long) else: ng = ng.to(cpu_device) if ng.dim() == 2: ng = ng[0] batch_results.append({ "gazing_pos": gaze["gazing_pos"].to(cpu_device), "if_padded": gaze["if_padded_gazing"].to(cpu_device), "num_gazing": ng, "batch_size": batch.shape[0], }) # Fast path: single minibatch — no cross-batch padding needed if len(batch_results) == 1: r = batch_results[0] num_gazing = r["num_gazing"].unsqueeze(0).expand(total, -1).contiguous() return r["gazing_pos"], r["if_padded"], num_gazing # Compute the max per-frame count across all minibatches all_ng = torch.stack([r["num_gazing"] for r in batch_results], dim=0) # (num_minibatches, T) max_per_frame = all_ng.max(dim=0).values # (T,) max_N = int(max_per_frame.sum().item()) T = max_per_frame.shape[0] padded_pos_list = [] padded_mask_list = [] for r in batch_results: src_pos = r["gazing_pos"] # (mini_B, N_src) src_pad = r["if_padded"] # (mini_B, N_src) src_ng = r["num_gazing"] # (T,) mini_B = r["batch_size"] if int(src_ng.sum().item()) == max_N: padded_pos_list.append(src_pos) padded_mask_list.append(src_pad) continue dst_pos = torch.zeros(mini_B, max_N, device=cpu_device, dtype=src_pos.dtype) dst_pad = torch.ones(mini_B, max_N, device=cpu_device, dtype=torch.bool) src_off = 0 dst_off = 0 for t in range(T): sc = int(src_ng[t].item()) dc = int(max_per_frame[t].item()) dst_pos[:, dst_off : dst_off + sc] = src_pos[:, src_off : src_off + sc] dst_pad[:, dst_off : dst_off + sc] = src_pad[:, src_off : src_off + sc] src_off += sc dst_off += dc padded_pos_list.append(dst_pos) padded_mask_list.append(dst_pad) gazing_pos = torch.cat(padded_pos_list, dim=0) if_padded = torch.cat(padded_mask_list, dim=0) num_gazing = max_per_frame.unsqueeze(0).expand(total, -1).contiguous() return gazing_pos, if_padded, num_gazing def _get_gazing_info_from_videos( self, videos_inputs: BatchFeature, ) -> Optional[dict]: """Run AutoGaze on the preprocessed tiles and thumbnails. All tiles from all videos are batched together (they share the same temporal dimension ``T_tile``). Similarly, all thumbnails are batched together (temporal dim = 1). AutoGaze is run once on each batch and the results are split back per-video. When a gazing ratio is 1 and the corresponding task_loss_requirement is None (or gazing_ratio is None), all patches are kept and AutoGaze is skipped for that component. If both tiles and thumbnails meet this condition, AutoGaze is not invoked at all. Args: videos_inputs: The ``BatchFeature`` returned by ``_preprocess_videos``, which must contain the keys ``pixel_values_videos_tiles_autogaze`` and ``pixel_values_videos_thumbnails_autogaze`` (unless the corresponding component can skip AutoGaze). Returns: A dict with the following keys (or ``None`` if AutoGaze is unavailable or the required inputs are missing): - ``"gazing_pos_tiles"`` – list of tensors, one per video, each shaped ``(num_tiles_i, N)``. - ``"num_gazing_each_frame_tiles"`` – list of tensors, one per video, each shaped ``(num_tiles_i, T_tile)``. - ``"if_padded_gazing_tiles"`` – list of bool tensors, one per video, each shaped ``(num_tiles_i, N)``. - ``"gazing_pos_thumbnails"`` – list of tensors, one per video, each shaped ``(T_thumb_i, N')``. - ``"num_gazing_each_frame_thumbnails"`` – list of tensors, one per video, each shaped ``(T_thumb_i, 1)``. - ``"if_padded_gazing_thumbnails"`` – list of bool tensors, one per video, each shaped ``(T_thumb_i, N')``. """ skip_tiles = self._should_gaze_all_patches( self.gazing_ratio_tile, self.task_loss_requirement_tile ) skip_thumbnails = self._should_gaze_all_patches( self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail ) need_autogaze = not skip_tiles or not skip_thumbnails if need_autogaze and self._autogaze_model is None: return None # Per-video tile/thumbnail counts from SigLIP tensors (always present) siglip_tiles = videos_inputs["pixel_values_videos_tiles"] siglip_thumbs = videos_inputs["pixel_values_videos_thumbnails"] num_tiles_per_video = [t.shape[0] for t in siglip_tiles] num_thumbs_per_video = [t.shape[0] for t in siglip_thumbs] device = torch.device("cpu") autogaze_device = torch.device("cuda") if torch.cuda.is_available() else device # Total patches per frame across all scales num_patches_each_scale = [ (s // self.target_patch_size) ** 2 for s in self.target_scales ] total_patches_per_frame = sum(num_patches_each_scale) # Ensure AutoGaze model is on GPU for inference if need_autogaze: current_device = next(self._autogaze_model.parameters()).device if current_device != autogaze_device: self._autogaze_model = self._autogaze_model.to(autogaze_device) # --- Tiles --- if skip_tiles: total_tiles = sum(num_tiles_per_video) T_tile = siglip_tiles[0].shape[1] per_frame_pos = torch.arange(total_patches_per_frame, device=device, dtype=torch.long) tiles_gazing_pos = per_frame_pos.repeat(T_tile).unsqueeze(0).expand(total_tiles, -1).contiguous() tiles_if_padded = torch.zeros( total_tiles, T_tile * total_patches_per_frame, device=device, dtype=torch.bool ) tiles_num_gazing = torch.full( (total_tiles, T_tile), total_patches_per_frame, device=device, dtype=torch.long ) else: tiles_autogaze = videos_inputs.get("pixel_values_videos_tiles_autogaze") if tiles_autogaze is None: return None all_tiles = torch.cat(tiles_autogaze, dim=0) tiles_gazing_pos, tiles_if_padded, tiles_num_gazing = self._run_autogaze_batched( all_tiles, autogaze_device, device, self.gazing_ratio_tile, self.task_loss_requirement_tile, ) tiles_gazing_pos = self._sort_gazing_pos_per_frame( tiles_gazing_pos, tiles_if_padded, tiles_num_gazing ) # --- Thumbnails --- if skip_thumbnails: total_thumbs = sum(num_thumbs_per_video) per_thumb_pos = torch.arange( total_patches_per_frame, device=device, dtype=torch.long ) thumbs_gazing_pos = per_thumb_pos.unsqueeze(0).expand(total_thumbs, -1).contiguous() thumbs_if_padded = torch.zeros_like(thumbs_gazing_pos, dtype=torch.bool) thumbs_num_gazing = torch.full( (total_thumbs, 1), total_patches_per_frame, device=device, dtype=torch.long, ) else: thumbs_autogaze = videos_inputs.get("pixel_values_videos_thumbnails_autogaze") if thumbs_autogaze is None: return None all_thumbs = torch.cat(thumbs_autogaze, dim=0) thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing = self._run_autogaze_batched( all_thumbs, autogaze_device, device, self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail, ) thumbs_gazing_pos = self._sort_gazing_pos_per_frame( thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing ) # --- Split results back per video --- tiles_gazing_pos_list = list(torch.split(tiles_gazing_pos, num_tiles_per_video, dim=0)) tiles_if_padded_list = list(torch.split(tiles_if_padded, num_tiles_per_video, dim=0)) tiles_num_gazing_list = list(torch.split(tiles_num_gazing, num_tiles_per_video, dim=0)) thumbs_gazing_pos_list = list(torch.split(thumbs_gazing_pos, num_thumbs_per_video, dim=0)) thumbs_if_padded_list = list(torch.split(thumbs_if_padded, num_thumbs_per_video, dim=0)) thumbs_num_gazing_list = list(torch.split(thumbs_num_gazing, num_thumbs_per_video, dim=0)) return { "gazing_pos_tiles": tiles_gazing_pos_list, "num_gazing_each_frame_tiles": tiles_num_gazing_list, "if_padded_gazing_tiles": tiles_if_padded_list, "gazing_pos_thumbnails": thumbs_gazing_pos_list, "num_gazing_each_frame_thumbnails": thumbs_num_gazing_list, "if_padded_gazing_thumbnails": thumbs_if_padded_list, }