"""Image processor class for Kimi-K2.5. """ import json from typing import Any, Dict, Optional, Union import numpy as np import torch from PIL import Image from transformers.image_processing_utils import (BaseImageProcessor, BatchFeature) from transformers.utils import TensorType from .media_utils import (MediaInput, VideoChunkInput, _to_tensor, ensure_media_type, get_video_meta, image_to_np, navit_patchify, navit_resize_image, navit_resize_video, normalize, real_sample_fps_and_max_num_frames, timestamp_as_str) try: from mecord import VideoReader except ImportError: VideoReader = None def resampling(video_bytes: bytes, sample_indices: list[int], key_indices=None, frame_time_info=None, num_threads=4) -> str: video = VideoReader(video_bytes, num_threads=num_threads, frame_time_info=frame_time_info, key_indices=key_indices) # extract target frames frames = video[sample_indices] frames = [Image.fromarray(frame) for frame in frames] return frames class KimiK25VisionProcessor(BaseImageProcessor): model_type = "kimi_k25" def __init__( self, media_proc_cfg: dict, **kwargs, ): super().__init__(**kwargs) self.media_proc_cfg = media_proc_cfg self.num_frames_per_chunk = media_proc_cfg[ 'temporal_merge_kernel_size'] def media_tokens_calculator(self, media: MediaInput): media = ensure_media_type(media) ret = self.get_resize_config(media) return ret['num_tokens'] @classmethod def make_chunk_prompt(cls, timestamp_text: str) -> str: return f"{timestamp_text}<|media_begin|>video<|media_content|><|media_pad|><|media_end|>" def split_video_chunks(self, video_url: str | bytes) -> list[list[Image.Image]]: # video_url should be base64 str or bytes video_spec = get_video_meta(video_url) sample_fps = min(self.media_proc_cfg['sample_fps'], video_spec.fps) sampled_nframes = max( round(video_spec.num_frames * sample_fps / video_spec.fps), 1) frame_inds = np.linspace(0, video_spec.num_frames - 1, sampled_nframes).round().astype(int) frame_inds = frame_inds.tolist() sampled_frame_ids = [] temporal_merge_kernel_size = self.media_proc_cfg[ "temporal_merge_kernel_size"] num_chunks = 0 chunk_timestamp = [] for i in range(0, len(frame_inds), temporal_merge_kernel_size): sampled_frame_ids.extend(frame_inds[i:i + temporal_merge_kernel_size]) start_time = frame_inds[i] / float(video_spec.fps) timestamp_text = timestamp_as_str( start_time, self.media_proc_cfg["timestamp_mode"]) chunk_timestamp.append(timestamp_text) num_chunks += 1 sampled_frames = resampling(video_url, sampled_frame_ids) chunks = [] for chunk_id in range(num_chunks): chunk = sampled_frames[chunk_id * temporal_merge_kernel_size:(chunk_id + 1) * temporal_merge_kernel_size] chunks.append( VideoChunkInput(type="video_chunk", video_chunk=chunk, prompt=self.make_chunk_prompt( chunk_timestamp[chunk_id]))) return chunks def get_resize_config(self, media_input: MediaInput) -> dict: if media_input['type'] == 'image': w, h = media_input['image'].size ret = navit_resize_image( w, h, self.media_proc_cfg['patch_size'], self.media_proc_cfg['merge_kernel_size'], self.media_proc_cfg['in_patch_limit'], self.media_proc_cfg['patch_limit_on_one_side'], self.media_proc_cfg['fixed_output_tokens']) return ret elif media_input['type'] == 'video_chunk': frame = media_input['video_chunk'][0] width, height = frame.size num_frames = len(media_input["video_chunk"]) fps = 1.0 sample_fps, max_num_frames_each_video = real_sample_fps_and_max_num_frames( media_input["type"], self.media_proc_cfg['sample_fps'], self.media_proc_cfg['max_num_frames_each_video'], ) in_patch_limit_each_frame = self.media_proc_cfg[ 'in_patch_limit_each_frame'] if in_patch_limit_each_frame is None: in_patch_limit_each_frame = self.media_proc_cfg[ 'in_patch_limit'] ret = navit_resize_video( width, height, num_frames, fps, sample_fps, self.media_proc_cfg['patch_size'], self.media_proc_cfg['merge_kernel_size'], in_patch_limit_each_frame, self.media_proc_cfg['patch_limit_on_one_side'], self.media_proc_cfg['in_patch_limit_video'], max_num_frames_each_video, self.media_proc_cfg['fixed_output_tokens'], ) return ret else: raise ValueError("Unsupported type: {}".format( media_input['type'])) def resize_image(self, image: Image.Image, new_width: int, new_height: int, pad_width: int, pad_height: int) -> np.ndarray: image_np = image_to_np(image, (new_width, new_height), "resize") image_np = np.pad( image_np, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0, ) return image_np def preprocess( self, medias: list[MediaInput], return_tensors: Optional[Union[str, TensorType]] = None, ) -> BatchFeature: """ Preprocess a atom vision input (images/video_chunk) into model-ready tensors. Args: medias: List of MediaInput. return_tensors: Desired output format ('pt', 'np', 'tf', or None). Returns: BatchFeature containing 'pixel_values' and 'grid_thws' tensors. """ if not isinstance(medias, list): medias = [medias] if medias: pixel_values = [] for item in medias: item = ensure_media_type(item) resize_config = self.get_resize_config(item) new_width, new_height, pad_width, pad_height = resize_config[ 'new_width'], resize_config['new_height'], resize_config[ 'pad_width'], resize_config['pad_height'] if item['type'] == 'image': image = item['image'] image_np = self.resize_image(image, new_width, new_height, pad_width, pad_height) pixel_values.append(np.expand_dims(image_np, axis=0)) elif item['type'] == 'video_chunk': pixels = [] for frame in item['video_chunk']: frame_np = self.resize_image(frame, new_width, new_height, pad_width, pad_height) pixels.append(frame_np) pixel_values.append(np.stack(pixels, axis=0)) else: raise ValueError("Unsupported type: {}".format( item['type'])) normalized_pixel_values = [] image_std_inv = 1.0 / np.array(self.media_proc_cfg['image_std']) image_mean = np.array(self.media_proc_cfg['image_mean']) for pixels in pixel_values: pixels = normalize(pixels, image_mean, image_std_inv) pixels_and_thw = navit_patchify( pixels, self.media_proc_cfg['patch_size'], ) normalized_pixel_values.append(pixels_and_thw) pixel_values = torch.cat([ _to_tensor(pixel_value['pixel_values']) for pixel_value in normalized_pixel_values ]) grid_thws = torch.cat([ _to_tensor(pixel_value['grid_thw'], dtype=torch.int64).unsqueeze(0) for pixel_value in normalized_pixel_values ]) data = { 'pixel_values': pixel_values, 'grid_thws': grid_thws, } else: data = {} return BatchFeature(data=data, tensor_type=return_tensors) def __repr__(self): return f"KimiK25VisionProcessor(media_proc_cfg={self.media_proc_cfg})" def to_dict(self) -> Dict[str, Any]: output = super().to_dict() output["media_proc_cfg"] = self.media_proc_cfg if "media_processor" in output: del output["media_processor"] return output @classmethod def from_dict(cls, config_dict: Dict[str, Any], **kwargs): config = config_dict.copy() media_proc_cfg = config.pop("media_proc_cfg", {}) return cls(media_proc_cfg=media_proc_cfg, **config, **kwargs) def to_json_string(self): dictionary = self.to_dict() for key, value in dictionary.items(): if hasattr(value, 'tolist'): dictionary[key] = value.tolist() return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"