| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Processor class for VideoMllama.""" |
|
|
| import av |
| import cv2 |
| import math |
| import numpy as np |
| import concurrent.futures |
|
|
| from PIL import Image |
| from typing import List, Optional, Union, Tuple |
|
|
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_utils import ImageInput, to_numpy_array |
| from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack |
| from transformers.tokenization_utils_base import ( |
| PreTokenizedInput, |
| TextInput, |
| ) |
|
|
| from .image_processing_video_mllama import make_list_of_images |
|
|
|
|
|
|
|
|
|
|
| class VideoMllamaImagesKwargs(ImagesKwargs, total=False): |
| max_image_tiles: Optional[int] |
|
|
|
|
| class VideoMllamaProcessorKwargs(ProcessingKwargs, total=False): |
| images_kwargs: VideoMllamaImagesKwargs |
| add_video_position_encoding: Optional[bool] |
|
|
| _defaults = { |
| "image_kwargs": { |
| "max_image_tiles": 1, |
| }, |
| "add_video_position_encoding": True, |
| } |
|
|
|
|
| |
|
|
|
|
| def validate_frame_sampling(sample_indices, frames, max_missing_frames=2, max_missing_ratio=0.1): |
| """ |
| Validate the completeness of sampled frames. |
| """ |
| expected_count = len(sample_indices) |
| actual_count = len(frames) |
| missing_count = expected_count - actual_count |
|
|
| if missing_count <= 0: |
| return |
|
|
| missing_ratio = missing_count / expected_count |
|
|
| if missing_count > max_missing_frames and missing_ratio > max_missing_ratio: |
| raise ValueError( |
| f"Too many frames missing: {missing_count}/{expected_count} " |
| f"({missing_ratio:.1%}) frames missing, exceeding " |
| f"{max_missing_ratio:.0%} threshold." |
| ) |
|
|
|
|
| def _get_video_sample_frames(video_stream, total_frames: int = 0, **kwargs) -> np.ndarray: |
| """ |
| Core logic to compute video sample frame indices. |
| """ |
| video_fps: float = kwargs.get("video_fps", 1.0) |
| video_minlen: int = kwargs.get("video_minlen", 8) |
| video_maxlen: int = kwargs.get("video_maxlen", 256) |
|
|
| obtained_total_frames = int(video_stream.frames) |
|
|
| duration = float(video_stream.duration * video_stream.time_base) |
| frame_rate = float(video_stream.average_rate) |
| calculated_total_frames = round(duration * frame_rate) |
| assert video_fps <= frame_rate, f"Sampling frequency ({video_fps}) must be less than or equal to video frame rate ({frame_rate})" |
|
|
| total_frames_num = [x for x in [total_frames, obtained_total_frames, calculated_total_frames] if x > 0] |
| final_total_frames = min(total_frames_num) if total_frames_num else 0 |
| if final_total_frames == 0: |
| raise AttributeError("Unable to obtain or calculate the total number of frames in the video.") |
|
|
| target_total_frames = int(math.ceil(duration * video_fps - 1e-6)) |
| sample_frames = max(target_total_frames, video_minlen) |
| sample_frames = min(sample_frames, video_maxlen, final_total_frames) |
|
|
| if target_total_frames == sample_frames and video_fps > 0 and frame_rate > 0: |
| sample_indices = np.arange(target_total_frames, dtype=np.int32) |
| sample_indices = (sample_indices * frame_rate / video_fps).astype(np.int32) |
| else: |
| sample_indices = np.linspace(0, final_total_frames - 1, sample_frames).astype(np.int32) |
|
|
| return sample_indices |
|
|
|
|
| def _get_cv2_video_sample_frames(video_path: str, total_frames: int = 0, **kwargs) -> np.ndarray: |
| container = av.open(video_path, "r") |
| video_stream = next(stream for stream in container.streams if stream.type == "video") |
| sample_indices = _get_video_sample_frames(video_stream, total_frames=total_frames, **kwargs) |
| return sample_indices |
|
|
|
|
| def get_video_sample_frames_av(video_path: str, **kwargs) -> List[Image.Image]: |
| container = av.open(video_path, "r") |
| video_stream = next(stream for stream in container.streams if stream.type == "video") |
|
|
| sample_indices = _get_video_sample_frames(video_stream, **kwargs) |
| sample_indices_set = set(sample_indices) |
|
|
| frames: List[Image.Image] = [] |
|
|
| container.seek(0) |
| for frame_idx, frame in enumerate(container.decode(video_stream)): |
| if frame_idx in sample_indices_set: |
| frames.append(frame.to_image()) |
| if len(frames) == len(sample_indices): |
| break |
|
|
| validate_frame_sampling(sample_indices, frames) |
| |
| return frames |
|
|
|
|
| def get_cv2_video_sample_frames_multithread(video_path: str, **kwargs) -> List[Image.Image]: |
| num_threads: int = kwargs.get("frame_extract_num_threads", 4) |
| num_threads = int(num_threads) |
|
|
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise ValueError(f"Unable to open video file: {video_path}") |
|
|
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| cap.release() |
|
|
| frame_indices = _get_cv2_video_sample_frames(video_path, total_frames=total_frames, **kwargs) |
|
|
| unique_frames: List[Optional[np.ndarray]] = [None] * len(frame_indices) |
| index_map = {idx: pos for pos, idx in enumerate(frame_indices)} |
|
|
| chunks = np.array_split(frame_indices, min(num_threads, len(frame_indices))) |
|
|
| def worker(chunk_indices): |
| local_cap = cv2.VideoCapture(video_path) |
| if not local_cap.isOpened(): |
| return |
|
|
| if chunk_indices[0] > 0: |
| local_cap.set(cv2.CAP_PROP_POS_FRAMES, chunk_indices[0]) |
|
|
| frame_idx_cursor = chunk_indices[0] |
| chunk_cursor = 0 |
|
|
| while chunk_cursor < len(chunk_indices): |
| target_idx = chunk_indices[chunk_cursor] |
| ok = local_cap.grab() |
| if not ok: |
| break |
|
|
| if frame_idx_cursor == target_idx: |
| ret, frame = local_cap.retrieve() |
| if ret: |
| unique_pos = index_map[target_idx] |
| unique_frames[unique_pos] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| chunk_cursor += 1 |
| frame_idx_cursor += 1 |
| local_cap.release() |
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: |
| list(executor.map(worker, [chunk for chunk in chunks if len(chunk) > 0])) |
|
|
| pil_frames = [Image.fromarray(frame) for frame in unique_frames if frame is not None] |
|
|
| validate_frame_sampling(frame_indices, pil_frames) |
|
|
| if not pil_frames: |
| return get_video_sample_frames_av(video_path, **kwargs) |
|
|
| return pil_frames |
|
|
|
|
| |
|
|
|
|
| def get_cross_attention_token_mask( |
| input_ids: List[int], |
| attention_mask: List[int], |
| image_token_id: int, |
| video_token_id: int, |
| frame_num_per_video: List[int], |
| cross_attention_token_mask_pad_token_id: int = -100, |
| ) -> Tuple[List[int], List[int], List[int]]: |
| """ |
| Generate a cross-attention-token-mask for each input_tokens in the input sequence. |
| This function implements a causal attention logic: |
| - A text token can see all image tokens that appeared before it. |
| - An image token can see itself and all image tokens that appeared before it. |
| """ |
| |
| input_ids_np = np.array(input_ids, dtype=np.int64) |
| if video_token_id in input_ids_np: |
| total_vid_num = np.sum(input_ids_np == video_token_id) |
| f_num_per_vid = frame_num_per_video[:total_vid_num] |
|
|
| convert_input_ids_list = [] |
| convert_attention_mask_list = [] |
| vid_idx = 0 |
| for token_id, mask_val in zip(input_ids_np, attention_mask): |
| if token_id == video_token_id: |
| vid_len = f_num_per_vid[vid_idx] |
| vid_idx += 1 |
| convert_input_ids_list.extend([image_token_id] * vid_len) |
| convert_attention_mask_list.extend([mask_val] * vid_len) |
| else: |
| convert_input_ids_list.append(token_id) |
| convert_attention_mask_list.append(mask_val) |
| convert_input_ids = np.array(convert_input_ids_list, dtype=np.int64) |
| convert_attention_mask = np.array(convert_attention_mask_list, dtype=np.int64) |
| else: |
| convert_input_ids = input_ids_np |
| convert_attention_mask = np.array(attention_mask, dtype=np.int64) |
|
|
| |
| is_image = convert_input_ids == image_token_id |
| |
| image_count_cumulative = np.cumsum(is_image) |
| |
| image_count_before = np.pad(image_count_cumulative[:-1], (1, 0), "constant", constant_values=0) |
|
|
| |
| |
| num_images_seen = np.where(is_image, image_count_cumulative, image_count_before) |
|
|
| |
| vision_masks = np.full(len(convert_input_ids), cross_attention_token_mask_pad_token_id, dtype=np.int64) |
| valid_mask = num_images_seen > 0 |
| vision_masks[valid_mask] = num_images_seen[valid_mask] - 1 |
|
|
| return vision_masks.tolist(), convert_input_ids.tolist(), convert_attention_mask.tolist() |
|
|
|
|
| def convert_sparse_cross_attention_mask_to_dense( |
| cross_attention_token_masks: np.ndarray, |
| num_tiles: List[List[int]], |
| max_num_tiles: int, |
| cross_attention_token_mask_pad_token_id: int = -100, |
| ) -> np.ndarray: |
| """ |
| Convert the cross attention mask indices to a cross attention mask 4D array. |
| |
| This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array. |
| The sparse representation is a tensor that defines [the range of images that can be seen] for [each input token]. |
| """ |
| batch_size, length = cross_attention_token_masks.shape |
| max_num_images = max([len(n_tiles) for n_tiles in num_tiles]) if num_tiles else 0 |
|
|
| cross_attention_mask = np.zeros( |
| shape=(batch_size, length, max_num_images, max_num_tiles), |
| dtype=np.int64, |
| ) |
|
|
| if max_num_images == 0: |
| return cross_attention_mask |
|
|
| for batch_idx, (sparse_mask, n_tiles) in enumerate(zip(cross_attention_token_masks, num_tiles)): |
| |
| |
| for image_idx, mask_n_tiles in enumerate(n_tiles): |
| |
| |
| visible_token_indices = (sparse_mask >= image_idx) & (sparse_mask != cross_attention_token_mask_pad_token_id) |
| |
| cross_attention_mask[batch_idx, visible_token_indices, image_idx, :mask_n_tiles] = 1 |
|
|
| return cross_attention_mask |
|
|
|
|
| def build_string_from_input(prompt: str, bos_token: str, image_token: str, video_token: str) -> str: |
| """ |
| Builds a string from the input prompt by adding `bos_token` if not already present. |
| It handles prompts starting with image or video tokens. |
| """ |
|
|
| if bos_token in prompt: |
| return prompt |
|
|
| num_media_tokens_on_start = 0 |
| media_tokens = [] |
| |
| while prompt.startswith(image_token) or prompt.startswith(video_token): |
| if prompt.startswith(image_token): |
| prompt = prompt[len(image_token) :] |
| media_tokens.append(image_token) |
| elif prompt.startswith(video_token): |
| prompt = prompt[len(video_token) :] |
| media_tokens.append(video_token) |
| num_media_tokens_on_start += 1 |
| |
| print(f"No bos_token `{bos_token}` in prompt, so it is added after the {num_media_tokens_on_start} media tokens at the start of the prompt.") |
|
|
| return f"{''.join(media_tokens)}{bos_token}{prompt}" |
|
|
|
|
| VIDEO_MLLAMA_PROCESSOR_PAD_POSITION_ID = 0 |
| VIDEO_MLLAMA_PROCESSOR_CROSS_ATTENTION_TOKEN_MASK_PAD_TOKEN_ID = -100 |
|
|
| class VideoMllamaProcessor(ProcessorMixin): |
| r""" |
| Constructs a VideoMllama processor which wraps [`VideoMllamaImageProcessor`] and |
| [`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and |
| tokenizer functionalities. See the [`~VideoMllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more |
| information. |
| The preferred way of passing kwargs is as a dictionary per modality, see usage example below. |
| ```python |
| from transformers import VideoMllamaProcessor |
| from PIL import Image |
| |
| processor = VideoMllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision") |
| |
| processor( |
| images=your_pil_image, |
| text=["<|image|>If I had to write a haiku for this one"], |
| images_kwargs = {"size": {"height": 448, "width": 448}}, |
| text_kwargs = {"padding": "right"}, |
| common_kwargs = {"return_tensors": "pt"}, |
| ) |
| ``` |
| |
| Args: |
| image_processor ([`VideoMllamaImageProcessor`]): |
| The image processor is a required input. |
| tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]): |
| The tokenizer is a required input. |
| |
| """ |
|
|
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "PreTrainedTokenizerFast" |
|
|
| _defaults = { |
| "image_kwargs": { |
| "max_image_tiles": 1, |
| }, |
| "add_video_position_encoding": True, |
| } |
|
|
| def __init__(self, image_processor, tokenizer, video_fps = None, video_minlen = None, video_maxlen = None,frame_extract_num_threads = None, extract_frame_func = None, max_image_tiles: Optional[int] = None, **kwargs): |
| |
| self.image_placeholder = "<image>" |
| self.video_placeholder = "<video>" |
| self.tokenizer = tokenizer |
|
|
| super().__init__(image_processor, tokenizer) |
| if max_image_tiles is not None: |
| self.image_processor.max_image_tiles = max_image_tiles |
| |
| if not hasattr(self.tokenizer, "image_token"): |
| self.image_token = "<|image|>" |
| self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) |
| else: |
| self.image_token = self.tokenizer.image_token |
| self.image_token_id = self.tokenizer.image_token_id |
|
|
| if not hasattr(self.tokenizer, "video_token"): |
| self.video_token = "<|video|>" |
| self.video_token_id = self.tokenizer.convert_tokens_to_ids(self.video_token) |
| else: |
| self.video_token = self.tokenizer.video_token |
| self.video_token_id = self.tokenizer.video_token_id |
|
|
| self.add_video_position_encoding = self._defaults["add_video_position_encoding"] |
| |
| self.pad_position_id = VIDEO_MLLAMA_PROCESSOR_PAD_POSITION_ID |
| self.cross_attention_token_mask_pad_token_id = VIDEO_MLLAMA_PROCESSOR_CROSS_ATTENTION_TOKEN_MASK_PAD_TOKEN_ID |
|
|
| |
| if video_fps is None: |
| self.video_fps = getattr(image_processor, "video_fps", 1.0) |
| if video_minlen is None: |
| self.video_minlen = getattr(image_processor, "video_minlen", 8) |
| if video_maxlen is None: |
| self.video_maxlen = getattr(image_processor, "video_maxlen", 256) |
| if frame_extract_num_threads is None: |
| self.frame_extract_num_threads = getattr(image_processor, "frame_extract_num_threads", 4) |
| if extract_frame_func is None: |
| self.extract_frame_func = getattr(image_processor, "extract_frame_func", "cv2") |
|
|
| self.bos_token = self.tokenizer.bos_token |
| self.chat_template = self.tokenizer.chat_template |
|
|
| def _pad_sequences( |
| self, |
| sequences: List[List[int]], |
| pad_value: int, |
| max_length: Optional[int] = None, |
| padding_side: str = "right", |
| pad_to_multiple_of: Optional[int] = None, |
| ) -> np.ndarray: |
| if not sequences: |
| return np.array([], dtype=np.int64) |
|
|
| if max_length is None: |
| max_length = max(len(seq) for seq in sequences) |
|
|
| if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): |
| max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of |
|
|
| padded_sequences = np.full((len(sequences), max_length), pad_value, dtype=np.int64) |
| |
| for i, seq in enumerate(sequences): |
| length = len(seq) |
| if padding_side == "right": |
| padded_sequences[i, :length] = seq |
| else: |
| padded_sequences[i, -length:] = seq |
| return padded_sequences |
|
|
|
|
| def __call__( |
| self, |
| text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, |
| images: Optional[Union[ImageInput, str, List[str], List[List[str]]]] = None, |
| videos: Optional[Union[str, List[str], List[List[str]]]] = None, |
| **kwargs: Unpack[VideoMllamaProcessorKwargs], |
| ) -> BatchFeature: |
| """ |
| Main method to prepare text(s), image(s) and video(s) to be fed as input to the model. This method forwards the `text` |
| arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode |
| the text. To prepare the image(s), this method forwards the `images` arguments to |
| VideoMllamaImageProcessor's [`~VideoMllamaImageProcessor.__call__`] if `images` is not `None`. Videos are first |
| processed into frames and then handled as images. |
| |
| Args: |
| text (`str`, `List[str]`, `List[List[str]]`): |
| The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
| (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
| `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
| images (`str`, `PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[str]`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): |
| The image or batch of images to be prepared. Can be a single image path, a single PIL image, or a |
| list of paths or PIL images. When a list of single images is passed, it's treated as a batch of samples. |
| videos (`str`, `List[str]`, `List[List[str]]`): |
| The video or batch of videos to be prepared. Can be a single video path, a list of video paths (which |
| is treated as a batch), or a list of lists of video paths (already batched). |
| return_tensors (`str` or [`~utils.TensorType`], *optional*): |
| If set, will return tensors of a particular framework. Acceptable values are: |
| - `'tf'`: Return TensorFlow `tf.constant` objects. |
| - `'pt'`: Return PyTorch `torch.Tensor` objects. |
| - `'np'`: Return NumPy `np.ndarray` objects. |
| - `'jax'`: Return JAX `jnp.ndarray` objects. |
| Returns: |
| [`BatchFeature`]: A [`BatchFeature`] with the following fields: |
| |
| - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. |
| - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when |
| `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not |
| `None`). |
| - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. |
| """ |
| |
| |
| if images is not None and isinstance(images, list) and len(images) == 0: |
| images = None |
| if videos is not None and isinstance(videos, list) and len(videos) == 0: |
| videos = None |
| |
| |
| if text is None and images is None and videos is None: |
| raise ValueError("You have to specify either `text`, `images` or `videos`.") |
| |
| if text is None and not (images or videos): |
| raise ValueError("If no text is provided, at least one of images or videos must be present.") |
|
|
| |
| video_fps = kwargs.pop("video_fps", self.video_fps) |
| video_minlen = kwargs.pop("video_minlen", self.video_minlen) |
| video_maxlen = kwargs.pop("video_maxlen", self.video_maxlen) |
| frame_extract_num_threads = kwargs.pop("frame_extract_num_threads", self.frame_extract_num_threads) |
| extract_frame_func = kwargs.pop("extract_frame_func", self.extract_frame_func) |
| max_image_tiles = kwargs.pop("max_image_tiles", self.image_processor.max_image_tiles) |
|
|
| |
| |
| |
| if text is not None and isinstance(text, str): |
| text = [text] |
| |
| if text is not None: |
| processed_text: List[str] = [] |
| for t in text: |
| |
| if isinstance(t, str): |
| t = t.replace(self.image_placeholder, self.image_token) |
| t = t.replace(self.video_placeholder, self.video_token) |
| processed_text.append(t) |
| text = processed_text |
|
|
| |
| |
| |
| loaded_images: Optional[ImageInput] = None |
| if images is not None: |
| if isinstance(images, str): |
| loaded_images = Image.open(images) |
| elif isinstance(images, list) and all(isinstance(i, str) for i in images): |
| loaded_images = [Image.open(i) for i in images] |
| elif isinstance(images, list) and all(isinstance(i, list) for i in images) and all(isinstance(j, str) for i in images for j in i): |
| loaded_images = [[Image.open(j) for j in i] for i in images] |
| else: |
| |
| loaded_images = images |
| images_list = make_list_of_images(loaded_images) if loaded_images is not None else None |
|
|
| |
| |
| videos_list: Optional[List[List[str]]] = None |
| if videos is not None: |
| if isinstance(videos, str): |
| videos_list = [[videos]] |
| elif isinstance(videos, list) and all(isinstance(i, str) for i in videos): |
| videos_list = [videos] |
| elif isinstance(videos, list) and all(isinstance(i, list) for i in videos) and all(isinstance(j, str) for i in videos for j in i): |
| videos_list = videos |
| else: |
| raise ValueError(f"Invalid video input type: {type(videos)}") |
|
|
| |
| |
| output_kwargs = self._merge_kwargs( |
| VideoMllamaProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
| |
| common_kwargs = output_kwargs["common_kwargs"] |
| text_kwargs = output_kwargs["text_kwargs"] |
| images_kwargs = output_kwargs["images_kwargs"] |
|
|
| |
| data = {} |
|
|
| |
| original_encoding = {} |
| if text is not None: |
| batch_size = len(text) |
| |
| if images_list is not None and len(images_list) != batch_size: |
| raise ValueError( |
| f"The number of samples in `text` ({batch_size}) and `images` ({len(images_list)}) do not match." |
| ) |
| if videos_list is not None and len(videos_list) != batch_size: |
| raise ValueError( |
| f"The number of samples in `text` ({batch_size}) and `videos` ({len(videos_list)}) do not match." |
| ) |
|
|
| n_images_in_text = [t.count(self.image_token) for t in text] |
| n_videos_in_text = [t.count(self.video_token) for t in text] |
|
|
| |
| n_images_provided = [len(img) for img in images_list] if images_list else [0] * len(text) |
| if any(img1 != img2 for img1, img2 in zip(n_images_in_text, n_images_provided)): |
| raise ValueError( |
| "Number of image tokens does not match number of images provided. " |
| f"Found {n_images_in_text} image tokens and {n_images_provided} images." |
| ) |
|
|
| n_videos_provided = [len(vid) for vid in videos_list] if videos_list else [0] * len(text) |
| if any(vid1 != vid2 for vid1, vid2 in zip(n_videos_in_text, n_videos_provided)): |
| raise ValueError( |
| "Number of video tokens does not match number of videos provided. " |
| f"Found {n_videos_in_text} video tokens and {n_videos_provided} videos." |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| text_to_encode = text |
| |
| original_encoding = self.tokenizer(text_to_encode, **text_kwargs) |
| |
| |
| if original_encoding: |
| data.update(original_encoding) |
|
|
| |
| |
| batch_media = [] |
| frame_num_per_video_batch = [] |
| |
| num_samples = 0 |
| if text is not None: |
| num_samples = len(text) |
| else: |
| |
| |
| num_img_samples = len(images_list) if images_list else 0 |
| num_vid_samples = len(videos_list) if videos_list else 0 |
| num_samples = max(num_img_samples, num_vid_samples) |
| |
| |
| print("No text provided, processing media in default order (images first, then videos).") |
|
|
| for i in range(num_samples): |
| sample_images = images_list[i] if images_list and i < len(images_list) else [] |
| sample_videos = videos_list[i] if videos_list and i < len(videos_list) else [] |
|
|
| |
| media_order = [] |
| if text is not None: |
| |
| txt = text[i] |
| img_tokens_in_text = [(pos, "image") for pos, char in enumerate(txt) if txt.startswith(self.image_token, pos)] |
| vid_tokens_in_text = [(pos, "video") for pos, char in enumerate(txt) if txt.startswith(self.video_token, pos)] |
| |
| media_order = [media_type for _, media_type in sorted(img_tokens_in_text + vid_tokens_in_text)] |
| else: |
| |
| media_order.extend(["image"] * len(sample_images)) |
| media_order.extend(["video"] * len(sample_videos)) |
|
|
| |
| media_for_sample = [] |
| num_frames_per_video_sample = [] |
| image_idx = 0 |
| video_idx = 0 |
|
|
| for media_type in media_order: |
| if media_type == "image": |
| media_for_sample.append(sample_images[image_idx]) |
| image_idx += 1 |
| elif media_type == "video": |
| video_path = sample_videos[video_idx] |
|
|
| sampling_kwargs = { |
| "video_fps": video_fps, |
| "video_minlen": video_minlen, |
| "video_maxlen": video_maxlen, |
| "frame_extract_num_threads": frame_extract_num_threads, |
| } |
| try: |
| if extract_frame_func == "cv2": |
| frames = get_cv2_video_sample_frames_multithread(video_path, **sampling_kwargs) |
| else: |
| frames = get_video_sample_frames_av(video_path, **sampling_kwargs) |
| except Exception as e: |
| raise ValueError(f"This video format is not supported.\nvideo processing failed: {e}") |
|
|
| media_for_sample.extend(frames) |
| num_frames_per_video_sample.append(len(frames)) |
| video_idx += 1 |
|
|
| batch_media.append(media_for_sample) |
| frame_num_per_video_batch.append(num_frames_per_video_sample) |
|
|
| |
| num_tiles_batch = [] |
| if any(batch_media): |
| images_kwargs["max_image_tiles"] = max_image_tiles |
| image_features = self.image_processor(batch_media, **images_kwargs) |
| num_tiles_batch = image_features.pop("num_tiles") |
| data.update(image_features) |
|
|
| |
| if original_encoding: |
| cross_attention_token_masks = [] |
| final_input_ids = [] |
| final_attention_mask = [] |
|
|
| |
| has_attention_mask = "attention_mask" in original_encoding |
|
|
| for i, token_ids in enumerate(original_encoding["input_ids"]): |
| |
| attention_mask_for_sample = ( |
| original_encoding["attention_mask"][i] if has_attention_mask else [1] * len(token_ids) |
| ) |
| |
| mask, converted_ids, converted_attn_mask = get_cross_attention_token_mask( |
| token_ids, |
| attention_mask_for_sample, |
| self.image_token_id, |
| self.video_token_id, |
| frame_num_per_video_batch[i], |
| self.cross_attention_token_mask_pad_token_id, |
| ) |
| |
| cross_attention_token_masks.append(np.array(mask)) |
| final_input_ids.append(converted_ids) |
| final_attention_mask.append(converted_attn_mask) |
|
|
| |
| if original_encoding: |
| batch_position_ids = [] |
| batch_vision_position_ids = [] |
|
|
| |
| |
| for i, ids in enumerate(final_input_ids): |
| |
| ids_arr = np.array(ids, dtype=np.int64) |
| attention_mask_arr = np.array(final_attention_mask[i], dtype=np.int64) |
| image_mask = ids_arr == self.image_token_id |
|
|
| |
| |
| position_ids = np.cumsum(attention_mask_arr, dtype=np.int64) - 1 |
| position_ids[attention_mask_arr == 0] = self.pad_position_id |
| batch_position_ids.append(position_ids) |
|
|
| |
| |
| if self.add_video_position_encoding and any(batch_media): |
| batch_vision_position_ids.append(position_ids[image_mask]) |
|
|
| |
| if original_encoding: |
| |
| data["input_ids"] = self._pad_sequences( |
| final_input_ids, |
| self.tokenizer.pad_token_id, |
| padding_side=self.tokenizer.padding_side, |
| pad_to_multiple_of=text_kwargs.get("pad_to_multiple_of"), |
| ) |
| |
| |
| data["attention_mask"] = self._pad_sequences( |
| final_attention_mask, |
| 0, |
| padding_side=self.tokenizer.padding_side, |
| pad_to_multiple_of=text_kwargs.get("pad_to_multiple_of"), |
| ) |
| |
| |
| padded_cross_attention_token_masks = self._pad_sequences( |
| cross_attention_token_masks, |
| self.cross_attention_token_mask_pad_token_id, |
| padding_side=self.tokenizer.padding_side, |
| pad_to_multiple_of=text_kwargs.get("pad_to_multiple_of"), |
| ) |
| |
| |
| data["position_ids"] = self._pad_sequences( |
| batch_position_ids, |
| self.pad_position_id, |
| padding_side=self.tokenizer.padding_side, |
| pad_to_multiple_of=text_kwargs.get("pad_to_multiple_of"), |
| ) |
| |
| |
| if self.add_video_position_encoding and any(batch_media): |
| |
| |
| data["vision_position_ids"] = self._pad_sequences( |
| batch_vision_position_ids, |
| self.pad_position_id, |
| padding_side="right", |
| ) |
|
|
| |
| |
| |
| if original_encoding and num_tiles_batch: |
| |
| cross_attention_mask = convert_sparse_cross_attention_mask_to_dense( |
| padded_cross_attention_token_masks, |
| num_tiles=num_tiles_batch, |
| max_num_tiles=max_image_tiles, |
| cross_attention_token_mask_pad_token_id=self.cross_attention_token_mask_pad_token_id, |
| ) |
| data["cross_attention_mask"] = cross_attention_mask |
|
|
| |
| return_tensors = common_kwargs.pop("return_tensors", None) |
| batch_feature = BatchFeature(data=data, tensor_type=return_tensors) |
|
|
| return batch_feature |
|
|
| def batch_decode(self, *args, **kwargs): |
| """ |
| This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
| refer to the docstring of this method for more information. |
| """ |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| """ |
| This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
| the docstring of this method for more information. |
| """ |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| def post_process_image_text_to_text(self, generated_outputs): |
| """ |
| Post-process the output of the model to decode the text. |
| |
| Args: |
| generated_outputs (`torch.Tensor` or `np.ndarray`): |
| The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` |
| or `(sequence_length,)`. |
| |
| Returns: |
| `List[str]`: The decoded text. |
| """ |
| return self.tokenizer.batch_decode( |
| generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| ) |
|
|
| @property |
| def model_input_names(self): |
| tokenizer_input_names = self.tokenizer.model_input_names |
| image_processor_input_names = self.image_processor.model_input_names |
| model_inputs = list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"]) |
| if self.add_video_position_encoding: |
| model_inputs.extend(["position_ids", "vision_position_ids"]) |
| return model_inputs |
|
|