# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Processor class for VideoMllama.""" import av import cv2 import math import numpy as np import concurrent.futures from PIL import Image from typing import List, Optional, Union, Tuple from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, to_numpy_array from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import ( PreTokenizedInput, TextInput, ) from .image_processing_video_mllama import make_list_of_images class VideoMllamaImagesKwargs(ImagesKwargs, total=False): max_image_tiles: Optional[int] class VideoMllamaProcessorKwargs(ProcessingKwargs, total=False): images_kwargs: VideoMllamaImagesKwargs add_video_position_encoding: Optional[bool] _defaults = { "image_kwargs": { "max_image_tiles": 1, }, "add_video_position_encoding": True, } # --- Start of new video sampling functions (adapted from streaming/mm_plugin.py) --- def validate_frame_sampling(sample_indices, frames, max_missing_frames=2, max_missing_ratio=0.1): """ Validate the completeness of sampled frames. """ expected_count = len(sample_indices) actual_count = len(frames) missing_count = expected_count - actual_count if missing_count <= 0: return missing_ratio = missing_count / expected_count if missing_count > max_missing_frames and missing_ratio > max_missing_ratio: raise ValueError( f"Too many frames missing: {missing_count}/{expected_count} " f"({missing_ratio:.1%}) frames missing, exceeding " f"{max_missing_ratio:.0%} threshold." ) def _get_video_sample_frames(video_stream, total_frames: int = 0, **kwargs) -> np.ndarray: """ Core logic to compute video sample frame indices. """ video_fps: float = kwargs.get("video_fps", 1.0) video_minlen: int = kwargs.get("video_minlen", 8) video_maxlen: int = kwargs.get("video_maxlen", 256) obtained_total_frames = int(video_stream.frames) duration = float(video_stream.duration * video_stream.time_base) frame_rate = float(video_stream.average_rate) calculated_total_frames = round(duration * frame_rate) assert video_fps <= frame_rate, f"Sampling frequency ({video_fps}) must be less than or equal to video frame rate ({frame_rate})" total_frames_num = [x for x in [total_frames, obtained_total_frames, calculated_total_frames] if x > 0] final_total_frames = min(total_frames_num) if total_frames_num else 0 if final_total_frames == 0: raise AttributeError("Unable to obtain or calculate the total number of frames in the video.") target_total_frames = int(math.ceil(duration * video_fps - 1e-6)) sample_frames = max(target_total_frames, video_minlen) sample_frames = min(sample_frames, video_maxlen, final_total_frames) if target_total_frames == sample_frames and video_fps > 0 and frame_rate > 0: sample_indices = np.arange(target_total_frames, dtype=np.int32) sample_indices = (sample_indices * frame_rate / video_fps).astype(np.int32) else: sample_indices = np.linspace(0, final_total_frames - 1, sample_frames).astype(np.int32) return sample_indices def _get_cv2_video_sample_frames(video_path: str, total_frames: int = 0, **kwargs) -> np.ndarray: container = av.open(video_path, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") sample_indices = _get_video_sample_frames(video_stream, total_frames=total_frames, **kwargs) return sample_indices def get_video_sample_frames_av(video_path: str, **kwargs) -> List[Image.Image]: container = av.open(video_path, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") sample_indices = _get_video_sample_frames(video_stream, **kwargs) sample_indices_set = set(sample_indices) frames: List[Image.Image] = [] container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): if frame_idx in sample_indices_set: frames.append(frame.to_image()) if len(frames) == len(sample_indices): break validate_frame_sampling(sample_indices, frames) return frames def get_cv2_video_sample_frames_multithread(video_path: str, **kwargs) -> List[Image.Image]: num_threads: int = kwargs.get("frame_extract_num_threads", 4) num_threads = int(num_threads) cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Unable to open video file: {video_path}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() frame_indices = _get_cv2_video_sample_frames(video_path, total_frames=total_frames, **kwargs) unique_frames: List[Optional[np.ndarray]] = [None] * len(frame_indices) index_map = {idx: pos for pos, idx in enumerate(frame_indices)} chunks = np.array_split(frame_indices, min(num_threads, len(frame_indices))) def worker(chunk_indices): local_cap = cv2.VideoCapture(video_path) if not local_cap.isOpened(): return if chunk_indices[0] > 0: local_cap.set(cv2.CAP_PROP_POS_FRAMES, chunk_indices[0]) frame_idx_cursor = chunk_indices[0] chunk_cursor = 0 while chunk_cursor < len(chunk_indices): target_idx = chunk_indices[chunk_cursor] ok = local_cap.grab() if not ok: break if frame_idx_cursor == target_idx: ret, frame = local_cap.retrieve() if ret: unique_pos = index_map[target_idx] unique_frames[unique_pos] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) chunk_cursor += 1 frame_idx_cursor += 1 local_cap.release() with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: list(executor.map(worker, [chunk for chunk in chunks if len(chunk) > 0])) pil_frames = [Image.fromarray(frame) for frame in unique_frames if frame is not None] validate_frame_sampling(frame_indices, pil_frames) if not pil_frames: return get_video_sample_frames_av(video_path, **kwargs) return pil_frames # --- End of new video sampling functions --- def get_cross_attention_token_mask( input_ids: List[int], attention_mask: List[int], image_token_id: int, video_token_id: int, frame_num_per_video: List[int], cross_attention_token_mask_pad_token_id: int = -100, ) -> Tuple[List[int], List[int], List[int]]: """ Generate a cross-attention-token-mask for each input_tokens in the input sequence. This function implements a causal attention logic: - A text token can see all image tokens that appeared before it. - An image token can see itself and all image tokens that appeared before it. """ # 1. Convert video tokens to image tokens input_ids_np = np.array(input_ids, dtype=np.int64) if video_token_id in input_ids_np: total_vid_num = np.sum(input_ids_np == video_token_id) f_num_per_vid = frame_num_per_video[:total_vid_num] convert_input_ids_list = [] convert_attention_mask_list = [] vid_idx = 0 for token_id, mask_val in zip(input_ids_np, attention_mask): if token_id == video_token_id: vid_len = f_num_per_vid[vid_idx] vid_idx += 1 convert_input_ids_list.extend([image_token_id] * vid_len) convert_attention_mask_list.extend([mask_val] * vid_len) else: convert_input_ids_list.append(token_id) convert_attention_mask_list.append(mask_val) convert_input_ids = np.array(convert_input_ids_list, dtype=np.int64) convert_attention_mask = np.array(convert_attention_mask_list, dtype=np.int64) else: convert_input_ids = input_ids_np convert_attention_mask = np.array(attention_mask, dtype=np.int64) # 2. Generate the sparse attention mask based on causal visibility is_image = convert_input_ids == image_token_id # Cumulative count of images up to and including the current position image_count_cumulative = np.cumsum(is_image) # Cumulative count of images up to the previous position image_count_before = np.pad(image_count_cumulative[:-1], (1, 0), "constant", constant_values=0) # For text tokens, num_seen = image_count_before. # For image tokens, num_seen = image_count_cumulative (sees itself). num_images_seen = np.where(is_image, image_count_cumulative, image_count_before) # Convert num_seen to sparse mask value (num_seen - 1). vision_masks = np.full(len(convert_input_ids), cross_attention_token_mask_pad_token_id, dtype=np.int64) valid_mask = num_images_seen > 0 vision_masks[valid_mask] = num_images_seen[valid_mask] - 1 return vision_masks.tolist(), convert_input_ids.tolist(), convert_attention_mask.tolist() def convert_sparse_cross_attention_mask_to_dense( cross_attention_token_masks: np.ndarray, num_tiles: List[List[int]], max_num_tiles: int, cross_attention_token_mask_pad_token_id: int = -100, ) -> np.ndarray: """ Convert the cross attention mask indices to a cross attention mask 4D array. This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array. The sparse representation is a tensor that defines [the range of images that can be seen] for [each input token]. """ batch_size, length = cross_attention_token_masks.shape max_num_images = max([len(n_tiles) for n_tiles in num_tiles]) if num_tiles else 0 cross_attention_mask = np.zeros( shape=(batch_size, length, max_num_images, max_num_tiles), dtype=np.int64, ) if max_num_images == 0: return cross_attention_mask for batch_idx, (sparse_mask, n_tiles) in enumerate(zip(cross_attention_token_masks, num_tiles)): # For each image, find all text tokens that are allowed to see it. # A token with sparse_mask value N can see all images with index i <= N. for image_idx, mask_n_tiles in enumerate(n_tiles): # Find all token positions where the sparse mask value is >= the current image's index. # This correctly implements the causal logic. visible_token_indices = (sparse_mask >= image_idx) & (sparse_mask != cross_attention_token_mask_pad_token_id) # Set the attention mask to 1 for these tokens and the current image. cross_attention_mask[batch_idx, visible_token_indices, image_idx, :mask_n_tiles] = 1 return cross_attention_mask def build_string_from_input(prompt: str, bos_token: str, image_token: str, video_token: str) -> str: """ Builds a string from the input prompt by adding `bos_token` if not already present. It handles prompts starting with image or video tokens. """ if bos_token in prompt: return prompt num_media_tokens_on_start = 0 media_tokens = [] while prompt.startswith(image_token) or prompt.startswith(video_token): if prompt.startswith(image_token): prompt = prompt[len(image_token) :] media_tokens.append(image_token) elif prompt.startswith(video_token): prompt = prompt[len(video_token) :] media_tokens.append(video_token) num_media_tokens_on_start += 1 print(f"No bos_token `{bos_token}` in prompt, so it is added after the {num_media_tokens_on_start} media tokens at the start of the prompt.") return f"{''.join(media_tokens)}{bos_token}{prompt}" VIDEO_MLLAMA_PROCESSOR_PAD_POSITION_ID = 0 VIDEO_MLLAMA_PROCESSOR_CROSS_ATTENTION_TOKEN_MASK_PAD_TOKEN_ID = -100 class VideoMllamaProcessor(ProcessorMixin): r""" Constructs a VideoMllama processor which wraps [`VideoMllamaImageProcessor`] and [`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and tokenizer functionalities. See the [`~VideoMllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more information. The preferred way of passing kwargs is as a dictionary per modality, see usage example below. ```python from transformers import VideoMllamaProcessor from PIL import Image processor = VideoMllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision") processor( images=your_pil_image, text=["<|image|>If I had to write a haiku for this one"], images_kwargs = {"size": {"height": 448, "width": 448}}, text_kwargs = {"padding": "right"}, common_kwargs = {"return_tensors": "pt"}, ) ``` Args: image_processor ([`VideoMllamaImageProcessor`]): The image processor is a required input. tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]): The tokenizer is a required input. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "PreTrainedTokenizerFast" _defaults = { "image_kwargs": { "max_image_tiles": 1, }, "add_video_position_encoding": True, } def __init__(self, image_processor, tokenizer, video_fps = None, video_minlen = None, video_maxlen = None,frame_extract_num_threads = None, extract_frame_func = None, max_image_tiles: Optional[int] = None, **kwargs): # User-facing placeholders self.image_placeholder = "" self.video_placeholder = "