| import contextlib |
| import sys |
| from pathlib import Path |
| from typing import Optional |
|
|
| import einops |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from transformers import Qwen2ForCausalLM |
| from transformers.cache_utils import Cache |
| from transformers.generation.utils import GenerationMixin |
| from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from autogaze.vision_encoders.siglip.modeling_siglip import SiglipVisionModel |
|
|
| from .configuration_nvila import NVILAConfig |
|
|
|
|
| MM_HIDDEN_SIZE = 1152 |
|
|
|
|
| class TokenShuffle(nn.Module): |
| """Token shuffle module that groups tokens and concatenates their features.""" |
| def __init__(self, shuffle_num: int): |
| super().__init__() |
| self.shuffle_num = shuffle_num |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """ |
| Args: |
| x: (B, N, C) tensor where B is batch size, N is sequence length, C is hidden size |
| Returns: |
| (B, N', C * shuffle_num) tensor where N' = ceil(N / shuffle_num) |
| """ |
| |
| if x.shape[1] % self.shuffle_num != 0: |
| |
| pad_size = self.shuffle_num - (x.shape[1] % self.shuffle_num) |
| x = torch.cat([x, x[:, -1:].repeat(1, pad_size, 1)], dim=1) |
| |
| return einops.rearrange(x, "b (n k) c -> b n (k c)", k=self.shuffle_num) |
|
|
|
|
| class NVILAMultiModalProjector(nn.Module): |
| """Multi-modal projector using mlp_shuffle_9 architecture.""" |
| def __init__(self, config: NVILAConfig): |
| super().__init__() |
|
|
| self.layers = nn.Sequential( |
| TokenShuffle(9), |
| nn.LayerNorm(MM_HIDDEN_SIZE * 9), |
| nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3), |
| nn.GELU(), |
| nn.LayerNorm(MM_HIDDEN_SIZE * 3), |
| nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size), |
| nn.GELU(), |
| nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size), |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.layers(x) |
|
|
|
|
| class NVILAForConditionalGeneration(PreTrainedModel, GenerationMixin): |
| config_class = NVILAConfig |
| base_model_prefix: str = "llm" |
| _auto_class = "AutoModel" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
|
|
| def __init__(self, config: NVILAConfig): |
| super().__init__(config) |
|
|
| self.config: NVILAConfig |
|
|
| @contextlib.contextmanager |
| def default_torch_dtype(dtype): |
| original_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(dtype) |
| try: |
| yield |
| finally: |
| torch.set_default_dtype(original_dtype) |
|
|
| with default_torch_dtype(config.torch_dtype): |
| self.vision_tower = SiglipVisionModel(config.vision_config) |
| self.mm_projector = NVILAMultiModalProjector(config) |
| self.llm = Qwen2ForCausalLM(config.text_config) |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| *, |
| input_ids: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| pixel_values: Tensor | None = None, |
| pixel_values_images_tiles: list[Tensor] | None = None, |
| pixel_values_images_thumbnails: list[Tensor] | None = None, |
| num_spatial_tiles_each_image: list[int] | None = None, |
| pixel_values_videos_tiles: list[Tensor] | None = None, |
| pixel_values_videos_thumbnails: list[Tensor] | None = None, |
| gazing_info: dict | None = None, |
| num_spatial_tiles_each_video: list[int] | None = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| assert (input_ids is None) != ( |
| inputs_embeds is None |
| ), "Exactly one of `input_ids` or `inputs_embeds` must be specified." |
|
|
| |
| kwargs.pop("pixel_values_videos_tiles_autogaze", None) |
| kwargs.pop("pixel_values_videos_thumbnails_autogaze", None) |
| kwargs.pop("pixel_values_videos", None) |
|
|
| if input_ids is not None and torch.any( |
| torch.isin( |
| input_ids, |
| torch.tensor( |
| [self.config.image_token_id, self.config.video_token_id], |
| device=input_ids.device, |
| ), |
| ).any() |
| ): |
| |
| if gazing_info is None: |
| gazing_info = kwargs.pop("gazing_info", None) |
| if pixel_values_images_tiles is None: |
| pixel_values_images_tiles = kwargs.pop("pixel_values_images_tiles", None) |
| if pixel_values_images_thumbnails is None: |
| pixel_values_images_thumbnails = kwargs.pop("pixel_values_images_thumbnails", None) |
| if num_spatial_tiles_each_image is None: |
| num_spatial_tiles_each_image = kwargs.pop("num_spatial_tiles_each_image", None) |
| if pixel_values_videos_tiles is None: |
| pixel_values_videos_tiles = kwargs.pop("pixel_values_videos_tiles", None) |
| if pixel_values_videos_thumbnails is None: |
| pixel_values_videos_thumbnails = kwargs.pop("pixel_values_videos_thumbnails", None) |
| if num_spatial_tiles_each_video is None: |
| num_spatial_tiles_each_video = kwargs.pop("num_spatial_tiles_each_video", None) |
|
|
| inputs_embeds = self._embed( |
| input_ids=input_ids, |
| pixel_values=pixel_values, |
| pixel_values_images_tiles=pixel_values_images_tiles, |
| pixel_values_images_thumbnails=pixel_values_images_thumbnails, |
| num_spatial_tiles_each_image=num_spatial_tiles_each_image, |
| pixel_values_videos_tiles=pixel_values_videos_tiles, |
| pixel_values_videos_thumbnails=pixel_values_videos_thumbnails, |
| gazing_info=gazing_info, |
| num_spatial_tiles_each_video=num_spatial_tiles_each_video, |
| ) |
| input_ids = None |
|
|
| outputs = self.llm( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| **kwargs, |
| ) |
|
|
| return outputs |
|
|
| def _embed( |
| self, |
| *, |
| input_ids: Tensor, |
| pixel_values: Tensor | None, |
| pixel_values_images_tiles: list[Tensor] | None, |
| pixel_values_images_thumbnails: list[Tensor] | None, |
| num_spatial_tiles_each_image: list[int] | None, |
| pixel_values_videos_tiles: list[Tensor] | None, |
| pixel_values_videos_thumbnails: list[Tensor] | None, |
| gazing_info: dict | None = None, |
| num_spatial_tiles_each_video: list[int] | None = None, |
| ) -> Tensor: |
| inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids) |
|
|
| |
| if pixel_values_images_tiles is not None and len(pixel_values_images_tiles) > 0: |
| per_image_features = self._encode_images( |
| pixel_values_images_tiles=pixel_values_images_tiles, |
| pixel_values_images_thumbnails=pixel_values_images_thumbnails, |
| num_spatial_tiles_each_image=num_spatial_tiles_each_image, |
| ) |
| all_features = torch.cat(per_image_features, dim=0) |
|
|
| image_token_mask = input_ids == self.config.image_token_id |
| num_image_tokens = image_token_mask.sum().item() |
| num_image_features = all_features.shape[0] |
|
|
| assert num_image_features == num_image_tokens, ( |
| f"Number of image features {num_image_features} does not match " |
| f"number of image tokens {num_image_tokens}" |
| ) |
|
|
| inputs_embeds[image_token_mask] = all_features.to(inputs_embeds.dtype) |
|
|
| |
| if pixel_values_videos_tiles is not None: |
| per_video_features = self._encode_vision( |
| pixel_values_videos_tiles=pixel_values_videos_tiles, |
| pixel_values_videos_thumbnails=pixel_values_videos_thumbnails, |
| gazing_info=gazing_info, |
| num_spatial_tiles_each_video=num_spatial_tiles_each_video, |
| ) |
| |
| all_features = torch.cat(per_video_features, dim=0) |
|
|
| |
| video_token_mask = input_ids == self.config.video_token_id |
| num_video_tokens = video_token_mask.sum().item() |
| num_vision_features = all_features.shape[0] |
|
|
| assert num_vision_features == num_video_tokens, ( |
| f"Number of vision features {num_vision_features} does not match " |
| f"number of video tokens {num_video_tokens}" |
| ) |
|
|
| inputs_embeds[video_token_mask] = all_features.to(inputs_embeds.dtype) |
|
|
| return inputs_embeds |
|
|
| def _make_default_gazing_info( |
| self, |
| total_items: int, |
| T: int, |
| device: torch.device, |
| ) -> dict: |
| """Create gazing_info that gazes at every patch (no reduction). |
| |
| Args: |
| total_items: Number of items (tiles or thumbnails) in the batch. |
| T: Temporal frames per item. |
| device: Target torch device. |
| |
| Returns: |
| gazing_info dict with ``gazing_pos``, ``num_gazing_each_frame``, |
| ``if_padded_gazing``. |
| """ |
| image_size = self.vision_tower.config.image_size |
| patch_size = self.vision_tower.config.patch_size |
| scales = sorted( |
| int(s) for s in self.vision_tower.config.scales.split("+") |
| ) |
| num_patches_each_scale = [(s // patch_size) ** 2 for s in scales] |
| total_patches_per_frame = sum(num_patches_each_scale) |
|
|
| |
| per_item_pos = [] |
| for t in range(T): |
| start = t * total_patches_per_frame |
| per_item_pos.append( |
| torch.arange(start, start + total_patches_per_frame, device=device, dtype=torch.long) |
| ) |
| per_item_pos = torch.cat(per_item_pos) |
|
|
| gazing_pos = per_item_pos.unsqueeze(0).expand(total_items, -1) |
| num_gazing_each_frame = torch.full( |
| (T,), total_patches_per_frame, device=device, dtype=torch.long |
| ) |
| if_padded_gazing = torch.zeros_like(gazing_pos, dtype=torch.bool) |
|
|
| return { |
| "gazing_pos": gazing_pos, |
| "num_gazing_each_frame": num_gazing_each_frame, |
| "if_padded_gazing": if_padded_gazing, |
| } |
|
|
| def _encode_images( |
| self, |
| pixel_values_images_tiles: list[Tensor], |
| pixel_values_images_thumbnails: list[Tensor] | None, |
| num_spatial_tiles_each_image: list[int], |
| ) -> list[Tensor]: |
| """Encode image tiles + thumbnails and return projected features per image. |
| |
| Each image is a set of spatial tiles plus one thumbnail (T=1 each). |
| All patches are kept (no gazing reduction). For each image the |
| spatial tiles are merged into one effective frame, the thumbnail |
| forms a second effective frame, and both are padded to |
| ``shuffle_num`` before projection through the mm_projector. |
| |
| Args: |
| pixel_values_images_tiles: Per-image tile tensors, each |
| ``(num_tiles_i, 1, C, H, W)``. |
| pixel_values_images_thumbnails: Per-image thumbnail tensors, |
| each ``(1, 1, C, H, W)``. May be ``None``. |
| num_spatial_tiles_each_image: Number of spatial tiles per image. |
| |
| Returns: |
| List of tensors (one per image), each ``(num_tokens_i, llm_hidden)``. |
| """ |
| shuffle_num = 9 |
| device = self.vision_tower.device |
|
|
| |
| all_tiles = torch.cat(pixel_values_images_tiles, dim=0) |
| total_tiles = all_tiles.shape[0] |
|
|
| gi_tiles = self._make_default_gazing_info(total_tiles, 1, device) |
| tiles_features = self._run_vision_tower_batched(all_tiles, gi_tiles) |
|
|
| num_gaze_tiles = gi_tiles["num_gazing_each_frame"] |
| if_padded_tiles = gi_tiles["if_padded_gazing"] |
| frame_lens_tiles = num_gaze_tiles.tolist() |
|
|
| tile_feats: list[Tensor] = [] |
| for idx in range(total_tiles): |
| feats = tiles_features[idx] |
| pad_mask = if_padded_tiles[idx] |
| frame_feats = feats.split(frame_lens_tiles, dim=0) |
| frame_pads = pad_mask.split(frame_lens_tiles, dim=0) |
| tile_feats.append( |
| torch.cat([f[~p] for f, p in zip(frame_feats, frame_pads)], dim=0) |
| ) |
|
|
| |
| thumb_feats: list[Tensor] | None = None |
| if pixel_values_images_thumbnails is not None and len(pixel_values_images_thumbnails) > 0: |
| all_thumbs = torch.cat(pixel_values_images_thumbnails, dim=0) |
| total_thumbs = all_thumbs.shape[0] |
|
|
| gi_thumbs = self._make_default_gazing_info(total_thumbs, 1, device) |
| thumbs_features = self._run_vision_tower_batched(all_thumbs, gi_thumbs) |
|
|
| num_gaze_thumbs = gi_thumbs["num_gazing_each_frame"] |
| if_padded_thumbs = gi_thumbs["if_padded_gazing"] |
| frame_lens_thumbs = num_gaze_thumbs.tolist() |
|
|
| thumb_feats = [] |
| for idx in range(total_thumbs): |
| feats = thumbs_features[idx] |
| pad_mask = if_padded_thumbs[idx] |
| frame_feats = feats.split(frame_lens_thumbs, dim=0) |
| frame_pads = pad_mask.split(frame_lens_thumbs, dim=0) |
| thumb_feats.append( |
| torch.cat([f[~p] for f, p in zip(frame_feats, frame_pads)], dim=0) |
| ) |
|
|
| |
| tile_offset = 0 |
| per_image_sequences: list[Tensor] = [] |
| per_image_token_counts: list[int] = [] |
|
|
| for img_idx, ns in enumerate(num_spatial_tiles_each_image): |
| effective_frames: list[Tensor] = [] |
|
|
| |
| spatial_feats = tile_feats[tile_offset : tile_offset + ns] |
| tile_offset += ns |
| effective_frames.append(torch.cat(spatial_feats, dim=0)) |
|
|
| |
| if thumb_feats is not None: |
| effective_frames.append(thumb_feats[img_idx]) |
|
|
| |
| padded_frames: list[Tensor] = [] |
| for frame in effective_frames: |
| n = frame.shape[0] |
| pad = (shuffle_num - n % shuffle_num) % shuffle_num |
| if pad > 0: |
| frame = torch.cat([frame, frame[-1:].expand(pad, -1)], dim=0) |
| padded_frames.append(frame) |
|
|
| image_seq = torch.cat(padded_frames, dim=0) |
| per_image_sequences.append(image_seq) |
| per_image_token_counts.append(image_seq.shape[0] // shuffle_num) |
|
|
| all_features = torch.cat(per_image_sequences, dim=0).unsqueeze(0) |
| projected = self.mm_projector( |
| all_features.to(device=self.device, dtype=self.dtype) |
| ) |
| projected = projected.squeeze(0) |
|
|
| return list(projected.split(per_image_token_counts, dim=0)) |
|
|
| def _run_vision_tower_batched( |
| self, |
| all_pixels: Tensor, |
| gazing_info_batch: dict, |
| ) -> Tensor: |
| """Run the vision tower in minibatches and concatenate features. |
| |
| Args: |
| all_pixels: ``(B, T, C, H, W)`` tensor. |
| gazing_info_batch: Dict with ``gazing_pos`` ``(B, N)``, |
| ``if_padded_gazing`` ``(B, N)``, and |
| ``num_gazing_each_frame`` ``(T,)`` (shared across batch). |
| |
| Returns: |
| ``(B, N, H)`` hidden features from the second-to-last layer. |
| """ |
| device = self.vision_tower.device |
| dtype = self.vision_tower.dtype |
| total = all_pixels.shape[0] |
| bs = self.config.max_batch_size_siglip |
|
|
| if total <= bs: |
| out: BaseModelOutputWithPooling = self.vision_tower( |
| all_pixels.to(device=device, dtype=dtype), |
| gazing_info=gazing_info_batch, |
| output_hidden_states=True, |
| ) |
| assert out.hidden_states is not None |
| return out.hidden_states[-2] |
|
|
| num_gaze_shared = gazing_info_batch["num_gazing_each_frame"] |
| all_pos = gazing_info_batch["gazing_pos"] |
| all_pad = gazing_info_batch["if_padded_gazing"] |
|
|
| feature_chunks: list[Tensor] = [] |
| for start in range(0, total, bs): |
| end = min(start + bs, total) |
| mini_gi = { |
| "gazing_pos": all_pos[start:end], |
| "if_padded_gazing": all_pad[start:end], |
| "num_gazing_each_frame": num_gaze_shared, |
| } |
| out = self.vision_tower( |
| all_pixels[start:end].to(device=device, dtype=dtype), |
| gazing_info=mini_gi, |
| output_hidden_states=True, |
| ) |
| assert out.hidden_states is not None |
| feature_chunks.append(out.hidden_states[-2]) |
|
|
| return torch.cat(feature_chunks, dim=0) |
|
|
| def _encode_vision( |
| self, |
| pixel_values_videos_tiles: list[Tensor], |
| pixel_values_videos_thumbnails: list[Tensor], |
| gazing_info: dict | None, |
| num_spatial_tiles_each_video: list[int], |
| ) -> list[Tensor]: |
| """Encode tiles and thumbnails and return projected features per video. |
| |
| Workflow |
| ------- |
| 1. Batch all tiles / thumbnails across videos and run the vision tower |
| (in minibatches controlled by ``config.max_batch_size_siglip``). |
| 2. Remove padded gazing features. |
| 3. Re-order per video: for each global temporal frame gather all spatial |
| tiles, then append thumbnail frames. |
| 4. Pad each effective frame to be divisible by ``shuffle_num`` (9). |
| 5. Concatenate all videos into a single sequence (batch=1), project |
| through ``mm_projector``, then split back per video. |
| |
| Args: |
| pixel_values_videos_tiles: Per-video tile tensors, each |
| ``(num_tiles_i, T_tile, C, H, W)``. |
| pixel_values_videos_thumbnails: Per-video thumbnail tensors, each |
| ``(T_thumb_i, 1, C, H, W)``. |
| gazing_info: Dict produced by the processor containing per-video |
| gazing data for tiles and thumbnails. ``None`` triggers |
| default "gaze at all patches" behaviour. |
| num_spatial_tiles_each_video: Number of spatial tiles per video. |
| |
| Returns: |
| List of tensors (one per video), each ``(num_tokens_i, llm_hidden)``. |
| """ |
| shuffle_num = 9 |
| device = self.vision_tower.device |
| dtype = self.vision_tower.dtype |
|
|
| num_videos = len(pixel_values_videos_tiles) |
| num_tiles_per_video = [t.shape[0] for t in pixel_values_videos_tiles] |
| num_thumbs_per_video = [t.shape[0] for t in pixel_values_videos_thumbnails] |
|
|
| |
| all_tiles = torch.cat(pixel_values_videos_tiles, dim=0) |
| T_tile = all_tiles.shape[1] |
|
|
| if gazing_info is not None: |
| tiles_nge = gazing_info["num_gazing_each_frame_tiles"] |
| ref = tiles_nge[0][0] |
| assert all( |
| torch.equal(t[0], ref) for t in tiles_nge |
| ), "num_gazing_each_frame must be identical across all videos for tiles" |
| tiles_gi = { |
| "gazing_pos": torch.cat(gazing_info["gazing_pos_tiles"], dim=0).to(device), |
| "num_gazing_each_frame": gazing_info["num_gazing_each_frame_tiles"][0][0].to(device), |
| "if_padded_gazing": torch.cat(gazing_info["if_padded_gazing_tiles"], dim=0).to(device), |
| } |
| else: |
| tiles_gi = self._make_default_gazing_info(all_tiles.shape[0], T_tile, device) |
|
|
| tiles_features = self._run_vision_tower_batched(all_tiles, tiles_gi) |
|
|
| |
| all_thumbs = torch.cat(pixel_values_videos_thumbnails, dim=0) |
|
|
| if gazing_info is not None: |
| thumbs_nge = gazing_info["num_gazing_each_frame_thumbnails"] |
| ref = thumbs_nge[0][0] |
| assert all( |
| torch.equal(t[0], ref) for t in thumbs_nge |
| ), "num_gazing_each_frame must be identical across all videos for thumbnails" |
| thumbs_gi = { |
| "gazing_pos": torch.cat(gazing_info["gazing_pos_thumbnails"], dim=0).to(device), |
| "num_gazing_each_frame": gazing_info["num_gazing_each_frame_thumbnails"][0][0].to(device), |
| "if_padded_gazing": torch.cat(gazing_info["if_padded_gazing_thumbnails"], dim=0).to(device), |
| } |
| else: |
| thumbs_gi = self._make_default_gazing_info(all_thumbs.shape[0], 1, device) |
|
|
| thumbs_features = self._run_vision_tower_batched(all_thumbs, thumbs_gi) |
|
|
| |
| |
| all_tiles_if_padded = tiles_gi["if_padded_gazing"] |
| all_tiles_num_gaze = tiles_gi["num_gazing_each_frame"] |
| tiles_frame_lens = all_tiles_num_gaze.tolist() |
|
|
| all_tiles_frame_feats: list[list[Tensor]] = [] |
| for idx in range(tiles_features.shape[0]): |
| feats = tiles_features[idx] |
| pad_mask = all_tiles_if_padded[idx] |
| frame_feats = feats.split(tiles_frame_lens, dim=0) |
| frame_pads = pad_mask.split(tiles_frame_lens, dim=0) |
| all_tiles_frame_feats.append( |
| [f[~p] for f, p in zip(frame_feats, frame_pads)] |
| ) |
|
|
| |
| all_thumbs_if_padded = thumbs_gi["if_padded_gazing"] |
| all_thumbs_num_gaze = thumbs_gi["num_gazing_each_frame"] |
| thumbs_frame_lens = all_thumbs_num_gaze.tolist() |
|
|
| all_thumbs_frame_feats: list[list[Tensor]] = [] |
| for idx in range(thumbs_features.shape[0]): |
| feats = thumbs_features[idx] |
| pad_mask = all_thumbs_if_padded[idx] |
| frame_feats = feats.split(thumbs_frame_lens, dim=0) |
| frame_pads = pad_mask.split(thumbs_frame_lens, dim=0) |
| all_thumbs_frame_feats.append( |
| [f[~p] for f, p in zip(frame_feats, frame_pads)] |
| ) |
|
|
| |
| tile_offset = 0 |
| thumb_offset = 0 |
| per_video_sequences: list[Tensor] = [] |
| per_video_token_counts: list[int] = [] |
|
|
| for vid_idx in range(num_videos): |
| ns = num_spatial_tiles_each_video[vid_idx] |
| nt = num_tiles_per_video[vid_idx] |
| tc = nt // ns |
| total_frames = tc * T_tile |
| n_thumbs = num_thumbs_per_video[vid_idx] |
|
|
| vid_tile_feats = all_tiles_frame_feats[tile_offset: tile_offset + nt] |
| tile_offset += nt |
| vid_thumb_feats = all_thumbs_frame_feats[thumb_offset: thumb_offset + n_thumbs] |
| thumb_offset += n_thumbs |
|
|
| |
| |
| |
| |
| effective_frames: list[Tensor] = [] |
| for g in range(total_frames): |
| chunk = g // T_tile |
| f_in_chunk = g % T_tile |
| spatial_feats = [ |
| vid_tile_feats[chunk * ns + s][f_in_chunk] |
| for s in range(ns) |
| ] |
| effective_frames.append(torch.cat(spatial_feats, dim=0)) |
|
|
| |
| for thumb in vid_thumb_feats: |
| effective_frames.append(thumb[0]) |
|
|
| |
| padded_frames: list[Tensor] = [] |
| for frame in effective_frames: |
| n = frame.shape[0] |
| pad = (shuffle_num - n % shuffle_num) % shuffle_num |
| if pad > 0: |
| padded_frame = torch.cat( |
| [frame, frame[-1:].expand(pad, -1)], dim=0 |
| ) |
| else: |
| padded_frame = frame |
| padded_frames.append(padded_frame) |
|
|
| video_seq = torch.cat(padded_frames, dim=0) |
| per_video_sequences.append(video_seq) |
| per_video_token_counts.append(video_seq.shape[0] // shuffle_num) |
|
|
| |
| all_features = torch.cat(per_video_sequences, dim=0).unsqueeze(0) |
| projected = self.mm_projector( |
| all_features.to(device=self.device, dtype=self.dtype) |
| ) |
| projected = projected.squeeze(0) |
|
|
| per_video_features = list(projected.split(per_video_token_counts, dim=0)) |
|
|
| return per_video_features |
|
|