kimi-k2.5 / kimi_k25_vision_processing.py
yujiepan's picture
Upload folder using huggingface_hub
14f47b2 verified
"""Image processor class for Kimi-K2.5.
"""
import json
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from PIL import Image
from transformers.image_processing_utils import (BaseImageProcessor,
BatchFeature)
from transformers.utils import TensorType
from .media_utils import (MediaInput, VideoChunkInput, _to_tensor,
ensure_media_type, get_video_meta, image_to_np,
navit_patchify, navit_resize_image,
navit_resize_video, normalize,
real_sample_fps_and_max_num_frames, timestamp_as_str)
try:
from mecord import VideoReader
except ImportError:
VideoReader = None
def resampling(video_bytes: bytes,
sample_indices: list[int],
key_indices=None,
frame_time_info=None,
num_threads=4) -> str:
video = VideoReader(video_bytes,
num_threads=num_threads,
frame_time_info=frame_time_info,
key_indices=key_indices)
# extract target frames
frames = video[sample_indices]
frames = [Image.fromarray(frame) for frame in frames]
return frames
class KimiK25VisionProcessor(BaseImageProcessor):
model_type = "kimi_k25"
def __init__(
self,
media_proc_cfg: dict,
**kwargs,
):
super().__init__(**kwargs)
self.media_proc_cfg = media_proc_cfg
self.num_frames_per_chunk = media_proc_cfg[
'temporal_merge_kernel_size']
def media_tokens_calculator(self, media: MediaInput):
media = ensure_media_type(media)
ret = self.get_resize_config(media)
return ret['num_tokens']
@classmethod
def make_chunk_prompt(cls, timestamp_text: str) -> str:
return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>"
def split_video_chunks(self,
video_url: str | bytes) -> list[list[Image.Image]]:
# video_url should be base64 str or bytes
video_spec = get_video_meta(video_url)
sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps)
sampled_nframes = max(
round(video_spec.num_frames * sample_fps / video_spec.fps), 1)
frame_inds = np.linspace(0, video_spec.num_frames - 1,
sampled_nframes).round().astype(int)
frame_inds = frame_inds.tolist()
sampled_frame_ids = []
temporal_merge_kernel_size = self.media_proc_cfg[
"temporal_merge_kernel_size"]
num_chunks = 0
chunk_timestamp = []
for i in range(0, len(frame_inds), temporal_merge_kernel_size):
sampled_frame_ids.extend(frame_inds[i:i +
temporal_merge_kernel_size])
start_time = frame_inds[i] / float(video_spec.fps)
timestamp_text = timestamp_as_str(
start_time, self.media_proc_cfg["timestamp_mode"])
chunk_timestamp.append(timestamp_text)
num_chunks += 1
sampled_frames = resampling(video_url, sampled_frame_ids)
chunks = []
for chunk_id in range(num_chunks):
chunk = sampled_frames[chunk_id *
temporal_merge_kernel_size:(chunk_id + 1) *
temporal_merge_kernel_size]
chunks.append(
VideoChunkInput(type="video_chunk",
video_chunk=chunk,
prompt=self.make_chunk_prompt(
chunk_timestamp[chunk_id])))
return chunks
def get_resize_config(self, media_input: MediaInput) -> dict:
if media_input['type'] == 'image':
w, h = media_input['image'].size
ret = navit_resize_image(
w, h, self.media_proc_cfg['patch_size'],
self.media_proc_cfg['merge_kernel_size'],
self.media_proc_cfg['in_patch_limit'],
self.media_proc_cfg['patch_limit_on_one_side'],
self.media_proc_cfg['fixed_output_tokens'])
return ret
elif media_input['type'] == 'video_chunk':
frame = media_input['video_chunk'][0]
width, height = frame.size
num_frames = len(media_input["video_chunk"])
fps = 1.0
sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames(
media_input["type"],
self.media_proc_cfg['sample_fps'],
self.media_proc_cfg['max_num_frames_each_video'],
)
in_patch_limit_each_frame = self.media_proc_cfg[
'in_patch_limit_each_frame']
if in_patch_limit_each_frame is None:
in_patch_limit_each_frame = self.media_proc_cfg[
'in_patch_limit']
ret = navit_resize_video(
width,
height,
num_frames,
fps,
sample_fps,
self.media_proc_cfg['patch_size'],
self.media_proc_cfg['merge_kernel_size'],
in_patch_limit_each_frame,
self.media_proc_cfg['patch_limit_on_one_side'],
self.media_proc_cfg['in_patch_limit_video'],
max_num_frames_each_video,
self.media_proc_cfg['fixed_output_tokens'],
)
return ret
else:
raise ValueError("Unsupported type: {}".format(
media_input['type']))
def resize_image(self, image: Image.Image, new_width: int, new_height: int,
pad_width: int, pad_height: int) -> np.ndarray:
image_np = image_to_np(image, (new_width, new_height), "resize")
image_np = np.pad(
image_np,
((0, pad_height), (0, pad_width), (0, 0)),
mode="constant",
constant_values=0,
)
return image_np
def preprocess(
self,
medias: list[MediaInput],
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
"""
Preprocess a atom vision input (images/video_chunk) into model-ready tensors.
Args:
medias: List of MediaInput.
return_tensors: Desired output format ('pt', 'np', 'tf', or None).
Returns:
BatchFeature containing 'pixel_values' and 'grid_thws' tensors.
"""
if not isinstance(medias, list):
medias = [medias]
if medias:
pixel_values = []
for item in medias:
item = ensure_media_type(item)
resize_config = self.get_resize_config(item)
new_width, new_height, pad_width, pad_height = resize_config[
'new_width'], resize_config['new_height'], resize_config[
'pad_width'], resize_config['pad_height']
if item['type'] == 'image':
image = item['image']
image_np = self.resize_image(image, new_width, new_height,
pad_width, pad_height)
pixel_values.append(np.expand_dims(image_np, axis=0))
elif item['type'] == 'video_chunk':
pixels = []
for frame in item['video_chunk']:
frame_np = self.resize_image(frame, new_width,
new_height, pad_width,
pad_height)
pixels.append(frame_np)
pixel_values.append(np.stack(pixels, axis=0))
else:
raise ValueError("Unsupported type: {}".format(
item['type']))
normalized_pixel_values = []
image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std'])
image_mean = np.array(self.media_proc_cfg['image_mean'])
for pixels in pixel_values:
pixels = normalize(pixels, image_mean, image_std_inv)
pixels_and_thw = navit_patchify(
pixels,
self.media_proc_cfg['patch_size'],
)
normalized_pixel_values.append(pixels_and_thw)
pixel_values = torch.cat([
_to_tensor(pixel_value['pixel_values'])
for pixel_value in normalized_pixel_values
])
grid_thws = torch.cat([
_to_tensor(pixel_value['grid_thw'],
dtype=torch.int64).unsqueeze(0)
for pixel_value in normalized_pixel_values
])
data = {
'pixel_values': pixel_values,
'grid_thws': grid_thws,
}
else:
data = {}
return BatchFeature(data=data, tensor_type=return_tensors)
def __repr__(self):
return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})"
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
output["media_proc_cfg"] = self.media_proc_cfg
if "media_processor" in output:
del output["media_processor"]
return output
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
config = config_dict.copy()
media_proc_cfg = config.pop("media_proc_cfg", {})
return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs)
def to_json_string(self):
dictionary = self.to_dict()
for key, value in dictionary.items():
if hasattr(value, 'tolist'):
dictionary[key] = value.tolist()
return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"