|
|
import base64 |
|
|
import io |
|
|
import math |
|
|
import os |
|
|
from datetime import datetime, timezone |
|
|
from typing import List, Literal, Optional, TypedDict |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
try: |
|
|
from mecord import VideoReader |
|
|
except ImportError: |
|
|
VideoReader = None |
|
|
|
|
|
|
|
|
class VideoSpec(BaseModel): |
|
|
media_type: str = Literal['video'] |
|
|
height: int = Field(..., gt=0, description="video frame height") |
|
|
width: int = Field(..., gt=0, description="video frame width") |
|
|
num_frames: int = Field(..., gt=0, description="num frames") |
|
|
fps: float = Field(..., gt=0, description="average fps") |
|
|
|
|
|
|
|
|
key_indices: list[int] = Field(None, description="key indices") |
|
|
frame_time_info: dict = Field(None, description="frame time info") |
|
|
|
|
|
|
|
|
class ImageInput(TypedDict): |
|
|
type: Literal['image'] |
|
|
image: Image.Image |
|
|
|
|
|
|
|
|
class VideoChunkInput(TypedDict): |
|
|
type: Literal['video_chunk'] |
|
|
video_chunk: List[Image.Image] |
|
|
prompt: Optional[str] = None |
|
|
|
|
|
|
|
|
MediaInput = ImageInput | VideoChunkInput |
|
|
|
|
|
|
|
|
def get_video_meta(video_src: bytes | str | os.PathLike, |
|
|
accurate: bool = True) -> dict: |
|
|
"""Get the dimensions of a video.""" |
|
|
if isinstance(video_src, os.PathLike): |
|
|
video_src = str(video_src) |
|
|
|
|
|
if isinstance(video_src, |
|
|
str) and video_src.startswith('data:video/mp4;base64,'): |
|
|
video_src = base64.b64decode(video_src.split(',')[1]) |
|
|
video = VideoReader(video_src, auto_init=accurate, num_threads=1) |
|
|
assert video.num_frames > 0, "Invalid video format." |
|
|
assert video.original_width > 0 and video.original_height > 0, ( |
|
|
"Invalid video format.") |
|
|
assert video.avg_fps > 0, "Invalid video format." |
|
|
return VideoSpec(media_type='video', |
|
|
height=video.original_height, |
|
|
width=video.original_width, |
|
|
num_frames=video.num_frames, |
|
|
fps=video.avg_fps, |
|
|
key_indices=video.key_indices, |
|
|
frame_time_info=video.frame_time_info) |
|
|
|
|
|
|
|
|
def timestamp_as_str(timestamp: float, |
|
|
timestamp_mode: str = "hh:mm:ss.fff") -> str: |
|
|
"""Convert a timestamp to a string in the format of HH:MM:SS.mmm.""" |
|
|
if timestamp_mode == "hh:mm:ss.fff": |
|
|
return (datetime.fromtimestamp(timestamp, |
|
|
tz=timezone.utc).strftime("%H:%M:%S") + |
|
|
f".{int((timestamp % 1) * 1000):03d}") |
|
|
elif timestamp_mode == "mm:ss.fff": |
|
|
return (datetime.fromtimestamp(timestamp, |
|
|
tz=timezone.utc).strftime("%M:%S") + |
|
|
f".{int((timestamp % 1) * 1000):03d}") |
|
|
elif timestamp_mode == "mm:ss": |
|
|
return datetime.fromtimestamp(timestamp, |
|
|
tz=timezone.utc).strftime("%M:%S") |
|
|
else: |
|
|
raise ValueError(f"Invalid timestamp mode: {timestamp_mode}") |
|
|
|
|
|
|
|
|
def navit_resize_image( |
|
|
width: int, |
|
|
height: int, |
|
|
patch_size: int, |
|
|
merge_kernel_size: int, |
|
|
in_patch_limit: int, |
|
|
patch_limit_on_one_side: int, |
|
|
fixed_output_tokens: int | None, |
|
|
): |
|
|
|
|
|
s1 = math.sqrt( |
|
|
in_patch_limit / |
|
|
(max(1.0, width // patch_size) * max(1.0, height // patch_size))) |
|
|
s2 = patch_limit_on_one_side * patch_size / width |
|
|
s3 = patch_limit_on_one_side * patch_size / height |
|
|
scale = min(1.0, s1, s2, s3) |
|
|
new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale)) |
|
|
new_w = min(new_w, patch_limit_on_one_side * patch_size) |
|
|
new_h = min(new_h, patch_limit_on_one_side * patch_size) |
|
|
|
|
|
|
|
|
factor = merge_kernel_size * patch_size |
|
|
|
|
|
pad_height = (factor - new_h % factor) % factor |
|
|
pad_width = (factor - new_w % factor) % factor |
|
|
|
|
|
if fixed_output_tokens is not None: |
|
|
num_tokens = fixed_output_tokens |
|
|
else: |
|
|
|
|
|
token_height = (new_h + pad_height) // factor |
|
|
token_width = (new_w + pad_width) // factor |
|
|
|
|
|
assert token_height * merge_kernel_size <= patch_limit_on_one_side, ( |
|
|
f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" |
|
|
) |
|
|
assert token_width * merge_kernel_size <= patch_limit_on_one_side, ( |
|
|
f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}" |
|
|
) |
|
|
|
|
|
num_tokens = token_height * token_width |
|
|
return { |
|
|
"num_tokens": num_tokens, |
|
|
"new_width": new_w, |
|
|
"new_height": new_h, |
|
|
"pad_width": pad_width, |
|
|
"pad_height": pad_height, |
|
|
"sampled_nframes": 1, |
|
|
} |
|
|
|
|
|
|
|
|
def navit_resize_video( |
|
|
width: int, |
|
|
height: int, |
|
|
nframes: int, |
|
|
avg_fps: float, |
|
|
sample_fps: float, |
|
|
patch_size: int, |
|
|
merge_kernel_size: int, |
|
|
in_patch_limit_each_frame: int, |
|
|
patch_limit_on_one_side: int, |
|
|
in_patch_limit_total: int | None, |
|
|
max_num_frames_each_video: int | None, |
|
|
fixed_output_tokens_each_frame: int | None, |
|
|
): |
|
|
sample_fps = min(sample_fps, avg_fps) |
|
|
|
|
|
sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1) |
|
|
if max_num_frames_each_video is not None: |
|
|
sampled_nframes = min(sampled_nframes, max_num_frames_each_video) |
|
|
|
|
|
if in_patch_limit_total is not None: |
|
|
in_patch_limit_each_frame = min( |
|
|
round(in_patch_limit_total / sampled_nframes), |
|
|
in_patch_limit_each_frame) |
|
|
|
|
|
ret = navit_resize_image( |
|
|
width, |
|
|
height, |
|
|
patch_size, |
|
|
merge_kernel_size, |
|
|
in_patch_limit_each_frame, |
|
|
patch_limit_on_one_side, |
|
|
fixed_output_tokens_each_frame, |
|
|
) |
|
|
ret["sampled_nframes"] = sampled_nframes |
|
|
return ret |
|
|
|
|
|
|
|
|
def real_sample_fps_and_max_num_frames( |
|
|
type_name: Literal["video", "video_chunk"], |
|
|
sample_fps: float, |
|
|
max_num_frames_each_video: int | None, |
|
|
) -> tuple[int, int | None]: |
|
|
if type_name == "video": |
|
|
return sample_fps, max_num_frames_each_video |
|
|
elif type_name == "video_chunk": |
|
|
max_num_frames_each_video = None |
|
|
sample_fps = math.inf |
|
|
return sample_fps, max_num_frames_each_video |
|
|
else: |
|
|
return math.inf, None |
|
|
|
|
|
|
|
|
def _to_pil(data: str | bytes): |
|
|
if isinstance(data, Image.Image): |
|
|
|
|
|
return data.convert("RGB") |
|
|
elif isinstance(data, str): |
|
|
if data.startswith("data:"): |
|
|
raw_base64 = data.split(",")[1] |
|
|
return Image.open(io.BytesIO( |
|
|
base64.b64decode(raw_base64))).convert("RGB") |
|
|
else: |
|
|
return Image.open(data).convert("RGB") |
|
|
elif isinstance(data, bytes): |
|
|
return Image.open(io.BytesIO(data)).convert("RGB") |
|
|
else: |
|
|
raise ValueError(f"Unsupported data type: {type(data)}") |
|
|
|
|
|
|
|
|
def ensure_media_type(media: MediaInput) -> MediaInput: |
|
|
if media['type'] == 'image': |
|
|
media['image'] = _to_pil(media['image']) |
|
|
return media |
|
|
elif media['type'] == 'video_chunk': |
|
|
media['video_chunk'] = [ |
|
|
_to_pil(frame) for frame in media['video_chunk'] |
|
|
] |
|
|
return media |
|
|
else: |
|
|
raise ValueError(f"Unsupported media type: {media['type']}") |
|
|
|
|
|
|
|
|
def image_to_np( |
|
|
image: Image.Image, |
|
|
resize_to: tuple[int, int] | None = None, |
|
|
mode: str = "resize", |
|
|
raise_error_for_ill_resize: bool = True, |
|
|
) -> np.ndarray: |
|
|
"""Convert an image to a numpy array. |
|
|
|
|
|
Args: |
|
|
content: The image to convert. |
|
|
resize_to: The size to resize the image to. |
|
|
mode: The mode to resize the image to. |
|
|
raise_error_for_ill_resize: Whether to raise an error for ill-sized resize. |
|
|
|
|
|
Returns: |
|
|
A numpy array. |
|
|
""" |
|
|
assert isinstance(image, Image.Image), "image must be a PIL Image" |
|
|
if resize_to is not None: |
|
|
if mode == "resize": |
|
|
image = image.resize(resize_to, resample=Image.Resampling.BICUBIC) |
|
|
|
|
|
elif mode == "rescale_and_pad_to_center": |
|
|
scale = min(resize_to[0] / image.width, |
|
|
resize_to[1] / image.height, 1.0) |
|
|
new_width = round(image.width * scale) |
|
|
new_height = round(image.height * scale) |
|
|
if new_width == 0 or new_height == 0: |
|
|
if raise_error_for_ill_resize: |
|
|
raise ValueError( |
|
|
f"Invalid resize to: {resize_to}, from image size: {image.size}" |
|
|
) |
|
|
else: |
|
|
return np.zeros((resize_to[1], resize_to[0], 3), |
|
|
dtype=np.uint8) |
|
|
|
|
|
image = image.resize((new_width, new_height), |
|
|
resample=Image.Resampling.BICUBIC) |
|
|
padding_left = (resize_to[0] - new_width) // 2 |
|
|
padding_right = resize_to[0] - new_width - padding_left |
|
|
padding_top = (resize_to[1] - new_height) // 2 |
|
|
padding_bottom = resize_to[1] - new_height - padding_top |
|
|
image = np.asarray(image) |
|
|
image = np.pad( |
|
|
image, |
|
|
((padding_top, padding_bottom), (padding_left, padding_right), |
|
|
(0, 0)), |
|
|
mode="constant", |
|
|
constant_values=0, |
|
|
) |
|
|
assert image.shape == (resize_to[1], resize_to[0], 3) |
|
|
|
|
|
elif mode == "rescale_and_pad_to_rightbottom": |
|
|
scale = min(resize_to[0] / image.width, |
|
|
resize_to[1] / image.height, 1.0) |
|
|
new_width = round(image.width * scale) |
|
|
new_height = round(image.height * scale) |
|
|
if new_width == 0 or new_height == 0: |
|
|
if raise_error_for_ill_resize: |
|
|
raise ValueError( |
|
|
f"Invalid resize to: {resize_to}, from image size: {image.size}" |
|
|
) |
|
|
else: |
|
|
return np.zeros((resize_to[1], resize_to[0], 3), |
|
|
dtype=np.uint8) |
|
|
|
|
|
image = image.resize((new_width, new_height), |
|
|
resample=Image.Resampling.BICUBIC) |
|
|
padding_right = resize_to[0] - new_width |
|
|
padding_bottom = resize_to[1] - new_height |
|
|
image = np.asarray(image) |
|
|
image = np.pad( |
|
|
image, |
|
|
((0, padding_bottom), (0, padding_right), (0, 0)), |
|
|
mode="constant", |
|
|
constant_values=0, |
|
|
) |
|
|
assert image.shape == (resize_to[1], resize_to[0], 3) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Invalid mode: {mode}") |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
return np.asarray(image) |
|
|
else: |
|
|
return image |
|
|
|
|
|
|
|
|
def navit_patchify(pixel_values: np.ndarray, |
|
|
patch_size: int) -> dict[str, np.ndarray]: |
|
|
"""Reshape the pixel values to a navit shape. |
|
|
|
|
|
Args: |
|
|
pixel_values: np.ndarray, shape (t, h, w, c) |
|
|
patch_size: int |
|
|
|
|
|
Returns: |
|
|
dict[str, np.ndarray] |
|
|
- patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size) |
|
|
- grid_thw: np.ndarray, (t, h//patch_size, w//patch_size) |
|
|
""" |
|
|
T, H, W, C = pixel_values.shape |
|
|
assert C == 3, "pixel_values must have 3 channels" |
|
|
|
|
|
patches = pixel_values.reshape(T, H // patch_size, patch_size, |
|
|
W // patch_size, patch_size, C) |
|
|
|
|
|
patches = patches.transpose(0, 1, 3, 5, 2, 4) |
|
|
patches = patches.reshape(-1, C, patch_size, patch_size) |
|
|
grid_thw = np.array([T, H // patch_size, W // patch_size]) |
|
|
return {"pixel_values": patches, "grid_thw": grid_thw} |
|
|
|
|
|
|
|
|
def normalize(x: np.ndarray, |
|
|
mean, |
|
|
std_inv, |
|
|
pixels_dtype: np.dtype = np.float32) -> np.ndarray: |
|
|
"""Normalize the image. |
|
|
|
|
|
Args: |
|
|
x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255]. |
|
|
mean: The mean of the image. |
|
|
std_inv: The inverse of the std of the image. |
|
|
pixels_dtype: The dtype of the image. |
|
|
Returns: |
|
|
The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype. |
|
|
""" |
|
|
x = (x / 255.0).astype(pixels_dtype) |
|
|
x -= mean |
|
|
x *= std_inv |
|
|
return x |
|
|
|
|
|
|
|
|
def _to_tensor(data, **kwargs): |
|
|
import torch |
|
|
|
|
|
if isinstance(data, np.ndarray): |
|
|
return torch.from_numpy(data).to(**kwargs) |
|
|
elif isinstance(data, torch.Tensor): |
|
|
return data.to(**kwargs) |
|
|
elif isinstance(data, list): |
|
|
return [_to_tensor(item, **kwargs) for item in data] |
|
|
elif isinstance(data, tuple): |
|
|
return tuple(_to_tensor(item, **kwargs) for item in data) |
|
|
elif isinstance(data, dict): |
|
|
return {k: _to_tensor(v, **kwargs) for k, v in data.items()} |
|
|
elif data is None: |
|
|
return None |
|
|
else: |
|
|
raise ValueError(f"Unsupported data type: {type(data)}") |
|
|
|