File size: 6,906 Bytes
14f47b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.utils import logging
logger = logging.get_logger(__name__)
class KimiK25Processor(ProcessorMixin):
r"""
Constructs a KimiK25 processor which wraps a KimiK25 image processor and a tokenizer into a single processor.
[`KimiK25Processor`] offers all the functionalities of [`KimiK25ImageProcessor`] and [`TikTokenTokenizer`]. See the
[`~KimiK25Processor.__call__`] and [`~KimiK25Processor.decode`] for more information.
Args:
image_processor ([`KimiK25ImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`TikTokenTokenizer`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
**kwargs,
):
super().__init__(image_processor,
tokenizer,
chat_template=chat_template)
self.media_processor = image_processor
# A special temporal placeholder to be replaced by actual video placeholders
self.video_placeholder = "<|kimi_k25_video_placeholder|>"
def update_raw_text(self, text: str, video_prompts: list[str]) -> str:
# replace video prompt in text with video chunk prompts
video_count = text.count(self.video_placeholder)
if video_count == 0:
return text
assert video_count == len(video_prompts)
text_parts = text.split(self.video_placeholder)
assert len(text_parts) == len(video_prompts) + 1
text = "".join([
text_parts[i] + video_prompts[i] for i in range(len(video_prompts))
])
text += text_parts[-1]
return text
def preprocess_medias(self, medias: list[dict]) -> list[dict]:
updated_medias = []
video_prompts = []
for media in medias:
if media['type'] == 'image':
updated_medias.append(media)
elif media['type'] == 'video':
video_chunks = self.media_processor.split_video_chunks(
media['video'])
updated_medias.extend(video_chunks)
video_prompts.append("".join(
[vc['prompt'] for vc in video_chunks]))
else:
raise ValueError(f"unsupported media type: {media['type']}")
return updated_medias, video_prompts
def __call__(self,
messages: list[dict] = None,
medias: list[dict] = None,
text: str = None,
return_tensors: str = "pt",
**kwargs) -> BatchFeature:
"""
Process multimodal inputs for Kimi-K2.5 model.
This processor accepts ordered messages and extracts both media and text in a single pass.
text will be automatically updated if video input detected in messages
Args:
messages: List of message dicts with 'role' and 'content' fields.
If provided, medias and text will be extracted automatically.
medias: Pre-extracted list of media dicts. If None, extracted from messages.
text: Pre-formatted text string. If None, generated via apply_chat_template.
return_tensors: Format of returned tensors ('pt', 'np', 'tf'). Default: 'pt'.
**kwargs: Additional arguments passed to tokenizer.apply_chat_template.
Returns:
BatchFeature with fields: input_ids, attention_mask, pixel_values, grid_thws.
"""
if messages is None and (medias is None or text is None):
raise ValueError(
"Provide either 'messages' or both 'medias' and 'text'")
if medias is not None and text is not None:
updated_medias, video_prompts = self.preprocess_medias(medias)
preprocessed = self.media_processor.preprocess(
updated_medias, return_tensors=return_tensors)
text = self.update_raw_text(text, video_prompts)
text_inputs = self.tokenizer(text, return_tensors=return_tensors)
return BatchFeature(data={**text_inputs, **preprocessed.data})
if medias is None:
medias = self._extract_medias_from_messages(messages)
updated_medias, video_prompts = self.preprocess_medias(medias)
preprocessed = self.media_processor.preprocess(
updated_medias, return_tensors=return_tensors)
# Generate text if not provided
if text is None:
text = self.tokenizer.apply_chat_template(messages, **kwargs)
text = self.update_raw_text(text, video_prompts)
text_inputs = self.tokenizer(text, return_tensors=return_tensors)
return BatchFeature(data={**text_inputs, **preprocessed.data})
@staticmethod
def _extract_medias_from_messages(messages: list[dict]) -> list[dict]:
"""
Extract media items from messages in a single pass.
This is an optimized version that processes messages only once.
Kept as internal method since external callers should use __call__.
"""
medias = []
for msg in messages:
if msg['role'] != 'user' or not msg.get('content'):
continue
for content_part in msg['content']:
if not isinstance(content_part, dict):
continue
content_type = content_part.get('type')
if content_type in ['video_url', 'video']:
medias.append({
'type': 'video',
'video': content_part['video_url']['url'],
'first_frame_timestamp': 0.0
})
elif content_type in ['image_url', 'image']:
medias.append({
'type': 'image',
'image': content_part['image_url'],
})
return medias
def apply_chat_template(self, messages, **kwargs):
return self.tokenizer.apply_chat_template(messages, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ['input_ids', 'attention_mask', 'pixel_values', 'grid_thws']
|