NVILA-8B-HD-Video / processing_nvila.py
Danny Yin
release
73b433d
import glob
import os
import re
import tempfile
import urllib.request
from os import PathLike
from typing import cast, Optional
from urllib.parse import urlparse
import cv2
import numpy as np
import torch
import transformers.image_transforms as image_transforms
import transformers.image_utils as image_utils
import transformers.video_utils as video_utils
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2TokenizerFast
from transformers.models.siglip import SiglipImageProcessor, SiglipImageProcessorFast
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from transformers.tokenization_utils_base import BatchEncoding, TextInput
from transformers.video_utils import VideoInput, VideoMetadata
from autogaze.models.autogaze import AutoGaze
from autogaze.models.autogaze import AutoGazeImageProcessor
from autogaze.datasets.video_utils import transform_video_for_pytorch
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
"""Find the closest aspect ratio from a set of target ratios.
Referenced from https://github.com/OpenGVLab/InternVL and llava/mm_utils.py
"""
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
class NVILAProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {} # type: ignore
def _load_video_frames(video_path: str, num_frames: int = 8) -> list[Image]:
"""
Load video frames from a video file path.
Similar to _load_video in llava/utils/media.py
Args:
video_path: Path to the video file or directory of frames
num_frames: Number of frames to extract
Returns:
List of PIL Images representing video frames
"""
vidcap = cv2.VideoCapture(video_path)
if not vidcap.isOpened():
raise ValueError(f"Failed to open video: {video_path}")
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
while frame_count > 0:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
if vidcap.grab():
break
frame_count -= 1
else:
vidcap.release()
raise ValueError(f"Video '{video_path}' has no frames.")
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
frames = {}
for index in indices:
if index in frames:
continue
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
success, frame = vidcap.read()
if not success:
continue
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames[index] = Image.fromarray(frame)
vidcap.release()
frames_to_return = [frames[index] for index in indices if index in frames]
if len(frames_to_return) < num_frames:
if frames_to_return:
frames_to_return = frames_to_return + [frames_to_return[-1]] * (num_frames - len(frames_to_return))
else:
raise ValueError(f"Could not extract any frames from video: {video_path}")
return frames_to_return
class NVILAProcessor(ProcessorMixin):
attributes = [
"image_processor",
"tokenizer",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
_auto_class = "AutoProcessor"
def __init__(
self,
image_processor: SiglipImageProcessor | SiglipImageProcessorFast,
tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast,
chat_template: str | None = None,
autogaze_model_id: str | None = None,
gazing_ratio_tile: list[float] | float = 0.75,
gazing_ratio_thumbnail: float | None = 0.75,
task_loss_requirement_tile: float = 0.7,
task_loss_requirement_thumbnail: float | None = 0.7,
target_scales: list[int] | None = None,
target_patch_size: int | None = None,
max_tiles_image: int = 12,
num_video_frames: int = 8,
max_tiles_video: int = 8,
num_video_frames_thumbnail: int = 8,
mm_projector_shuffle_num: int = 9,
max_batch_size_autogaze: int = 32,
**kwargs,
):
super().__init__(
image_processor,
tokenizer,
chat_template=chat_template,
**kwargs,
)
self.image_processor: SiglipImageProcessor | SiglipImageProcessorFast
self.tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast
# AutoGaze configuration
self.autogaze_model_id = autogaze_model_id or "bfshi/AutoGaze"
self.gazing_ratio_tile = gazing_ratio_tile
self.gazing_ratio_thumbnail = gazing_ratio_thumbnail
self.task_loss_requirement_tile = task_loss_requirement_tile
self.task_loss_requirement_thumbnail = task_loss_requirement_thumbnail
self.target_scales = target_scales or [56, 112, 224, 448]
self.target_patch_size = target_patch_size or 16
# Image / video processing configuration
self.max_tiles_image = max_tiles_image
self.num_video_frames = num_video_frames
self.max_tiles_video = max_tiles_video
self.num_video_frames_thumbnail = num_video_frames_thumbnail
self.mm_projector_shuffle_num = mm_projector_shuffle_num
self.max_batch_size_autogaze = max_batch_size_autogaze
# Load AutoGaze if available
self._autogaze_model = None
self._autogaze_model = AutoGaze.from_pretrained(
self.autogaze_model_id,
device_map=None,
)
self._autogaze_model.to("cuda").eval()
print("AutoGaze loaded successfully in processor")
def __call__(
self,
*,
text: TextInput | list[TextInput],
images: ImageInput | None = None,
videos: VideoInput | None = None,
**kwargs: Unpack[NVILAProcessorKwargs],
) -> BatchFeature:
normalized_text, normalized_images, normalized_videos = self._normalize_inputs(
text=text,
images=images,
videos=videos,
)
images_inputs, image_token_padding_strategy = (
self._preprocess_images(
normalized_images,
**kwargs,
)
if len(normalized_images) > 0
else (BatchFeature(), [])
)
videos_inputs = (
self._preprocess_videos(
normalized_videos,
**kwargs,
)
if len(normalized_videos) > 0
else (BatchFeature(), [])
)
# Run AutoGaze on preprocessed tiles/thumbnails and compute padding
gazing_info = None
video_token_padding_strategy = []
skip_tiles_gaze = self._should_gaze_all_patches(self.gazing_ratio_tile, self.task_loss_requirement_tile)
skip_thumbs_gaze = self._should_gaze_all_patches(self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail)
can_construct_without_autogaze = skip_tiles_gaze and skip_thumbs_gaze
if len(normalized_videos) > 0 and (self._autogaze_model is not None or can_construct_without_autogaze):
gazing_info = self._get_gazing_info_from_videos(videos_inputs)
# Compute video padding strategy from gazing results.
# Because the mm_projector uses TokenShuffle(9), each
# "effective frame" is padded to a multiple of 9 before
# projection, then divided by 9. So total tokens per
# video = sum_over_frames(ceil(non_padded_per_frame / 9)).
shuffle_num = self.mm_projector_shuffle_num
ns_list = videos_inputs["num_spatial_tiles_each_video"]
for vid_idx in range(len(gazing_info["if_padded_gazing_tiles"])):
tiles_if_pad = gazing_info["if_padded_gazing_tiles"][vid_idx] # (num_tiles, N)
tiles_num_gaze = gazing_info["num_gazing_each_frame_tiles"][vid_idx] # (num_tiles, T_tile)
thumbs_if_pad = gazing_info["if_padded_gazing_thumbnails"][vid_idx] # (T_thumb, N')
thumbs_num_gaze = gazing_info["num_gazing_each_frame_thumbnails"][vid_idx] # (T_thumb, 1)
ns = ns_list[vid_idx]
num_tiles = tiles_if_pad.shape[0]
T_tile = tiles_num_gaze.shape[1]
tc = num_tiles // ns # temporal chunks
total_frames = tc * T_tile
# Non-padded count per tile per frame
tile_non_padded = [] # tile_non_padded[tile][frame] = int
for t_idx in range(num_tiles):
frame_sizes = tiles_num_gaze[t_idx].tolist()
frame_pad_segs = tiles_if_pad[t_idx].split(frame_sizes)
tile_non_padded.append(
[int((~seg).sum().item()) for seg in frame_pad_segs]
)
total_tokens = 0
# Tile effective frames (all spatial tiles for one temporal frame)
for g in range(total_frames):
chunk = g // T_tile
f_in_chunk = g % T_tile
frame_count = sum(
tile_non_padded[chunk * ns + s][f_in_chunk]
for s in range(ns)
)
total_tokens += (frame_count + shuffle_num - 1) // shuffle_num
# Thumbnail frames (each is 1 frame)
for th_idx in range(thumbs_if_pad.shape[0]):
frame_sizes = thumbs_num_gaze[th_idx].tolist()
frame_pad_segs = thumbs_if_pad[th_idx].split(frame_sizes)
non_pad = sum(int((~seg).sum().item()) for seg in frame_pad_segs)
total_tokens += (non_pad + shuffle_num - 1) // shuffle_num
video_token_padding_strategy.append([total_tokens])
else:
video_token_padding_strategy = [[(self.num_video_frames + self.num_video_frames_thumbnail) * 118] * len(normalized_videos)]
# Remove AutoGaze-processed pixel values — they were only needed
# for computing gazing_info and should not be sent to the model.
if len(normalized_videos) > 0:
videos_inputs.pop("pixel_values_videos_tiles_autogaze", None)
videos_inputs.pop("pixel_values_videos_thumbnails_autogaze", None)
text_inputs = self._preprocess_text(
normalized_text,
image_token_padding_strategy=image_token_padding_strategy,
video_token_padding_strategy=video_token_padding_strategy,
**kwargs,
)
# Combine all inputs
batch_feature = BatchFeature(
{
**text_inputs,
**images_inputs,
**videos_inputs,
}
)
# Attach gazing_info so the model can use it downstream
if gazing_info is not None:
batch_feature["gazing_info"] = gazing_info
return batch_feature
def batch_decode(self, *args, **kwargs) -> list[str]:
return self.tokenizer.batch_decode(*args, **kwargs)
def _normalize_inputs(
self,
*,
text: TextInput | list[TextInput],
images: ImageInput | None,
videos: VideoInput | None,
) -> tuple[list[str], list[Image], list[list[Image]]]:
if isinstance(text, list):
normalized_text = text
else:
normalized_text = [text]
if images is not None and images != []:
image_flat_list = cast(list, image_utils.make_flat_list_of_images(images))
normalized_images = [cast(Image, image_transforms.to_pil_image(image)) for image in image_flat_list]
else:
normalized_images = []
if videos is not None and videos != []:
# Handle video inputs - can be file paths (str) or lists of PIL Images
# videos can be a single item or a list
if not isinstance(videos, (list, tuple)):
videos = [videos]
normalized_videos = []
# Use num_video_frames from processor config
num_frames = self.num_video_frames
for video_input in videos:
if isinstance(video_input, str):
parsed = urlparse(video_input)
if parsed.scheme in ("http", "https"):
suffix = os.path.splitext(parsed.path)[1] or ".mp4"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
try:
urllib.request.urlretrieve(video_input, tmp.name)
video_frames = _load_video_frames(tmp.name, num_frames=num_frames)
finally:
tmp.close()
os.unlink(tmp.name)
else:
video_frames = _load_video_frames(video_input, num_frames=num_frames)
normalized_videos.append(video_frames)
elif isinstance(video_input, (list, tuple)):
# If it's already a list of images, convert them to PIL Images
normalized_videos.append([
cast(Image, image_transforms.to_pil_image(image)) for image in video_input
])
else:
# Try to use video_utils for other types
try:
video_list = cast(list[list], video_utils.make_batched_videos([video_input]))
normalized_videos.extend([
[cast(Image, image_transforms.to_pil_image(image)) for image in video]
for video in video_list
])
except Exception:
raise ValueError(
f"Unsupported video input type: {type(video_input)}. "
"Expected str (file path) or list of PIL Images."
)
else:
normalized_videos = []
return normalized_text, normalized_images, normalized_videos
def _preprocess_images(
self,
images: list[Image],
**kwargs: Unpack[NVILAProcessorKwargs],
) -> tuple[BatchFeature, list[list[int]]]:
"""Preprocess images into spatial tiles plus a thumbnail.
Each image is split into a grid of spatial tiles whose count is at
most ``max_tiles_image``. A thumbnail (the whole image resized to
``image_size × image_size``) is appended. Every tile / thumbnail
is a single-frame "video" of shape ``(1, C, H, W)``. No AutoGaze
is applied — all patches are kept.
Returns:
A tuple ``(images_inputs, padding_strategy)`` where
``images_inputs`` is a ``BatchFeature`` with:
- ``"pixel_values_images_tiles"`` – list of tensors, one per
image, each ``(num_tiles_i, 1, C, H, W)``.
- ``"pixel_values_images_thumbnails"`` – list of tensors, one
per image, each ``(1, 1, C, H, W)``.
- ``"num_spatial_tiles_each_image"`` – list of ints.
``padding_strategy`` is a list (one per image) of
``[total_tokens]`` used for text-token padding.
"""
merged_kwargs = self._merge_kwargs(
NVILAProcessorKwargs, # type: ignore
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if hasattr(self.image_processor, "size"):
image_size = self.image_processor.size.get("height", 392)
else:
image_size = 392
shuffle_num = self.mm_projector_shuffle_num
num_patches_each_scale = [
(s // self.target_patch_size) ** 2 for s in self.target_scales
]
total_patches_per_frame = sum(num_patches_each_scale)
pixel_values_images_tiles: list[torch.Tensor] = []
pixel_values_images_thumbnails: list[torch.Tensor] = []
num_spatial_tiles_each_image: list[int] = []
padding_strategy: list[list[int]] = []
for image in images:
image = image.convert("RGB")
orig_width, orig_height = image.size
max_spatial_tiles = max(self.max_tiles_image, 1)
aspect_ratio = orig_width / orig_height
target_ratios = {
(i, j)
for n in range(1, max_spatial_tiles + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if 1 <= i * j <= max_spatial_tiles
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = _find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
num_tiles = target_aspect_ratio[0] * target_aspect_ratio[1]
num_cols = target_aspect_ratio[0]
resized = image.resize((target_width, target_height))
# Spatial tiles + thumbnail (whole image resized)
all_tile_images: list[Image] = []
for tile_idx in range(num_tiles):
col = tile_idx % num_cols
row = tile_idx // num_cols
box = (
col * image_size,
row * image_size,
(col + 1) * image_size,
(row + 1) * image_size,
)
all_tile_images.append(resized.crop(box))
thumbnail = image.resize((image_size, image_size))
all_images_for_siglip = all_tile_images + [thumbnail]
# SigLIP: process tiles + thumbnail at once → (num_tiles+1, C, H, W)
siglip_processed = self.image_processor(
all_images_for_siglip, **merged_kwargs["images_kwargs"],
)["pixel_values"]
if not isinstance(siglip_processed, torch.Tensor):
siglip_processed = torch.tensor(np.array(siglip_processed))
# Split into tiles and thumbnail, add temporal dim
tiles_pv = siglip_processed[:num_tiles].unsqueeze(1) # (num_tiles, 1, C, H, W)
thumb_pv = siglip_processed[num_tiles:].unsqueeze(1) # (1, 1, C, H, W)
pixel_values_images_tiles.append(tiles_pv)
pixel_values_images_thumbnails.append(thumb_pv)
num_spatial_tiles_each_image.append(num_tiles)
# Padding: tiles effective frame + thumbnail effective frame
tiles_tokens = (num_tiles * total_patches_per_frame + shuffle_num - 1) // shuffle_num
thumb_tokens = (total_patches_per_frame + shuffle_num - 1) // shuffle_num
padding_strategy.append([tiles_tokens + thumb_tokens])
images_inputs = BatchFeature({
"pixel_values_images_tiles": pixel_values_images_tiles,
"pixel_values_images_thumbnails": pixel_values_images_thumbnails,
"num_spatial_tiles_each_image": num_spatial_tiles_each_image,
})
return images_inputs, padding_strategy
def _preprocess_text(
self,
text: list[str],
*,
image_token_padding_strategy: list[list[int]],
video_token_padding_strategy: list[list[int]],
**kwargs: Unpack[NVILAProcessorKwargs],
) -> BatchEncoding:
# Apply chat template to text
messages = [[
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": t}
] for t in text]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Pad media tokens.
assert isinstance(self.tokenizer.image_token, str)
assert isinstance(self.tokenizer.video_token, str)
for media_token, padding_strategy in (
(self.tokenizer.image_token, image_token_padding_strategy),
(self.tokenizer.video_token, video_token_padding_strategy),
):
assert sum([s.count(media_token) for s in text]) == len(padding_strategy)
# Pad to number of tiles.
pad_lens = [len(x) for x in padding_strategy]
text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text]
# Pad to number of features.
pad_lens = [y for x in padding_strategy for y in x]
text = [re.sub(rf"({re.escape(media_token)})", lambda _: media_token * pad_lens.pop(0), s) for s in text]
merged_kwargs = self._merge_kwargs(
NVILAProcessorKwargs, # type: ignore
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
text_inputs = self.tokenizer(
text=text,
**merged_kwargs["text_kwargs"],
)
return text_inputs
def _preprocess_videos(
self,
videos: list[list[Image]],
**kwargs: Unpack[NVILAProcessorKwargs],
) -> BatchFeature:
"""Preprocess videos into spatiotemporal tiles and thumbnails.
Each video is split into a grid of spatiotemporal tiles and a set of
low-resolution thumbnail frames. Both SigLIP-processed and
AutoGaze-processed copies are produced.
Spatial tiling
Every frame is resized so that its dimensions become a multiple of
``image_size`` (from the SigLIP image processor) and then cropped
into ``(cols, rows)`` spatial tiles, where ``cols * rows <=
max_tiles_video``. The best ``(cols, rows)`` is chosen by matching
the original frame aspect ratio (same logic as
``dynamic_preprocess`` in ``llava/mm_utils.py``).
Temporal chunking
The T sampled frames are divided into ``T // max_num_frames``
consecutive chunks of ``max_num_frames`` frames each, where
``max_num_frames`` comes from the AutoGaze model config.
``T`` must be divisible by ``max_num_frames``.
Tile ordering
Tiles are ordered **temporal-chunk-first**: all spatial tiles for
the first temporal chunk, then all spatial tiles for the second
temporal chunk, and so on.
Thumbnails
Each frame is also resized to ``image_size × image_size`` to form a
thumbnail. If the number of frames exceeds
``num_video_frames_thumbnail``, thumbnails are uniformly subsampled
(every k-th frame) to that count. Each thumbnail is treated as a
single-frame video (temporal dim = 1).
Args:
videos: List of videos, where each video is a list of PIL Images
(one per frame).
**kwargs: Additional keyword arguments forwarded to the SigLIP
image processor.
Returns:
A tuple ``(videos_inputs, padding_strategy)`` where
``videos_inputs`` is a ``BatchFeature`` dict with the keys:
- ``"pixel_values_videos_tiles"`` – list of tensors, one per video.
Each tensor has shape ``(num_tiles, T_tile, C, H, W)`` where
``num_tiles = num_spatial_tiles * temporal_chunks``,
``T_tile = max_num_frames`` (from AutoGaze config),
and ``H = W = image_size``.
Processed by the SigLIP image processor.
- ``"pixel_values_videos_thumbnails"`` – list of tensors, one per
video. Each tensor has shape
``(T_thumbnail, 1, C, H, W)`` where ``T_thumbnail <=
num_video_frames_thumbnail`` and ``H = W = image_size``.
Processed by the SigLIP image processor.
- ``"pixel_values_videos_tiles_autogaze"`` *(optional)* – same
structure as ``pixel_values_videos_tiles`` but processed by the
AutoGaze ``transform_video_for_pytorch`` transform.
Only present when AutoGaze is available.
- ``"pixel_values_videos_thumbnails_autogaze"`` *(optional)* – same
structure as ``pixel_values_videos_thumbnails`` but processed by
the AutoGaze transform. Only present when AutoGaze is available.
``padding_strategy`` is a list (one entry per video) of lists of
ints used for text-token padding. Currently a placeholder; the
final strategy depends on downstream gazing results.
"""
merged_kwargs = self._merge_kwargs(
NVILAProcessorKwargs, # type: ignore
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Get siglip image size (tile spatial resolution)
if hasattr(self.image_processor, "size"):
image_size = self.image_processor.size.get("height", 392)
else:
image_size = 392
# Get AutoGaze max_num_frames for temporal chunking
if self._autogaze_model is not None:
autogaze_max_num_frames = self._autogaze_model.config.max_num_frames
else:
autogaze_max_num_frames = 16 # default
# Load AutoGaze transform if available
autogaze_transform = None
largest_scale = max(self.target_scales)
autogaze_transform = AutoGazeImageProcessor.from_pretrained(
self.autogaze_model_id,
size=(largest_scale, largest_scale),
)
pixel_values_videos_tiles = []
pixel_values_videos_thumbnails = []
pixel_values_videos_tiles_autogaze = []
pixel_values_videos_thumbnails_autogaze = []
num_spatial_tiles_each_video = []
for video in videos:
video = [img.convert("RGB") for img in video]
num_frames = len(video)
orig_width, orig_height = video[0].size
# --- Temporal chunking ---
temporal_chunks = num_frames // autogaze_max_num_frames
assert temporal_chunks >= 1 and num_frames % autogaze_max_num_frames == 0, (
f"Number of frames ({num_frames}) must be divisible by "
f"AutoGaze max_num_frames ({autogaze_max_num_frames})"
)
# --- Spatial tiling ---
# max_tiles_video directly controls the max number of spatial tiles
max_spatial_tiles = max(self.max_tiles_video, 1)
# Use dynamic_preprocess-style approach for finding best spatial aspect ratio
aspect_ratio = orig_width / orig_height
target_ratios = {
(i, j)
for n in range(1, max_spatial_tiles + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if 1 <= i * j <= max_spatial_tiles
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = _find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
target_width = image_size * target_aspect_ratio[0] # cols * image_size
target_height = image_size * target_aspect_ratio[1] # rows * image_size
num_spatial_tiles = target_aspect_ratio[0] * target_aspect_ratio[1]
num_cols = target_aspect_ratio[0]
# --- Build per-frame spatial tiles and thumbnails ---
# spatial_tile_frames[spatial_idx] = list of T PIL Images
spatial_tile_frames = [[] for _ in range(num_spatial_tiles)]
thumbnail_frames = []
for frame in video:
# Resize frame for spatial tiling
resized_frame = frame.resize((target_width, target_height))
# Split into spatial tiles
for tile_idx in range(num_spatial_tiles):
col = tile_idx % num_cols
row = tile_idx // num_cols
box = (
col * image_size,
row * image_size,
(col + 1) * image_size,
(row + 1) * image_size,
)
tile = resized_frame.crop(box)
spatial_tile_frames[tile_idx].append(tile)
# Thumbnail: resize whole frame to image_size x image_size
thumbnail = frame.resize((image_size, image_size))
thumbnail_frames.append(thumbnail)
# --- Assemble spatiotemporal tiles ---
# Collect all tile images in flat order: temporal chunk (outer) ×
# spatial tile (inner) × frame-within-chunk (innermost).
num_tiles = temporal_chunks * num_spatial_tiles
T_tile = autogaze_max_num_frames
all_tile_images = []
for t_chunk in range(temporal_chunks):
for spatial_idx in range(num_spatial_tiles):
start = t_chunk * T_tile
end = start + T_tile
all_tile_images.extend(spatial_tile_frames[spatial_idx][start:end])
# SigLIP: process all tile images at once → (num_tiles * T_tile, C, H, W)
siglip_processed = self.image_processor(
all_tile_images, **merged_kwargs["images_kwargs"],
)["pixel_values"]
if not isinstance(siglip_processed, torch.Tensor):
siglip_processed = torch.tensor(np.array(siglip_processed))
video_tiles_siglip = siglip_processed.reshape(num_tiles, T_tile, *siglip_processed.shape[1:])
pixel_values_videos_tiles.append(video_tiles_siglip)
# AutoGaze transform: process all tile images at once
if autogaze_transform is not None:
all_tile_np = np.stack([np.array(f) for f in all_tile_images]) # (num_tiles * T_tile, H, W, 3)
autogaze_processed = transform_video_for_pytorch(all_tile_np, autogaze_transform)
video_tiles_autogaze = autogaze_processed.reshape(num_tiles, T_tile, *autogaze_processed.shape[1:])
pixel_values_videos_tiles_autogaze.append(video_tiles_autogaze)
# --- Assemble thumbnails ---
# Subsample thumbnails if needed (keep every k-th frame)
if len(thumbnail_frames) > self.num_video_frames_thumbnail:
step = len(thumbnail_frames) // self.num_video_frames_thumbnail
sampled_thumbnail_frames = thumbnail_frames[::step][: self.num_video_frames_thumbnail]
else:
sampled_thumbnail_frames = thumbnail_frames
T_thumb = len(sampled_thumbnail_frames)
# SigLIP: process all thumbnail images at once → (T_thumb, C, H, W)
siglip_processed = self.image_processor(
sampled_thumbnail_frames, **merged_kwargs["images_kwargs"],
)["pixel_values"]
if not isinstance(siglip_processed, torch.Tensor):
siglip_processed = torch.tensor(np.array(siglip_processed))
# Each thumbnail is a single-frame video → (T_thumb, 1, C, H, W)
video_thumbnails_siglip = siglip_processed.unsqueeze(1)
pixel_values_videos_thumbnails.append(video_thumbnails_siglip)
# AutoGaze transform: process all thumbnail images at once
if autogaze_transform is not None:
all_thumb_np = np.stack([np.array(f) for f in sampled_thumbnail_frames]) # (T_thumb, H, W, 3)
autogaze_processed = transform_video_for_pytorch(all_thumb_np, autogaze_transform)
video_thumbnails_autogaze = autogaze_processed.unsqueeze(1) # (T_thumb, 1, C, H, W)
pixel_values_videos_thumbnails_autogaze.append(video_thumbnails_autogaze)
num_spatial_tiles_each_video.append(num_spatial_tiles)
print(
f"Video tiling: {num_frames} frames @ {orig_width}x{orig_height} → "
f"{num_spatial_tiles} spatial × {temporal_chunks} temporal = "
f"{num_spatial_tiles * temporal_chunks} tiles, each "
f"{autogaze_max_num_frames}×{image_size}×{image_size}; "
f"{len(sampled_thumbnail_frames)} thumbnail frames"
)
# Build output BatchFeature
videos_inputs = BatchFeature(
{
"pixel_values_videos_tiles": pixel_values_videos_tiles,
"pixel_values_videos_thumbnails": pixel_values_videos_thumbnails,
"num_spatial_tiles_each_video": num_spatial_tiles_each_video,
}
)
if pixel_values_videos_tiles_autogaze:
videos_inputs["pixel_values_videos_tiles_autogaze"] = pixel_values_videos_tiles_autogaze
if pixel_values_videos_thumbnails_autogaze:
videos_inputs["pixel_values_videos_thumbnails_autogaze"] = pixel_values_videos_thumbnails_autogaze
return videos_inputs
@staticmethod
def _should_gaze_all_patches(gazing_ratio, task_loss_requirement) -> bool:
"""Return True when the gazing config means every patch is kept.
This is the case when ``gazing_ratio`` is ``None`` (no gazing at all),
or when ``gazing_ratio == 1`` (keep 100 %) **and**
``task_loss_requirement is None`` (no adaptive pruning).
"""
if gazing_ratio is None:
return True
if task_loss_requirement is not None:
return False
if isinstance(gazing_ratio, (list, tuple)):
return all(r == 1 for r in gazing_ratio)
return gazing_ratio == 1
@staticmethod
def _sort_gazing_pos_per_frame(
gazing_pos: torch.Tensor,
if_padded: torch.Tensor,
num_gazing_each_frame: torch.Tensor,
) -> torch.Tensor:
"""Sort non-padded gazing positions in ascending order within each frame.
Padded positions are left untouched at the end of each frame's segment
so that the total count (padded + non-padded) per frame is unchanged.
Args:
gazing_pos: ``(B, N)`` tensor of gazing patch indices.
if_padded: ``(B, N)`` bool tensor (``True`` = padded / dummy).
num_gazing_each_frame: ``(B, T)`` tensor giving the number of
gazing positions (padded + non-padded) for each frame.
Returns:
A new ``(B, N)`` tensor with the same values as *gazing_pos*
except that the non-padded entries within every frame are sorted.
"""
sorted_pos = gazing_pos.clone()
B, _ = gazing_pos.shape
T = num_gazing_each_frame.shape[1]
for b in range(B):
offset = 0
for t in range(T):
count = int(num_gazing_each_frame[b, t].item())
frame_pos = gazing_pos[b, offset : offset + count]
frame_pad = if_padded[b, offset : offset + count]
# Indices of non-padded (real) positions within the frame segment
real_mask = ~frame_pad
real_pos = frame_pos[real_mask]
# Sort the real positions
real_pos_sorted = real_pos.sort()[0]
# Write sorted values back at the correct locations
real_indices = real_mask.nonzero(as_tuple=True)[0]
sorted_pos[b, offset + real_indices] = real_pos_sorted
offset += count
return sorted_pos
def _run_autogaze_batched(
self,
all_videos: torch.Tensor,
autogaze_device: torch.device,
cpu_device: torch.device,
gazing_ratio,
task_loss_requirement,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run AutoGaze in minibatches and return combined results on CPU.
Different minibatches may produce different per-frame gazing counts
(e.g. when ``task_loss_requirement`` triggers adaptive pruning).
This method pads each frame's segment to the *maximum* count across
all minibatches so that the results can be concatenated along the
batch dimension.
Args:
all_videos: ``(B, T, C, H, W)`` tensor of videos to process.
autogaze_device: Device where AutoGaze runs (typically CUDA).
cpu_device: Device for the returned tensors (typically CPU).
gazing_ratio: Gazing ratio to pass to AutoGaze.
task_loss_requirement: Task loss requirement to pass to AutoGaze.
Returns:
A tuple ``(gazing_pos, if_padded, num_gazing)`` where
- ``gazing_pos`` is ``(B, N_max)`` on *cpu_device*
- ``if_padded`` is ``(B, N_max)`` bool on *cpu_device*
- ``num_gazing`` is ``(B, T)`` on *cpu_device*
``N_max = sum(max_per_frame)`` where ``max_per_frame[t]`` is the
largest per-frame count across all minibatches.
"""
total = all_videos.shape[0]
bs = self.max_batch_size_autogaze
batch_results: list[dict] = []
with torch.inference_mode():
for start in range(0, total, bs):
batch = all_videos[start : start + bs]
gaze = self._autogaze_model(
{"video": batch.to(autogaze_device)},
gazing_ratio=gazing_ratio,
task_loss_requirement=task_loss_requirement,
target_scales=self.target_scales,
target_patch_size=self.target_patch_size,
)
ng = gaze["num_gazing_each_frame"]
if isinstance(ng, list):
ng = torch.tensor(ng, device=cpu_device, dtype=torch.long)
elif not isinstance(ng, torch.Tensor):
ng = torch.tensor(ng, device=cpu_device, dtype=torch.long)
else:
ng = ng.to(cpu_device)
if ng.dim() == 2:
ng = ng[0]
batch_results.append({
"gazing_pos": gaze["gazing_pos"].to(cpu_device),
"if_padded": gaze["if_padded_gazing"].to(cpu_device),
"num_gazing": ng,
"batch_size": batch.shape[0],
})
# Fast path: single minibatch — no cross-batch padding needed
if len(batch_results) == 1:
r = batch_results[0]
num_gazing = r["num_gazing"].unsqueeze(0).expand(total, -1).contiguous()
return r["gazing_pos"], r["if_padded"], num_gazing
# Compute the max per-frame count across all minibatches
all_ng = torch.stack([r["num_gazing"] for r in batch_results], dim=0) # (num_minibatches, T)
max_per_frame = all_ng.max(dim=0).values # (T,)
max_N = int(max_per_frame.sum().item())
T = max_per_frame.shape[0]
padded_pos_list = []
padded_mask_list = []
for r in batch_results:
src_pos = r["gazing_pos"] # (mini_B, N_src)
src_pad = r["if_padded"] # (mini_B, N_src)
src_ng = r["num_gazing"] # (T,)
mini_B = r["batch_size"]
if int(src_ng.sum().item()) == max_N:
padded_pos_list.append(src_pos)
padded_mask_list.append(src_pad)
continue
dst_pos = torch.zeros(mini_B, max_N, device=cpu_device, dtype=src_pos.dtype)
dst_pad = torch.ones(mini_B, max_N, device=cpu_device, dtype=torch.bool)
src_off = 0
dst_off = 0
for t in range(T):
sc = int(src_ng[t].item())
dc = int(max_per_frame[t].item())
dst_pos[:, dst_off : dst_off + sc] = src_pos[:, src_off : src_off + sc]
dst_pad[:, dst_off : dst_off + sc] = src_pad[:, src_off : src_off + sc]
src_off += sc
dst_off += dc
padded_pos_list.append(dst_pos)
padded_mask_list.append(dst_pad)
gazing_pos = torch.cat(padded_pos_list, dim=0)
if_padded = torch.cat(padded_mask_list, dim=0)
num_gazing = max_per_frame.unsqueeze(0).expand(total, -1).contiguous()
return gazing_pos, if_padded, num_gazing
def _get_gazing_info_from_videos(
self,
videos_inputs: BatchFeature,
) -> Optional[dict]:
"""Run AutoGaze on the preprocessed tiles and thumbnails.
All tiles from all videos are batched together (they share the same
temporal dimension ``T_tile``). Similarly, all thumbnails are batched
together (temporal dim = 1). AutoGaze is run once on each batch and
the results are split back per-video.
When a gazing ratio is 1 and the corresponding task_loss_requirement is
None (or gazing_ratio is None), all patches are kept and AutoGaze is
skipped for that component. If both tiles and thumbnails meet this
condition, AutoGaze is not invoked at all.
Args:
videos_inputs: The ``BatchFeature`` returned by
``_preprocess_videos``, which must contain the keys
``pixel_values_videos_tiles_autogaze`` and
``pixel_values_videos_thumbnails_autogaze`` (unless the
corresponding component can skip AutoGaze).
Returns:
A dict with the following keys (or ``None`` if AutoGaze is
unavailable or the required inputs are missing):
- ``"gazing_pos_tiles"`` – list of tensors, one per video, each
shaped ``(num_tiles_i, N)``.
- ``"num_gazing_each_frame_tiles"`` – list of tensors, one per
video, each shaped ``(num_tiles_i, T_tile)``.
- ``"if_padded_gazing_tiles"`` – list of bool tensors, one per
video, each shaped ``(num_tiles_i, N)``.
- ``"gazing_pos_thumbnails"`` – list of tensors, one per video,
each shaped ``(T_thumb_i, N')``.
- ``"num_gazing_each_frame_thumbnails"`` – list of tensors, one per
video, each shaped ``(T_thumb_i, 1)``.
- ``"if_padded_gazing_thumbnails"`` – list of bool tensors, one per
video, each shaped ``(T_thumb_i, N')``.
"""
skip_tiles = self._should_gaze_all_patches(
self.gazing_ratio_tile, self.task_loss_requirement_tile
)
skip_thumbnails = self._should_gaze_all_patches(
self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail
)
need_autogaze = not skip_tiles or not skip_thumbnails
if need_autogaze and self._autogaze_model is None:
return None
# Per-video tile/thumbnail counts from SigLIP tensors (always present)
siglip_tiles = videos_inputs["pixel_values_videos_tiles"]
siglip_thumbs = videos_inputs["pixel_values_videos_thumbnails"]
num_tiles_per_video = [t.shape[0] for t in siglip_tiles]
num_thumbs_per_video = [t.shape[0] for t in siglip_thumbs]
device = torch.device("cpu")
autogaze_device = torch.device("cuda") if torch.cuda.is_available() else device
# Total patches per frame across all scales
num_patches_each_scale = [
(s // self.target_patch_size) ** 2 for s in self.target_scales
]
total_patches_per_frame = sum(num_patches_each_scale)
# Ensure AutoGaze model is on GPU for inference
if need_autogaze:
current_device = next(self._autogaze_model.parameters()).device
if current_device != autogaze_device:
self._autogaze_model = self._autogaze_model.to(autogaze_device)
# --- Tiles ---
if skip_tiles:
total_tiles = sum(num_tiles_per_video)
T_tile = siglip_tiles[0].shape[1]
per_frame_pos = torch.arange(total_patches_per_frame, device=device, dtype=torch.long)
tiles_gazing_pos = per_frame_pos.repeat(T_tile).unsqueeze(0).expand(total_tiles, -1).contiguous()
tiles_if_padded = torch.zeros(
total_tiles, T_tile * total_patches_per_frame, device=device, dtype=torch.bool
)
tiles_num_gazing = torch.full(
(total_tiles, T_tile), total_patches_per_frame, device=device, dtype=torch.long
)
else:
tiles_autogaze = videos_inputs.get("pixel_values_videos_tiles_autogaze")
if tiles_autogaze is None:
return None
all_tiles = torch.cat(tiles_autogaze, dim=0)
tiles_gazing_pos, tiles_if_padded, tiles_num_gazing = self._run_autogaze_batched(
all_tiles, autogaze_device, device,
self.gazing_ratio_tile, self.task_loss_requirement_tile,
)
tiles_gazing_pos = self._sort_gazing_pos_per_frame(
tiles_gazing_pos, tiles_if_padded, tiles_num_gazing
)
# --- Thumbnails ---
if skip_thumbnails:
total_thumbs = sum(num_thumbs_per_video)
per_thumb_pos = torch.arange(
total_patches_per_frame, device=device, dtype=torch.long
)
thumbs_gazing_pos = per_thumb_pos.unsqueeze(0).expand(total_thumbs, -1).contiguous()
thumbs_if_padded = torch.zeros_like(thumbs_gazing_pos, dtype=torch.bool)
thumbs_num_gazing = torch.full(
(total_thumbs, 1), total_patches_per_frame,
device=device, dtype=torch.long,
)
else:
thumbs_autogaze = videos_inputs.get("pixel_values_videos_thumbnails_autogaze")
if thumbs_autogaze is None:
return None
all_thumbs = torch.cat(thumbs_autogaze, dim=0)
thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing = self._run_autogaze_batched(
all_thumbs, autogaze_device, device,
self.gazing_ratio_thumbnail, self.task_loss_requirement_thumbnail,
)
thumbs_gazing_pos = self._sort_gazing_pos_per_frame(
thumbs_gazing_pos, thumbs_if_padded, thumbs_num_gazing
)
# --- Split results back per video ---
tiles_gazing_pos_list = list(torch.split(tiles_gazing_pos, num_tiles_per_video, dim=0))
tiles_if_padded_list = list(torch.split(tiles_if_padded, num_tiles_per_video, dim=0))
tiles_num_gazing_list = list(torch.split(tiles_num_gazing, num_tiles_per_video, dim=0))
thumbs_gazing_pos_list = list(torch.split(thumbs_gazing_pos, num_thumbs_per_video, dim=0))
thumbs_if_padded_list = list(torch.split(thumbs_if_padded, num_thumbs_per_video, dim=0))
thumbs_num_gazing_list = list(torch.split(thumbs_num_gazing, num_thumbs_per_video, dim=0))
return {
"gazing_pos_tiles": tiles_gazing_pos_list,
"num_gazing_each_frame_tiles": tiles_num_gazing_list,
"if_padded_gazing_tiles": tiles_if_padded_list,
"gazing_pos_thumbnails": thumbs_gazing_pos_list,
"num_gazing_each_frame_thumbnails": thumbs_num_gazing_list,
"if_padded_gazing_thumbnails": thumbs_if_padded_list,
}