| import base64
|
| import io
|
| import math
|
| import os
|
| from datetime import datetime, timezone
|
| from typing import List, Literal, NotRequired, Optional, TypedDict
|
|
|
| import numpy as np
|
| from PIL import Image
|
| from pydantic import BaseModel, Field
|
|
|
| try:
|
| from mecord import VideoReader
|
| except ImportError:
|
| VideoReader = None
|
|
|
|
|
| class VideoSpec(BaseModel):
|
| media_type: Literal["video"] = "video"
|
| height: int = Field(..., gt=0, description="video frame height")
|
| width: int = Field(..., gt=0, description="video frame width")
|
| num_frames: int = Field(..., gt=0, description="num frames")
|
| fps: float = Field(..., gt=0, description="average fps")
|
|
|
|
|
| key_indices: list[int] = Field(None, description="key indices")
|
| frame_time_info: dict = Field(None, description="frame time info")
|
|
|
|
|
| class ImageInput(TypedDict):
|
| type: Literal["image"]
|
| image: Image.Image
|
|
|
|
|
| class VideoChunkInput(TypedDict):
|
| type: Literal["video_chunk"]
|
| video_chunk: List[Image.Image]
|
| prompt: NotRequired[str]
|
|
|
|
|
| MediaInput = ImageInput | VideoChunkInput
|
|
|
|
|
| def get_video_meta(video_src: bytes | str | os.PathLike, accurate: bool = True) -> dict:
|
| """Get the dimensions of a video."""
|
| if isinstance(video_src, os.PathLike):
|
| video_src = str(video_src)
|
|
|
| if isinstance(video_src, str) and video_src.startswith("data:video/mp4;base64,"):
|
| video_src = base64.b64decode(video_src.split(",")[1])
|
| video = VideoReader(video_src, auto_init=accurate, num_threads=1)
|
| assert video.num_frames > 0, "Invalid video format."
|
| assert video.original_width > 0 and video.original_height > 0, (
|
| "Invalid video format."
|
| )
|
| assert video.avg_fps > 0, "Invalid video format."
|
| return VideoSpec(
|
| media_type="video",
|
| height=video.original_height,
|
| width=video.original_width,
|
| num_frames=video.num_frames,
|
| fps=video.avg_fps,
|
| key_indices=video.key_indices,
|
| frame_time_info=video.frame_time_info,
|
| )
|
|
|
|
|
| def timestamp_as_str(timestamp: float, timestamp_mode: str = "hh:mm:ss.fff") -> str:
|
| """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
|
| if timestamp_mode == "hh:mm:ss.fff":
|
| return (
|
| datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%H:%M:%S")
|
| + f".{int((timestamp % 1) * 1000):03d}"
|
| )
|
| elif timestamp_mode == "mm:ss.fff":
|
| return (
|
| datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%M:%S")
|
| + f".{int((timestamp % 1) * 1000):03d}"
|
| )
|
| elif timestamp_mode == "mm:ss":
|
| return datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime("%M:%S")
|
| else:
|
| raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
|
|
|
|
|
| def navit_resize_image(
|
| width: int,
|
| height: int,
|
| patch_size: int,
|
| merge_kernel_size: int,
|
| in_patch_limit: int,
|
| patch_limit_on_one_side: int,
|
| fixed_output_tokens: int | None,
|
| ):
|
|
|
| s1 = math.sqrt(
|
| in_patch_limit
|
| / (max(1.0, width // patch_size) * max(1.0, height // patch_size))
|
| )
|
| s2 = patch_limit_on_one_side * patch_size / width
|
| s3 = patch_limit_on_one_side * patch_size / height
|
| scale = min(1.0, s1, s2, s3)
|
| new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
|
| new_w = min(new_w, patch_limit_on_one_side * patch_size)
|
| new_h = min(new_h, patch_limit_on_one_side * patch_size)
|
|
|
|
|
| factor = merge_kernel_size * patch_size
|
|
|
| pad_height = (factor - new_h % factor) % factor
|
| pad_width = (factor - new_w % factor) % factor
|
|
|
| if fixed_output_tokens is not None:
|
| num_tokens = fixed_output_tokens
|
| else:
|
|
|
| token_height = (new_h + pad_height) // factor
|
| token_width = (new_w + pad_width) // factor
|
|
|
| assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
|
| f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
|
| )
|
| assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
|
| f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
|
| )
|
|
|
| num_tokens = token_height * token_width
|
| return {
|
| "num_tokens": num_tokens,
|
| "new_width": new_w,
|
| "new_height": new_h,
|
| "pad_width": pad_width,
|
| "pad_height": pad_height,
|
| "sampled_nframes": 1,
|
| }
|
|
|
|
|
| def navit_resize_video(
|
| width: int,
|
| height: int,
|
| nframes: int,
|
| avg_fps: float,
|
| sample_fps: float,
|
| patch_size: int,
|
| merge_kernel_size: int,
|
| in_patch_limit_each_frame: int,
|
| patch_limit_on_one_side: int,
|
| in_patch_limit_total: int | None,
|
| max_num_frames_each_video: int | None,
|
| fixed_output_tokens_each_frame: int | None,
|
| ):
|
| sample_fps = min(sample_fps, avg_fps)
|
|
|
| sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
|
| if max_num_frames_each_video is not None:
|
| sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
|
|
|
| if in_patch_limit_total is not None:
|
| in_patch_limit_each_frame = min(
|
| round(in_patch_limit_total / sampled_nframes), in_patch_limit_each_frame
|
| )
|
|
|
| ret = navit_resize_image(
|
| width,
|
| height,
|
| patch_size,
|
| merge_kernel_size,
|
| in_patch_limit_each_frame,
|
| patch_limit_on_one_side,
|
| fixed_output_tokens_each_frame,
|
| )
|
| ret["sampled_nframes"] = sampled_nframes
|
| return ret
|
|
|
|
|
| def real_sample_fps_and_max_num_frames(
|
| type_name: Literal["video", "video_chunk"],
|
| sample_fps: float,
|
| max_num_frames_each_video: int | None,
|
| ) -> tuple[int, int | None]:
|
| if type_name == "video":
|
| return sample_fps, max_num_frames_each_video
|
| elif type_name == "video_chunk":
|
| max_num_frames_each_video = None
|
| sample_fps = math.inf
|
| return sample_fps, max_num_frames_each_video
|
| else:
|
| return math.inf, None
|
|
|
|
|
| def _to_pil(data: str | bytes):
|
| if isinstance(data, Image.Image):
|
| return data.convert("RGB")
|
| elif isinstance(data, str):
|
| if data.startswith("data:"):
|
| raw_base64 = data.split(",")[1]
|
| return Image.open(io.BytesIO(base64.b64decode(raw_base64))).convert("RGB")
|
| else:
|
| return Image.open(data).convert("RGB")
|
| elif isinstance(data, bytes):
|
| return Image.open(io.BytesIO(data)).convert("RGB")
|
| else:
|
| raise ValueError(f"Unsupported data type: {type(data)}")
|
|
|
|
|
| def ensure_media_type(media: MediaInput) -> MediaInput:
|
| if media["type"] == "image":
|
| media["image"] = _to_pil(media["image"])
|
| return media
|
| elif media["type"] == "video_chunk":
|
| media["video_chunk"] = [_to_pil(frame) for frame in media["video_chunk"]]
|
| return media
|
| else:
|
| raise ValueError(f"Unsupported media type: {media['type']}")
|
|
|
|
|
| def image_to_np(
|
| image: Image.Image,
|
| resize_to: tuple[int, int] | None = None,
|
| mode: str = "resize",
|
| raise_error_for_ill_resize: bool = True,
|
| ) -> np.ndarray:
|
| """Convert an image to a numpy array.
|
|
|
| Args:
|
| content: The image to convert.
|
| resize_to: The size to resize the image to.
|
| mode: The mode to resize the image to.
|
| raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
|
|
|
| Returns:
|
| A numpy array.
|
| """
|
| assert isinstance(image, Image.Image), "image must be a PIL Image"
|
| if resize_to is not None:
|
| if mode == "resize":
|
| image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
|
|
|
| elif mode == "rescale_and_pad_to_center":
|
| scale = min(resize_to[0] / image.width, resize_to[1] / image.height, 1.0)
|
| new_width = round(image.width * scale)
|
| new_height = round(image.height * scale)
|
| if new_width == 0 or new_height == 0:
|
| if raise_error_for_ill_resize:
|
| raise ValueError(
|
| f"Invalid resize to: {resize_to}, from image size: {image.size}"
|
| )
|
| else:
|
| return np.zeros((resize_to[1], resize_to[0], 3), dtype=np.uint8)
|
|
|
| image = image.resize(
|
| (new_width, new_height), resample=Image.Resampling.BICUBIC
|
| )
|
| padding_left = (resize_to[0] - new_width) // 2
|
| padding_right = resize_to[0] - new_width - padding_left
|
| padding_top = (resize_to[1] - new_height) // 2
|
| padding_bottom = resize_to[1] - new_height - padding_top
|
| image = np.asarray(image)
|
| image = np.pad(
|
| image,
|
| ((padding_top, padding_bottom), (padding_left, padding_right), (0, 0)),
|
| mode="constant",
|
| constant_values=0,
|
| )
|
| assert image.shape == (resize_to[1], resize_to[0], 3)
|
|
|
| elif mode == "rescale_and_pad_to_rightbottom":
|
| scale = min(resize_to[0] / image.width, resize_to[1] / image.height, 1.0)
|
| new_width = round(image.width * scale)
|
| new_height = round(image.height * scale)
|
| if new_width == 0 or new_height == 0:
|
| if raise_error_for_ill_resize:
|
| raise ValueError(
|
| f"Invalid resize to: {resize_to}, from image size: {image.size}"
|
| )
|
| else:
|
| return np.zeros((resize_to[1], resize_to[0], 3), dtype=np.uint8)
|
|
|
| image = image.resize(
|
| (new_width, new_height), resample=Image.Resampling.BICUBIC
|
| )
|
| padding_right = resize_to[0] - new_width
|
| padding_bottom = resize_to[1] - new_height
|
| image = np.asarray(image)
|
| image = np.pad(
|
| image,
|
| ((0, padding_bottom), (0, padding_right), (0, 0)),
|
| mode="constant",
|
| constant_values=0,
|
| )
|
| assert image.shape == (resize_to[1], resize_to[0], 3)
|
|
|
| else:
|
| raise ValueError(f"Invalid mode: {mode}")
|
|
|
| if isinstance(image, Image.Image):
|
| return np.asarray(image)
|
| else:
|
| return image
|
|
|
|
|
| def navit_patchify(pixel_values: np.ndarray, patch_size: int) -> dict[str, np.ndarray]:
|
| """Reshape the pixel values to a navit shape.
|
|
|
| Args:
|
| pixel_values: np.ndarray, shape (t, h, w, c)
|
| patch_size: int
|
|
|
| Returns:
|
| dict[str, np.ndarray]
|
| - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
|
| - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
|
| """
|
| T, H, W, C = pixel_values.shape
|
| assert C == 3, "pixel_values must have 3 channels"
|
|
|
| patches = pixel_values.reshape(
|
| T, H // patch_size, patch_size, W // patch_size, patch_size, C
|
| )
|
|
|
| patches = patches.transpose(0, 1, 3, 5, 2, 4)
|
| patches = patches.reshape(-1, C, patch_size, patch_size)
|
| grid_thw = np.array([T, H // patch_size, W // patch_size])
|
| return {"pixel_values": patches, "grid_thw": grid_thw}
|
|
|
|
|
| def normalize(
|
| x: np.ndarray, mean, std_inv, pixels_dtype: np.dtype = np.float32
|
| ) -> np.ndarray:
|
| """Normalize the image.
|
|
|
| Args:
|
| x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
|
| mean: The mean of the image.
|
| std_inv: The inverse of the std of the image.
|
| pixels_dtype: The dtype of the image.
|
| Returns:
|
| The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
|
| """
|
| x = (x / 255.0).astype(pixels_dtype)
|
| x -= mean
|
| x *= std_inv
|
| return x
|
|
|
|
|
| def _to_tensor(data, **kwargs):
|
| import torch
|
|
|
| if isinstance(data, np.ndarray):
|
| return torch.from_numpy(data).to(**kwargs)
|
| elif isinstance(data, torch.Tensor):
|
| return data.to(**kwargs)
|
| elif isinstance(data, list):
|
| return [_to_tensor(item, **kwargs) for item in data]
|
| elif isinstance(data, tuple):
|
| return tuple(_to_tensor(item, **kwargs) for item in data)
|
| elif isinstance(data, dict):
|
| return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
|
| elif data is None:
|
| return None
|
| else:
|
| raise ValueError(f"Unsupported data type: {type(data)}")
|
|
|