| | import base64 |
| | import io |
| | import math |
| | import os |
| | from datetime import datetime, timezone |
| | from typing import List, Literal, Optional, TypedDict |
| |
|
| | import numpy as np |
| | from PIL import Image |
| | from pydantic import BaseModel, Field |
| |
|
| | try: |
| | from mecord import VideoReader |
| | except ImportError: |
| | VideoReader = None |
| |
|
| |
|
| | class VideoSpec(BaseModel): |
| | media_type: str = Literal['video'] |
| | height: int = Field(..., gt=0, description="video frame height") |
| | width: int = Field(..., gt=0, description="video frame width") |
| | num_frames: int = Field(..., gt=0, description="num frames") |
| | fps: float = Field(..., gt=0, description="average fps") |
| |
|
| | |
| | key_indices: list[int] = Field(None, description="key indices") |
| | frame_time_info: dict = Field(None, description="frame time info") |
| |
|
| |
|
| | class ImageInput(TypedDict): |
| | type: Literal['image'] |
| | image: Image.Image |
| |
|
| |
|
| | class VideoChunkInput(TypedDict): |
| | type: Literal['video_chunk'] |
| | video_chunk: List[Image.Image] |
| | prompt: Optional[str] = None |
| |
|
| |
|
| | MediaInput = ImageInput | VideoChunkInput |
| |
|
| |
|
| | def get_video_meta(video_src: bytes | str | os.PathLike, |
| | accurate: bool = True) -> dict: |
| | """Get the dimensions of a video.""" |
| | if isinstance(video_src, os.PathLike): |
| | video_src = str(video_src) |
| | |
| | if isinstance(video_src, |
| | str) and video_src.startswith('data:video/mp4;base64,'): |
| | video_src = base64.b64decode(video_src.split(',')[1]) |
| | video = VideoReader(video_src, auto_init=accurate, num_threads=1) |
| | assert video.num_frames > 0, "Invalid video format." |
| | assert video.original_width > 0 and video.original_height > 0, ( |
| | "Invalid video format.") |
| | assert video.avg_fps > 0, "Invalid video format." |
| | return VideoSpec(media_type='video', |
| | height=video.original_height, |
| | width=video.original_width, |
| | num_frames=video.num_frames, |
| | fps=video.avg_fps, |
| | key_indices=video.key_indices, |
| | frame_time_info=video.frame_time_info) |
| |
|
| |
|
| | def timestamp_as_str(timestamp: float, |
| | timestamp_mode: str = "hh:mm:ss.fff") -> str: |
| | """Convert a timestamp to a string in the format of HH:MM:SS.mmm.""" |
| | if timestamp_mode == "hh:mm:ss.fff": |
| | return (datetime.fromtimestamp(timestamp, |
| | tz=timezone.utc).strftime("%H:%M:%S") + |
| | f".{int((timestamp % 1) * 1000):03d}") |
| | elif timestamp_mode == "mm:ss.fff": |
| | return (datetime.fromtimestamp(timestamp, |
| | tz=timezone.utc).strftime("%M:%S") + |
| | f".{int((timestamp % 1) * 1000):03d}") |
| | elif timestamp_mode == "mm:ss": |
| | return datetime.fromtimestamp(timestamp, |
| | tz=timezone.utc).strftime("%M:%S") |
| | else: |
| | raise ValueError(f"Invalid timestamp mode: {timestamp_mode}") |
| |
|
| |
|
| | def navit_resize_image( |
| | width: int, |
| | height: int, |
| | patch_size: int, |
| | merge_kernel_size: int, |
| | in_patch_limit: int, |
| | patch_limit_on_one_side: int, |
| | fixed_output_tokens: int | None, |
| | ): |
| | |
| | s1 = math.sqrt( |
| | in_patch_limit / |
| | (max(1.0, width // patch_size) * max(1.0, height // patch_size))) |
| | s2 = patch_limit_on_one_side * patch_size / width |
| | s3 = patch_limit_on_one_side * patch_size / height |
| | scale = min(1.0, s1, s2, s3) |
| | new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale)) |
| | new_w = min(new_w, patch_limit_on_one_side * patch_size) |
| | new_h = min(new_h, patch_limit_on_one_side * patch_size) |
| |
|
| | |
| | factor = merge_kernel_size * patch_size |
| |
|
| | pad_height = (factor - new_h % factor) % factor |
| | pad_width = (factor - new_w % factor) % factor |
| |
|
| | if fixed_output_tokens is not None: |
| | num_tokens = fixed_output_tokens |
| | else: |
| | |
| | token_height = (new_h + pad_height) // factor |
| | token_width = (new_w + pad_width) // factor |
| |
|
| | assert token_height * merge_kernel_size <= patch_limit_on_one_side, ( |
| | f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" |
| | ) |
| | assert token_width * merge_kernel_size <= patch_limit_on_one_side, ( |
| | f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" |
| | ) |
| |
|
| | num_tokens = token_height * token_width |
| | return { |
| | "num_tokens": num_tokens, |
| | "new_width": new_w, |
| | "new_height": new_h, |
| | "pad_width": pad_width, |
| | "pad_height": pad_height, |
| | "sampled_nframes": 1, |
| | } |
| |
|
| |
|
| | def navit_resize_video( |
| | width: int, |
| | height: int, |
| | nframes: int, |
| | avg_fps: float, |
| | sample_fps: float, |
| | patch_size: int, |
| | merge_kernel_size: int, |
| | in_patch_limit_each_frame: int, |
| | patch_limit_on_one_side: int, |
| | in_patch_limit_total: int | None, |
| | max_num_frames_each_video: int | None, |
| | fixed_output_tokens_each_frame: int | None, |
| | ): |
| | sample_fps = min(sample_fps, avg_fps) |
| | |
| | sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1) |
| | if max_num_frames_each_video is not None: |
| | sampled_nframes = min(sampled_nframes, max_num_frames_each_video) |
| |
|
| | if in_patch_limit_total is not None: |
| | in_patch_limit_each_frame = min( |
| | round(in_patch_limit_total / sampled_nframes), |
| | in_patch_limit_each_frame) |
| |
|
| | ret = navit_resize_image( |
| | width, |
| | height, |
| | patch_size, |
| | merge_kernel_size, |
| | in_patch_limit_each_frame, |
| | patch_limit_on_one_side, |
| | fixed_output_tokens_each_frame, |
| | ) |
| | ret["sampled_nframes"] = sampled_nframes |
| | return ret |
| |
|
| |
|
| | def real_sample_fps_and_max_num_frames( |
| | type_name: Literal["video", "video_chunk"], |
| | sample_fps: float, |
| | max_num_frames_each_video: int | None, |
| | ) -> tuple[int, int | None]: |
| | if type_name == "video": |
| | return sample_fps, max_num_frames_each_video |
| | elif type_name == "video_chunk": |
| | max_num_frames_each_video = None |
| | sample_fps = math.inf |
| | return sample_fps, max_num_frames_each_video |
| | else: |
| | return math.inf, None |
| |
|
| |
|
| | def _to_pil(data: str | bytes): |
| | if isinstance(data, Image.Image): |
| |
|
| | return data.convert("RGB") |
| | elif isinstance(data, str): |
| | if data.startswith("data:"): |
| | raw_base64 = data.split(",")[1] |
| | return Image.open(io.BytesIO( |
| | base64.b64decode(raw_base64))).convert("RGB") |
| | else: |
| | return Image.open(data).convert("RGB") |
| | elif isinstance(data, bytes): |
| | return Image.open(io.BytesIO(data)).convert("RGB") |
| | else: |
| | raise ValueError(f"Unsupported data type: {type(data)}") |
| |
|
| |
|
| | def ensure_media_type(media: MediaInput) -> MediaInput: |
| | if media['type'] == 'image': |
| | media['image'] = _to_pil(media['image']) |
| | return media |
| | elif media['type'] == 'video_chunk': |
| | media['video_chunk'] = [ |
| | _to_pil(frame) for frame in media['video_chunk'] |
| | ] |
| | return media |
| | else: |
| | raise ValueError(f"Unsupported media type: {media['type']}") |
| |
|
| |
|
| | def image_to_np( |
| | image: Image.Image, |
| | resize_to: tuple[int, int] | None = None, |
| | mode: str = "resize", |
| | raise_error_for_ill_resize: bool = True, |
| | ) -> np.ndarray: |
| | """Convert an image to a numpy array. |
| | |
| | Args: |
| | content: The image to convert. |
| | resize_to: The size to resize the image to. |
| | mode: The mode to resize the image to. |
| | raise_error_for_ill_resize: Whether to raise an error for ill-sized resize. |
| | |
| | Returns: |
| | A numpy array. |
| | """ |
| | assert isinstance(image, Image.Image), "image must be a PIL Image" |
| | if resize_to is not None: |
| | if mode == "resize": |
| | image = image.resize(resize_to, resample=Image.Resampling.BICUBIC) |
| |
|
| | elif mode == "rescale_and_pad_to_center": |
| | scale = min(resize_to[0] / image.width, |
| | resize_to[1] / image.height, 1.0) |
| | new_width = round(image.width * scale) |
| | new_height = round(image.height * scale) |
| | if new_width == 0 or new_height == 0: |
| | if raise_error_for_ill_resize: |
| | raise ValueError( |
| | f"Invalid resize to: {resize_to}, from image size: {image.size}" |
| | ) |
| | else: |
| | return np.zeros((resize_to[1], resize_to[0], 3), |
| | dtype=np.uint8) |
| |
|
| | image = image.resize((new_width, new_height), |
| | resample=Image.Resampling.BICUBIC) |
| | padding_left = (resize_to[0] - new_width) // 2 |
| | padding_right = resize_to[0] - new_width - padding_left |
| | padding_top = (resize_to[1] - new_height) // 2 |
| | padding_bottom = resize_to[1] - new_height - padding_top |
| | image = np.asarray(image) |
| | image = np.pad( |
| | image, |
| | ((padding_top, padding_bottom), (padding_left, padding_right), |
| | (0, 0)), |
| | mode="constant", |
| | constant_values=0, |
| | ) |
| | assert image.shape == (resize_to[1], resize_to[0], 3) |
| |
|
| | elif mode == "rescale_and_pad_to_rightbottom": |
| | scale = min(resize_to[0] / image.width, |
| | resize_to[1] / image.height, 1.0) |
| | new_width = round(image.width * scale) |
| | new_height = round(image.height * scale) |
| | if new_width == 0 or new_height == 0: |
| | if raise_error_for_ill_resize: |
| | raise ValueError( |
| | f"Invalid resize to: {resize_to}, from image size: {image.size}" |
| | ) |
| | else: |
| | return np.zeros((resize_to[1], resize_to[0], 3), |
| | dtype=np.uint8) |
| |
|
| | image = image.resize((new_width, new_height), |
| | resample=Image.Resampling.BICUBIC) |
| | padding_right = resize_to[0] - new_width |
| | padding_bottom = resize_to[1] - new_height |
| | image = np.asarray(image) |
| | image = np.pad( |
| | image, |
| | ((0, padding_bottom), (0, padding_right), (0, 0)), |
| | mode="constant", |
| | constant_values=0, |
| | ) |
| | assert image.shape == (resize_to[1], resize_to[0], 3) |
| |
|
| | else: |
| | raise ValueError(f"Invalid mode: {mode}") |
| |
|
| | if isinstance(image, Image.Image): |
| | return np.asarray(image) |
| | else: |
| | return image |
| |
|
| |
|
| | def navit_patchify(pixel_values: np.ndarray, |
| | patch_size: int) -> dict[str, np.ndarray]: |
| | """Reshape the pixel values to a navit shape. |
| | |
| | Args: |
| | pixel_values: np.ndarray, shape (t, h, w, c) |
| | patch_size: int |
| | |
| | Returns: |
| | dict[str, np.ndarray] |
| | - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size) |
| | - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size) |
| | """ |
| | T, H, W, C = pixel_values.shape |
| | assert C == 3, "pixel_values must have 3 channels" |
| |
|
| | patches = pixel_values.reshape(T, H // patch_size, patch_size, |
| | W // patch_size, patch_size, C) |
| | |
| | patches = patches.transpose(0, 1, 3, 5, 2, 4) |
| | patches = patches.reshape(-1, C, patch_size, patch_size) |
| | grid_thw = np.array([T, H // patch_size, W // patch_size]) |
| | return {"pixel_values": patches, "grid_thw": grid_thw} |
| |
|
| |
|
| | def normalize(x: np.ndarray, |
| | mean, |
| | std_inv, |
| | pixels_dtype: np.dtype = np.float32) -> np.ndarray: |
| | """Normalize the image. |
| | |
| | Args: |
| | x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255]. |
| | mean: The mean of the image. |
| | std_inv: The inverse of the std of the image. |
| | pixels_dtype: The dtype of the image. |
| | Returns: |
| | The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype. |
| | """ |
| | x = (x / 255.0).astype(pixels_dtype) |
| | x -= mean |
| | x *= std_inv |
| | return x |
| |
|
| |
|
| | def _to_tensor(data, **kwargs): |
| | import torch |
| |
|
| | if isinstance(data, np.ndarray): |
| | return torch.from_numpy(data).to(**kwargs) |
| | elif isinstance(data, torch.Tensor): |
| | return data.to(**kwargs) |
| | elif isinstance(data, list): |
| | return [_to_tensor(item, **kwargs) for item in data] |
| | elif isinstance(data, tuple): |
| | return tuple(_to_tensor(item, **kwargs) for item in data) |
| | elif isinstance(data, dict): |
| | return {k: _to_tensor(v, **kwargs) for k, v in data.items()} |
| | elif data is None: |
| | return None |
| | else: |
| | raise ValueError(f"Unsupported data type: {type(data)}") |
| |
|