# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for Eagle3_VL. copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py """ import math import os from typing import Iterable, List, Union, Literal import base64 import sys import time import warnings from functools import lru_cache from io import BytesIO import re import requests import torch import torchvision from packaging import version from PIL import Image from torchvision import io from torchvision import transforms from torch.nn import functional as F from torchvision.transforms import InterpolationMode from typing import Optional, Any import numpy as np from transformers.feature_extraction_utils import BatchFeature from transformers.image_processing_utils import select_best_resolution from transformers.image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from transformers.models.auto import AutoImageProcessor import lmdb import cv2 import pickle logger = logging.get_logger(__name__) # Highly inspired by https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py FRAME_FACTOR = 2 FPS = 2.0 FPS_MIN_FRAMES = 4 FPS_MAX_FRAMES = 256 IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 4096 * 28 * 28 MAX_RATIO = 200 IMAGE_MAX_SIZE = 500 * 14 VIDEO_MIN_PIXELS = 128 * 28 * 28 VIDEO_MAX_PIXELS = 768 * 28 * 28 # Set the maximum number of video token inputs. # Here, 128K represents the maximum number of input tokens for the VLLM model. # Remember to adjust it according to your own configuration. VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9))) logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}") def adjust_by_factor(number: int, factor: int, method: Literal['round', 'ceil', 'floor'] = 'round') -> int: """Adjusts 'number' to the nearest, ceiling, or floor multiple of 'factor'.""" op = {'round': round, 'ceil': math.ceil, 'floor': math.floor}[method] return op(number / factor) * factor def to_rgb(pil_image: Image.Image) -> Image.Image: if pil_image.mode == 'RGBA': white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask return white_background else: return pil_image.convert("RGB") def smart_resize( height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS ) -> tuple[int, int]: """ Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ if max(height, width) / min(height, width) > MAX_RATIO: raise ValueError( f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" ) h_bar = min(max(factor, adjust_by_factor(height, factor, method='round')), IMAGE_MAX_SIZE) w_bar = min(max(factor, adjust_by_factor(width, factor, method='round')), IMAGE_MAX_SIZE) if h_bar * w_bar > max_pixels: beta = math.sqrt((h_bar * w_bar) / max_pixels) h_bar = adjust_by_factor(h_bar / beta, factor, method='floor') w_bar = adjust_by_factor(w_bar / beta, factor, method='floor') elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = adjust_by_factor(height * beta, factor, method='ceil') w_bar = adjust_by_factor(width * beta, factor, method='ceil') return h_bar, w_bar def read_img_from_lmdb_v2(image_data): # special case for AgiBotWorld lmdb_file, lmdb_key = image_data['lmdb_file'], image_data['lmdb_key'] key = lmdb_key.encode('ascii') env = lmdb.open(lmdb_file, max_readers=10240, readonly=True, lock=False, readahead=False, meminit=False) txn = env.begin() value = txn.get(key) if value is None: print(f"Warning: Key {key} not found.") return None record = pickle.loads(value) image_bgr = cv2.imdecode(np.frombuffer(record['image'], dtype=np.uint8), cv2.IMREAD_COLOR) image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) image = Image.fromarray(image_rgb) return image def parse_lmdb_image_data(image_data): lmdb_file = image_data['lmdb_file'] if not os.path.exists(lmdb_file): if "/home/zhidingy/workspace/libs/eagle/Eagle2/" in lmdb_file: lmdb_file = lmdb_file.replace("/home/zhidingy/workspace/libs/eagle/Eagle2/", "") else: raise ValueError(f"LMDB file {lmdb_file} does not exist") # special case for AgiBotWorld, will remove it later if 'AgiBotWorld' in image_data['lmdb_file']: return read_img_from_lmdb_v2(image_data) try: env = lmdb.open(image_data['lmdb_file'], readonly=True, lock=False, max_readers=10240) except Exception as e: print(f"Failed to open lmdb file {image_data['lmdb_file']}. Error message: {e}", flush=True) raise e with env.begin(write=False) as txn: try: image_bin = txn.get(image_data['lmdb_key'].encode('ascii')) buf = BytesIO(image_bin) except Exception as e: print(f"Failed to get image from lmdb file {image_data['lmdb_file']}. Error message: {e}", flush=True) raise e try: image = Image.open(buf) except Exception as e: image_np = np.frombuffer(image_bin, dtype=np.uint8) image_bgr = cv2.imdecode(image_np, cv2.IMREAD_COLOR) image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) image = Image.fromarray(image_rgb) return image def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image: if "image" in ele: image = ele["image"] else: image = ele["image_url"] image_obj = None if isinstance(image, Image.Image): image_obj = image elif isinstance(image, dict) and 'lmdb_file' in image: image_obj = parse_lmdb_image_data(image) elif image.startswith("http://") or image.startswith("https://"): response = requests.get(image, stream=True) image_obj = Image.open(BytesIO(response.content)) elif image.startswith("file://"): image_obj = Image.open(image[7:]) elif image.startswith("data:image"): if "base64," in image: _, base64_data = image.split("base64,", 1) data = base64.b64decode(base64_data) image_obj = Image.open(BytesIO(data)) else: image_obj = Image.open(image) if image_obj is None: raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") image = to_rgb(image_obj) # if 'scale_factor' in ele: # scale_factor = ele['scale_factor'] # image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR) if "resized_height" in ele and "resized_width" in ele: resized_height, resized_width = smart_resize( ele["resized_height"], ele["resized_width"], factor=size_factor, ) else: width, height = image.size min_pixels = ele.get("min_pixels", MIN_PIXELS) max_pixels = ele.get("max_pixels", MAX_PIXELS) resized_height, resized_width = smart_resize( height, width, factor=size_factor, min_pixels=min_pixels, max_pixels=max_pixels, ) image = image.resize((resized_width, resized_height)) return image def smart_nframes( ele: dict, total_frames: int, video_fps: int | float, ) -> int: """calculate the number of frames for video used for model inputs. Args: ele (dict): a dict contains the configuration of video. support either `fps` or `nframes`: - nframes: the number of frames to extract for model inputs. - fps: the fps to extract frames for model inputs. - min_frames: the minimum number of frames of the video, only used when fps is provided. - max_frames: the maximum number of frames of the video, only used when fps is provided. total_frames (int): the original total number of frames of the video. video_fps (int | float): the original fps of the video. Raises: ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. Returns: int: the number of frames for video used for model inputs. """ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`" if "nframes" in ele: nframes = adjust_by_factor(ele["nframes"], FRAME_FACTOR, method='round') else: fps = ele.get("fps", FPS) min_frames = adjust_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR, method='ceil') max_frames = adjust_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR, method='floor') nframes = total_frames / video_fps * fps if nframes > total_frames: logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]") nframes = min(min(max(nframes, min_frames), max_frames), total_frames) nframes = adjust_by_factor(nframes, FRAME_FACTOR, method='floor') if not (FRAME_FACTOR <= nframes and nframes <= total_frames): # raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.") nframes = total_frames return nframes def _read_video_torchvision( ele: dict, ) -> (torch.Tensor, float, list): """read video using torchvision.io.read_video and return also per-frame timestamps""" video_path = ele["video"] if version.parse(torchvision.__version__) < version.parse("0.19.0"): if "http://" in video_path or "https://" in video_path: warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.") if "file://" in video_path: video_path = video_path[7:] st = time.time() video, audio, info = io.read_video( video_path, start_pts=ele.get("video_start", 0.0), end_pts=ele.get("video_end", None), pts_unit="sec", output_format="TCHW", ) total_frames, video_fps = video.size(0), info["video_fps"] logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) # Calculate frame indices and corresponding timestamps (based on video start time) idx = torch.linspace(0, total_frames - 1, nframes).round().long() start_time = ele.get("video_start", 0.0) timestamps = (start_time + idx.to(torch.float32) / video_fps).tolist() sample_fps = nframes / max(total_frames, 1e-6) * video_fps video = video[idx] return video, sample_fps, timestamps def is_decord_available() -> bool: import importlib.util return importlib.util.find_spec("decord") is not None def _read_video_decord( ele: dict, ) -> (torch.Tensor, float, list): """read video using decord.VideoReader and return also per-frame timestamps""" import decord video_path = ele["video"] st = time.time() vr = decord.VideoReader(video_path) if 'video_start' in ele or 'video_end' in ele: raise NotImplementedError("not support start_pts and end_pts in decord for now.") total_frames, video_fps = len(vr), vr.get_avg_fps() logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s") nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() start_time = ele.get("video_start", 0.0) # TODO: timestamps = [start_time + i / video_fps for i in idx] video = vr.get_batch(idx).asnumpy() video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format sample_fps = nframes / max(total_frames, 1e-6) * video_fps return video, sample_fps, timestamps VIDEO_READER_BACKENDS = { "decord": _read_video_decord, "torchvision": _read_video_torchvision, } @lru_cache(maxsize=1) def get_video_reader_backend() -> str: if is_decord_available(): video_reader_backend = "decord" else: video_reader_backend = "torchvision" return video_reader_backend def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]: if isinstance(ele["video"], str): video_reader_backend = get_video_reader_backend() try: video, sample_fps, timestamps = VIDEO_READER_BACKENDS[video_reader_backend](ele) except Exception as e: logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}") video, sample_fps, timestamps = VIDEO_READER_BACKENDS["torchvision"](ele) nframes, _, height, width = video.shape min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05)) max_pixels_supposed = ele.get("max_pixels", max_pixels) if max_pixels_supposed > max_pixels: logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].") max_pixels = min(max_pixels_supposed, max_pixels) if "resized_height" in ele and "resized_width" in ele: resized_height, resized_width = smart_resize( ele["resized_height"], ele["resized_width"], factor=image_factor, ) else: resized_height, resized_width = smart_resize( height, width, factor=image_factor, min_pixels=min_pixels, max_pixels=max_pixels, ) video = transforms.functional.resize( video, [resized_height, resized_width], interpolation=InterpolationMode.BICUBIC, antialias=True, ).float() if return_video_sample_fps: return video, sample_fps, timestamps return video else: assert isinstance(ele["video"], (list, tuple)) process_info = ele.copy() process_info.pop("type", None) process_info.pop("video", None) images = [ fetch_image({"image": video_element, **process_info}, size_factor=image_factor) for video_element in ele["video"] ] nframes = adjust_by_factor(len(images), FRAME_FACTOR, method='ceil') if len(images) < nframes: images.extend([images[-1]] * (nframes - len(images))) timestamps = [-1 for i in range(nframes)] # not sure about this if return_video_sample_fps: return images, process_info.pop("fps", 2.0), timestamps return images class Eagle3_VLProcessorKwargs(ProcessingKwargs, total=False): # see processing_utils.ProcessingKwargs documentation for usage. _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": {}, "videos_kwargs": {}, } class Eagle3_VLProcessor(ProcessorMixin): r""" Constructs a Eagle3_VL processor which wraps a Eagle3_VL video processor, Eagle3_VL image processor and a Eagle3_VL tokenizer into a single processor. [`Eagle3_VLProcessor`] offers all the functionalities of [`Eagle3_VLVideoProcessor`], [`Eagle3_VLImageProcessor`] and [`Eagle3_VLTokenizer`]. See the [`~Eagle3_VLVideoProcessor.__call__`], [`~Eagle3_VLProcessor.__call__`] and [`~Eagle3_VLProcessor.decode`] for more information. Args: image_processor ([`LlavaOnevisionImageProcessor`], *optional*): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. num_image_tokens (`int`, *optional*): Number of image tokens for one imagethat will be returned by vision tower. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Shoudl be same as in model's config chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string. image_token (`str`, *optional*, defaults to `""`): Special token used to denote image location. video_token (`str`, *optional*, defaults to `"