NVILA-8B-HD-Video / modeling_nvila.py
Danny Yin
release
73b433d
raw
history blame
25.9 kB
import contextlib
import sys
from pathlib import Path
from typing import Optional
import einops
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from transformers import Qwen2ForCausalLM
from transformers.cache_utils import Cache
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from autogaze.vision_encoders.siglip.modeling_siglip import SiglipVisionModel
from .configuration_nvila import NVILAConfig
MM_HIDDEN_SIZE = 1152
class TokenShuffle(nn.Module):
"""Token shuffle module that groups tokens and concatenates their features."""
def __init__(self, shuffle_num: int):
super().__init__()
self.shuffle_num = shuffle_num
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, N, C) tensor where B is batch size, N is sequence length, C is hidden size
Returns:
(B, N', C * shuffle_num) tensor where N' = ceil(N / shuffle_num)
"""
# x: (B, N, C)
if x.shape[1] % self.shuffle_num != 0:
# Pad with the last token to make sequence length divisible by shuffle_num
pad_size = self.shuffle_num - (x.shape[1] % self.shuffle_num)
x = torch.cat([x, x[:, -1:].repeat(1, pad_size, 1)], dim=1)
# Rearrange: (B, N, C) -> (B, N//k, k*C) where k = shuffle_num
return einops.rearrange(x, "b (n k) c -> b n (k c)", k=self.shuffle_num)
class NVILAMultiModalProjector(nn.Module):
"""Multi-modal projector using mlp_shuffle_9 architecture."""
def __init__(self, config: NVILAConfig):
super().__init__()
self.layers = nn.Sequential(
TokenShuffle(9),
nn.LayerNorm(MM_HIDDEN_SIZE * 9),
nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
nn.GELU(),
nn.LayerNorm(MM_HIDDEN_SIZE * 3),
nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
nn.GELU(),
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
class NVILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
config_class = NVILAConfig
base_model_prefix: str = "llm"
_auto_class = "AutoModel"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: NVILAConfig):
super().__init__(config)
self.config: NVILAConfig
@contextlib.contextmanager
def default_torch_dtype(dtype):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(original_dtype)
with default_torch_dtype(config.torch_dtype):
self.vision_tower = SiglipVisionModel(config.vision_config)
self.mm_projector = NVILAMultiModalProjector(config)
self.llm = Qwen2ForCausalLM(config.text_config)
self.post_init()
def forward(
self,
*,
input_ids: Tensor | None = None,
inputs_embeds: Tensor | None = None,
pixel_values: Tensor | None = None,
pixel_values_images_tiles: list[Tensor] | None = None,
pixel_values_images_thumbnails: list[Tensor] | None = None,
num_spatial_tiles_each_image: list[int] | None = None,
pixel_values_videos_tiles: list[Tensor] | None = None,
pixel_values_videos_thumbnails: list[Tensor] | None = None,
gazing_info: dict | None = None,
num_spatial_tiles_each_video: list[int] | None = None,
**kwargs,
) -> CausalLMOutputWithPast:
assert (input_ids is None) != (
inputs_embeds is None
), "Exactly one of `input_ids` or `inputs_embeds` must be specified."
# Pop processor-only fields that the LLM should not see
kwargs.pop("pixel_values_videos_tiles_autogaze", None)
kwargs.pop("pixel_values_videos_thumbnails_autogaze", None)
kwargs.pop("pixel_values_videos", None)
if input_ids is not None and torch.any(
torch.isin(
input_ids,
torch.tensor(
[self.config.image_token_id, self.config.video_token_id],
device=input_ids.device,
),
).any()
): # Prefill
# Extract fields from kwargs if not passed as explicit args
if gazing_info is None:
gazing_info = kwargs.pop("gazing_info", None)
if pixel_values_images_tiles is None:
pixel_values_images_tiles = kwargs.pop("pixel_values_images_tiles", None)
if pixel_values_images_thumbnails is None:
pixel_values_images_thumbnails = kwargs.pop("pixel_values_images_thumbnails", None)
if num_spatial_tiles_each_image is None:
num_spatial_tiles_each_image = kwargs.pop("num_spatial_tiles_each_image", None)
if pixel_values_videos_tiles is None:
pixel_values_videos_tiles = kwargs.pop("pixel_values_videos_tiles", None)
if pixel_values_videos_thumbnails is None:
pixel_values_videos_thumbnails = kwargs.pop("pixel_values_videos_thumbnails", None)
if num_spatial_tiles_each_video is None:
num_spatial_tiles_each_video = kwargs.pop("num_spatial_tiles_each_video", None)
inputs_embeds = self._embed(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_images_tiles=pixel_values_images_tiles,
pixel_values_images_thumbnails=pixel_values_images_thumbnails,
num_spatial_tiles_each_image=num_spatial_tiles_each_image,
pixel_values_videos_tiles=pixel_values_videos_tiles,
pixel_values_videos_thumbnails=pixel_values_videos_thumbnails,
gazing_info=gazing_info,
num_spatial_tiles_each_video=num_spatial_tiles_each_video,
)
input_ids = None
outputs = self.llm(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
return outputs
def _embed(
self,
*,
input_ids: Tensor,
pixel_values: Tensor | None,
pixel_values_images_tiles: list[Tensor] | None,
pixel_values_images_thumbnails: list[Tensor] | None,
num_spatial_tiles_each_image: list[int] | None,
pixel_values_videos_tiles: list[Tensor] | None,
pixel_values_videos_thumbnails: list[Tensor] | None,
gazing_info: dict | None = None,
num_spatial_tiles_each_video: list[int] | None = None,
) -> Tensor:
inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids)
# Handle images
if pixel_values_images_tiles is not None and len(pixel_values_images_tiles) > 0:
per_image_features = self._encode_images(
pixel_values_images_tiles=pixel_values_images_tiles,
pixel_values_images_thumbnails=pixel_values_images_thumbnails,
num_spatial_tiles_each_image=num_spatial_tiles_each_image,
)
all_features = torch.cat(per_image_features, dim=0)
image_token_mask = input_ids == self.config.image_token_id
num_image_tokens = image_token_mask.sum().item()
num_image_features = all_features.shape[0]
assert num_image_features == num_image_tokens, (
f"Number of image features {num_image_features} does not match "
f"number of image tokens {num_image_tokens}"
)
inputs_embeds[image_token_mask] = all_features.to(inputs_embeds.dtype)
# Handle videos
if pixel_values_videos_tiles is not None:
per_video_features = self._encode_vision(
pixel_values_videos_tiles=pixel_values_videos_tiles,
pixel_values_videos_thumbnails=pixel_values_videos_thumbnails,
gazing_info=gazing_info,
num_spatial_tiles_each_video=num_spatial_tiles_each_video,
)
# per_video_features: list of (num_tokens_i, llm_hidden) tensors
all_features = torch.cat(per_video_features, dim=0)
# Match vision features to video tokens
video_token_mask = input_ids == self.config.video_token_id
num_video_tokens = video_token_mask.sum().item()
num_vision_features = all_features.shape[0]
assert num_vision_features == num_video_tokens, (
f"Number of vision features {num_vision_features} does not match "
f"number of video tokens {num_video_tokens}"
)
inputs_embeds[video_token_mask] = all_features.to(inputs_embeds.dtype)
return inputs_embeds
def _make_default_gazing_info(
self,
total_items: int,
T: int,
device: torch.device,
) -> dict:
"""Create gazing_info that gazes at every patch (no reduction).
Args:
total_items: Number of items (tiles or thumbnails) in the batch.
T: Temporal frames per item.
device: Target torch device.
Returns:
gazing_info dict with ``gazing_pos``, ``num_gazing_each_frame``,
``if_padded_gazing``.
"""
image_size = self.vision_tower.config.image_size
patch_size = self.vision_tower.config.patch_size
scales = sorted(
int(s) for s in self.vision_tower.config.scales.split("+")
)
num_patches_each_scale = [(s // patch_size) ** 2 for s in scales]
total_patches_per_frame = sum(num_patches_each_scale)
# Gazing positions: all patches for every frame
per_item_pos = []
for t in range(T):
start = t * total_patches_per_frame
per_item_pos.append(
torch.arange(start, start + total_patches_per_frame, device=device, dtype=torch.long)
)
per_item_pos = torch.cat(per_item_pos) # (T * total_patches_per_frame,)
gazing_pos = per_item_pos.unsqueeze(0).expand(total_items, -1) # (B, N)
num_gazing_each_frame = torch.full(
(T,), total_patches_per_frame, device=device, dtype=torch.long
)
if_padded_gazing = torch.zeros_like(gazing_pos, dtype=torch.bool)
return {
"gazing_pos": gazing_pos,
"num_gazing_each_frame": num_gazing_each_frame,
"if_padded_gazing": if_padded_gazing,
}
def _encode_images(
self,
pixel_values_images_tiles: list[Tensor],
pixel_values_images_thumbnails: list[Tensor] | None,
num_spatial_tiles_each_image: list[int],
) -> list[Tensor]:
"""Encode image tiles + thumbnails and return projected features per image.
Each image is a set of spatial tiles plus one thumbnail (T=1 each).
All patches are kept (no gazing reduction). For each image the
spatial tiles are merged into one effective frame, the thumbnail
forms a second effective frame, and both are padded to
``shuffle_num`` before projection through the mm_projector.
Args:
pixel_values_images_tiles: Per-image tile tensors, each
``(num_tiles_i, 1, C, H, W)``.
pixel_values_images_thumbnails: Per-image thumbnail tensors,
each ``(1, 1, C, H, W)``. May be ``None``.
num_spatial_tiles_each_image: Number of spatial tiles per image.
Returns:
List of tensors (one per image), each ``(num_tokens_i, llm_hidden)``.
"""
shuffle_num = 9
device = self.vision_tower.device
# --- Run vision tower on all tiles ---
all_tiles = torch.cat(pixel_values_images_tiles, dim=0) # (total_tiles, 1, C, H, W)
total_tiles = all_tiles.shape[0]
gi_tiles = self._make_default_gazing_info(total_tiles, 1, device)
tiles_features = self._run_vision_tower_batched(all_tiles, gi_tiles) # (total_tiles, N, H)
num_gaze_tiles = gi_tiles["num_gazing_each_frame"] # (1,)
if_padded_tiles = gi_tiles["if_padded_gazing"] # (total_tiles, N)
frame_lens_tiles = num_gaze_tiles.tolist()
tile_feats: list[Tensor] = []
for idx in range(total_tiles):
feats = tiles_features[idx]
pad_mask = if_padded_tiles[idx]
frame_feats = feats.split(frame_lens_tiles, dim=0)
frame_pads = pad_mask.split(frame_lens_tiles, dim=0)
tile_feats.append(
torch.cat([f[~p] for f, p in zip(frame_feats, frame_pads)], dim=0)
)
# --- Run vision tower on all thumbnails ---
thumb_feats: list[Tensor] | None = None
if pixel_values_images_thumbnails is not None and len(pixel_values_images_thumbnails) > 0:
all_thumbs = torch.cat(pixel_values_images_thumbnails, dim=0) # (num_images, 1, C, H, W)
total_thumbs = all_thumbs.shape[0]
gi_thumbs = self._make_default_gazing_info(total_thumbs, 1, device)
thumbs_features = self._run_vision_tower_batched(all_thumbs, gi_thumbs)
num_gaze_thumbs = gi_thumbs["num_gazing_each_frame"]
if_padded_thumbs = gi_thumbs["if_padded_gazing"]
frame_lens_thumbs = num_gaze_thumbs.tolist()
thumb_feats = []
for idx in range(total_thumbs):
feats = thumbs_features[idx]
pad_mask = if_padded_thumbs[idx]
frame_feats = feats.split(frame_lens_thumbs, dim=0)
frame_pads = pad_mask.split(frame_lens_thumbs, dim=0)
thumb_feats.append(
torch.cat([f[~p] for f, p in zip(frame_feats, frame_pads)], dim=0)
)
# --- Build per-image sequences ---
tile_offset = 0
per_image_sequences: list[Tensor] = []
per_image_token_counts: list[int] = []
for img_idx, ns in enumerate(num_spatial_tiles_each_image):
effective_frames: list[Tensor] = []
# Tiles effective frame: merge all spatial tiles
spatial_feats = tile_feats[tile_offset : tile_offset + ns]
tile_offset += ns
effective_frames.append(torch.cat(spatial_feats, dim=0))
# Thumbnail effective frame
if thumb_feats is not None:
effective_frames.append(thumb_feats[img_idx])
# Pad each effective frame to divisible by shuffle_num
padded_frames: list[Tensor] = []
for frame in effective_frames:
n = frame.shape[0]
pad = (shuffle_num - n % shuffle_num) % shuffle_num
if pad > 0:
frame = torch.cat([frame, frame[-1:].expand(pad, -1)], dim=0)
padded_frames.append(frame)
image_seq = torch.cat(padded_frames, dim=0)
per_image_sequences.append(image_seq)
per_image_token_counts.append(image_seq.shape[0] // shuffle_num)
all_features = torch.cat(per_image_sequences, dim=0).unsqueeze(0)
projected = self.mm_projector(
all_features.to(device=self.device, dtype=self.dtype)
)
projected = projected.squeeze(0)
return list(projected.split(per_image_token_counts, dim=0))
def _run_vision_tower_batched(
self,
all_pixels: Tensor,
gazing_info_batch: dict,
) -> Tensor:
"""Run the vision tower in minibatches and concatenate features.
Args:
all_pixels: ``(B, T, C, H, W)`` tensor.
gazing_info_batch: Dict with ``gazing_pos`` ``(B, N)``,
``if_padded_gazing`` ``(B, N)``, and
``num_gazing_each_frame`` ``(T,)`` (shared across batch).
Returns:
``(B, N, H)`` hidden features from the second-to-last layer.
"""
device = self.vision_tower.device
dtype = self.vision_tower.dtype
total = all_pixels.shape[0]
bs = self.config.max_batch_size_siglip
if total <= bs:
out: BaseModelOutputWithPooling = self.vision_tower(
all_pixels.to(device=device, dtype=dtype),
gazing_info=gazing_info_batch,
output_hidden_states=True,
)
assert out.hidden_states is not None
return out.hidden_states[-2]
num_gaze_shared = gazing_info_batch["num_gazing_each_frame"]
all_pos = gazing_info_batch["gazing_pos"]
all_pad = gazing_info_batch["if_padded_gazing"]
feature_chunks: list[Tensor] = []
for start in range(0, total, bs):
end = min(start + bs, total)
mini_gi = {
"gazing_pos": all_pos[start:end],
"if_padded_gazing": all_pad[start:end],
"num_gazing_each_frame": num_gaze_shared,
}
out = self.vision_tower(
all_pixels[start:end].to(device=device, dtype=dtype),
gazing_info=mini_gi,
output_hidden_states=True,
)
assert out.hidden_states is not None
feature_chunks.append(out.hidden_states[-2])
return torch.cat(feature_chunks, dim=0)
def _encode_vision(
self,
pixel_values_videos_tiles: list[Tensor],
pixel_values_videos_thumbnails: list[Tensor],
gazing_info: dict | None,
num_spatial_tiles_each_video: list[int],
) -> list[Tensor]:
"""Encode tiles and thumbnails and return projected features per video.
Workflow
-------
1. Batch all tiles / thumbnails across videos and run the vision tower
(in minibatches controlled by ``config.max_batch_size_siglip``).
2. Remove padded gazing features.
3. Re-order per video: for each global temporal frame gather all spatial
tiles, then append thumbnail frames.
4. Pad each effective frame to be divisible by ``shuffle_num`` (9).
5. Concatenate all videos into a single sequence (batch=1), project
through ``mm_projector``, then split back per video.
Args:
pixel_values_videos_tiles: Per-video tile tensors, each
``(num_tiles_i, T_tile, C, H, W)``.
pixel_values_videos_thumbnails: Per-video thumbnail tensors, each
``(T_thumb_i, 1, C, H, W)``.
gazing_info: Dict produced by the processor containing per-video
gazing data for tiles and thumbnails. ``None`` triggers
default "gaze at all patches" behaviour.
num_spatial_tiles_each_video: Number of spatial tiles per video.
Returns:
List of tensors (one per video), each ``(num_tokens_i, llm_hidden)``.
"""
shuffle_num = 9 # must match TokenShuffle in NVILAMultiModalProjector
device = self.vision_tower.device
dtype = self.vision_tower.dtype
num_videos = len(pixel_values_videos_tiles)
num_tiles_per_video = [t.shape[0] for t in pixel_values_videos_tiles]
num_thumbs_per_video = [t.shape[0] for t in pixel_values_videos_thumbnails]
# ---- 1. Batch & run vision tower on tiles ----
all_tiles = torch.cat(pixel_values_videos_tiles, dim=0) # (total_tiles, T_tile, C, H, W)
T_tile = all_tiles.shape[1]
if gazing_info is not None:
tiles_nge = gazing_info["num_gazing_each_frame_tiles"]
ref = tiles_nge[0][0]
assert all(
torch.equal(t[0], ref) for t in tiles_nge
), "num_gazing_each_frame must be identical across all videos for tiles"
tiles_gi = {
"gazing_pos": torch.cat(gazing_info["gazing_pos_tiles"], dim=0).to(device),
"num_gazing_each_frame": gazing_info["num_gazing_each_frame_tiles"][0][0].to(device),
"if_padded_gazing": torch.cat(gazing_info["if_padded_gazing_tiles"], dim=0).to(device),
}
else:
tiles_gi = self._make_default_gazing_info(all_tiles.shape[0], T_tile, device)
tiles_features = self._run_vision_tower_batched(all_tiles, tiles_gi) # (total_tiles, N, H)
# ---- 2. Batch & run vision tower on thumbnails ----
all_thumbs = torch.cat(pixel_values_videos_thumbnails, dim=0) # (total_thumbs, 1, C, H, W)
if gazing_info is not None:
thumbs_nge = gazing_info["num_gazing_each_frame_thumbnails"]
ref = thumbs_nge[0][0]
assert all(
torch.equal(t[0], ref) for t in thumbs_nge
), "num_gazing_each_frame must be identical across all videos for thumbnails"
thumbs_gi = {
"gazing_pos": torch.cat(gazing_info["gazing_pos_thumbnails"], dim=0).to(device),
"num_gazing_each_frame": gazing_info["num_gazing_each_frame_thumbnails"][0][0].to(device),
"if_padded_gazing": torch.cat(gazing_info["if_padded_gazing_thumbnails"], dim=0).to(device),
}
else:
thumbs_gi = self._make_default_gazing_info(all_thumbs.shape[0], 1, device)
thumbs_features = self._run_vision_tower_batched(all_thumbs, thumbs_gi) # (total_thumbs, N', H)
# ---- 3. Remove padded features & split by frame ----
# For each tile: list of T_tile tensors, each (n_i, hidden)
all_tiles_if_padded = tiles_gi["if_padded_gazing"]
all_tiles_num_gaze = tiles_gi["num_gazing_each_frame"] # 1-D (T_tile,)
tiles_frame_lens = all_tiles_num_gaze.tolist()
all_tiles_frame_feats: list[list[Tensor]] = []
for idx in range(tiles_features.shape[0]):
feats = tiles_features[idx] # (N, hidden)
pad_mask = all_tiles_if_padded[idx] # (N,)
frame_feats = feats.split(tiles_frame_lens, dim=0)
frame_pads = pad_mask.split(tiles_frame_lens, dim=0)
all_tiles_frame_feats.append(
[f[~p] for f, p in zip(frame_feats, frame_pads)]
)
# For each thumbnail: list with 1 tensor (n_i, hidden)
all_thumbs_if_padded = thumbs_gi["if_padded_gazing"]
all_thumbs_num_gaze = thumbs_gi["num_gazing_each_frame"] # 1-D (1,)
thumbs_frame_lens = all_thumbs_num_gaze.tolist()
all_thumbs_frame_feats: list[list[Tensor]] = []
for idx in range(thumbs_features.shape[0]):
feats = thumbs_features[idx]
pad_mask = all_thumbs_if_padded[idx]
frame_feats = feats.split(thumbs_frame_lens, dim=0)
frame_pads = pad_mask.split(thumbs_frame_lens, dim=0)
all_thumbs_frame_feats.append(
[f[~p] for f, p in zip(frame_feats, frame_pads)]
)
# ---- 4. Per-video: reorder, pad frames, build sequences ----
tile_offset = 0
thumb_offset = 0
per_video_sequences: list[Tensor] = []
per_video_token_counts: list[int] = []
for vid_idx in range(num_videos):
ns = num_spatial_tiles_each_video[vid_idx]
nt = num_tiles_per_video[vid_idx]
tc = nt // ns # temporal chunks
total_frames = tc * T_tile
n_thumbs = num_thumbs_per_video[vid_idx]
vid_tile_feats = all_tiles_frame_feats[tile_offset: tile_offset + nt]
tile_offset += nt
vid_thumb_feats = all_thumbs_frame_feats[thumb_offset: thumb_offset + n_thumbs]
thumb_offset += n_thumbs
# -- Reorder tile features to frame-first --
# Tiles from processor are ordered:
# chunk0: [S0, S1, ..., S_{ns-1}], chunk1: [S0, ...], ...
# We want: for each global frame g, cat all spatial tiles.
effective_frames: list[Tensor] = []
for g in range(total_frames):
chunk = g // T_tile
f_in_chunk = g % T_tile
spatial_feats = [
vid_tile_feats[chunk * ns + s][f_in_chunk]
for s in range(ns)
]
effective_frames.append(torch.cat(spatial_feats, dim=0))
# -- Append thumbnail frames --
for thumb in vid_thumb_feats:
effective_frames.append(thumb[0]) # single frame
# -- Pad each effective frame to divisible by shuffle_num --
padded_frames: list[Tensor] = []
for frame in effective_frames:
n = frame.shape[0]
pad = (shuffle_num - n % shuffle_num) % shuffle_num
if pad > 0:
padded_frame = torch.cat(
[frame, frame[-1:].expand(pad, -1)], dim=0
)
else:
padded_frame = frame
padded_frames.append(padded_frame)
video_seq = torch.cat(padded_frames, dim=0) # (total_padded, hidden)
per_video_sequences.append(video_seq)
per_video_token_counts.append(video_seq.shape[0] // shuffle_num)
# ---- 5. Concat all videos, project, split back ----
all_features = torch.cat(per_video_sequences, dim=0).unsqueeze(0) # (1, total, hidden)
projected = self.mm_projector(
all_features.to(device=self.device, dtype=self.dtype)
) # (1, total // shuffle_num, llm_hidden)
projected = projected.squeeze(0) # (total // shuffle_num, llm_hidden)
per_video_features = list(projected.split(per_video_token_counts, dim=0))
return per_video_features