| import copy |
|
|
| import torch |
| from transformers import BaseImageProcessor, PreTrainedTokenizer |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack |
|
|
| try: |
| from .processor_core import ( |
| CHAT_TEMPLATE, |
| CHAT_TEMPLATE_FAKE_THINKING, |
| make_image_config_from_processor, |
| process_images, |
| ) |
| except ImportError: |
| from processor_core import ( |
| CHAT_TEMPLATE, |
| CHAT_TEMPLATE_FAKE_THINKING, |
| make_image_config_from_processor, |
| process_images, |
| ) |
|
|
|
|
| class ModRWKVProcessorKwargs(ProcessingKwargs, total=False): |
| _defaults = { |
| "text_kwargs": { |
| "padding": False, |
| "return_token_type_ids": False, |
| }, |
| "images_kwargs": {}, |
| } |
|
|
|
|
| class ModRWKVProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| tokenizer_class = "RwkvTokenizer" |
| user_image_tag = "<image>" |
|
|
| def __init__( |
| self, |
| tokenizer: PreTrainedTokenizer = None, |
| image_processor: BaseImageProcessor = None, |
| chat_template=None, |
| auto_insert_image_tags: bool = True, |
| total_pixels_budget: bool = True, |
| ): |
| chat_template = CHAT_TEMPLATE if chat_template is None else chat_template |
| super().__init__( |
| tokenizer=tokenizer, |
| image_processor=image_processor, |
| chat_template=chat_template, |
| ) |
| self.auto_insert_image_tags = auto_insert_image_tags |
| self.total_pixels_budget = total_pixels_budget |
| self.image_token = getattr(tokenizer, "image_token", "<|image_pad|>") |
| self.vision_start_token = getattr(tokenizer, "vision_start_token", "<|vision_start|>") |
| self.vision_end_token = getattr(tokenizer, "vision_end_token", "<|vision_end|>") |
| self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) |
| self.vision_start_token_id = self.tokenizer.convert_tokens_to_ids(self.vision_start_token) |
| self.vision_end_token_id = self.tokenizer.convert_tokens_to_ids(self.vision_end_token) |
| self.vision_image_token = ( |
| f"{self.vision_start_token}{self.image_token}{self.vision_end_token}" |
| ) |
|
|
| def to_dict(self): |
| output = {} |
| if self.image_processor is not None: |
| output["image_processor"] = self.image_processor.to_dict() |
| if getattr(self, "auto_map", None) is not None: |
| output["auto_map"] = copy.deepcopy(self.auto_map) |
| output["processor_class"] = self.__class__.__name__ |
| if not self.auto_insert_image_tags: |
| output["auto_insert_image_tags"] = False |
| output["total_pixels_budget"] = self.total_pixels_budget |
| return output |
|
|
| def _flatten_images(self, images): |
| if images is None: |
| return [] |
| if not isinstance(images, (list, tuple)): |
| return [images] |
|
|
| flat_images = [] |
| for item in images: |
| if isinstance(item, (list, tuple)): |
| flat_images.extend(self._flatten_images(item)) |
| else: |
| flat_images.append(item) |
| return flat_images |
|
|
| def _get_num_images_per_text_sample(self, images, batch_size): |
| if images is None: |
| return [0] * batch_size |
| if batch_size == 1: |
| return [len(self._flatten_images(images))] |
| if isinstance(images, (list, tuple)) and len(images) == batch_size: |
| return [len(self._flatten_images(sample_images)) for sample_images in images] |
| return None |
|
|
| def _get_images_per_text_sample(self, images, batch_size): |
| if images is None: |
| return [[] for _ in range(batch_size)] |
| if batch_size == 1: |
| return [self._flatten_images(images)] |
| if isinstance(images, (list, tuple)) and len(images) == batch_size: |
| return [self._flatten_images(sample_images) for sample_images in images] |
| return None |
|
|
| def _process_images(self, images, batch_size, images_kwargs): |
| image_groups = self._get_images_per_text_sample(images, batch_size) |
| if image_groups is None: |
| image_groups = [self._flatten_images(images)] |
| num_images_per_sample = None |
| else: |
| num_images_per_sample = [len(group) for group in image_groups] |
|
|
| image_config = make_image_config_from_processor( |
| self.image_processor, |
| **images_kwargs, |
| ) |
| processed_groups = [process_images(group, image_config) for group in image_groups] |
| num_image_tokens = [ |
| count |
| for processed in processed_groups |
| for count in processed.image_token_counts |
| ] |
| if not num_image_tokens: |
| return {}, None, None, num_images_per_sample |
|
|
| pixel_values = torch.cat( |
| [ |
| processed.flat_patches |
| for processed in processed_groups |
| if processed.flat_patches.numel() > 0 |
| ], |
| dim=0, |
| ) |
| image_grid_thw = torch.cat( |
| [ |
| processed.grid_thw |
| for processed in processed_groups |
| if processed.grid_thw.numel() > 0 |
| ], |
| dim=0, |
| ) |
| return ( |
| { |
| "pixel_values": pixel_values, |
| "image_grid_thw": image_grid_thw, |
| }, |
| image_grid_thw, |
| num_image_tokens, |
| num_images_per_sample, |
| ) |
|
|
| def _normalize_image_tags(self, text): |
| return text.replace(self.user_image_tag, self.vision_image_token) |
|
|
| def _strip_excess_image_tags(self, text, num_allowed): |
| tag = self.user_image_tag |
| count = text.count(tag) |
| if count <= num_allowed: |
| return text |
| parts = text.split(tag) |
| kept = tag.join(parts[: num_allowed + 1]) |
| rest = "".join(parts[num_allowed + 1 :]) |
| return kept + rest |
|
|
| def _append_missing_image_tags(self, text, num_missing_images): |
| if num_missing_images <= 0: |
| return text |
| return text + self.vision_image_token * num_missing_images |
|
|
| def _get_num_multimodal_tokens(self, image_grid_thw=None, **kwargs): |
| vision_data = {} |
| if image_grid_thw is not None: |
| processor_defaults = getattr(self.image_processor, "_defaults", {}) |
| images_kwargs = dict(processor_defaults.get("images_kwargs", {})) |
| images_kwargs.update(kwargs) |
| merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size |
|
|
| num_image_patches = [int(grid[0] * grid[1] * grid[2]) for grid in image_grid_thw] |
| num_image_tokens = [num_patches // merge_size**2 for num_patches in num_image_patches] |
| vision_data.update( |
| { |
| "num_image_tokens": num_image_tokens, |
| "num_image_patches": num_image_patches, |
| } |
| ) |
| return MultiModalData(**vision_data) |
|
|
| def _count_token_occurrences(self, input_ids, token_id): |
| return [sum(1 for token in sample_ids if token == token_id) for sample_ids in input_ids] |
|
|
| def _validate_image_token_alignment(self, text_inputs, expected_image_tokens, expected_num_images): |
| input_ids = text_inputs["input_ids"] |
| actual_image_tokens = self._count_token_occurrences(input_ids, self.image_token_id) |
| actual_vision_starts = self._count_token_occurrences(input_ids, self.vision_start_token_id) |
| actual_vision_ends = self._count_token_occurrences(input_ids, self.vision_end_token_id) |
|
|
| if actual_image_tokens != expected_image_tokens: |
| raise ValueError( |
| "Image token count does not match image_grid_thw-derived token count: " |
| f"expected {expected_image_tokens}, got {actual_image_tokens}." |
| ) |
| if actual_vision_starts != expected_num_images or actual_vision_ends != expected_num_images: |
| raise ValueError( |
| "Vision boundary token count does not match the number of image placeholders: " |
| f"expected {expected_num_images}, got starts={actual_vision_starts}, ends={actual_vision_ends}." |
| ) |
|
|
| def __call__(self, images=None, text=None, **kwargs: Unpack[ModRWKVProcessorKwargs]): |
| output_kwargs = self._merge_kwargs( |
| ModRWKVProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
|
|
| if not isinstance(text, list): |
| text = [text] if text is not None else None |
|
|
| batch_size = len(text) if text is not None else 1 |
| if images is not None: |
| ( |
| image_inputs, |
| image_grid_thw, |
| num_image_tokens, |
| num_images_per_sample, |
| ) = self._process_images( |
| images, |
| batch_size, |
| output_kwargs["images_kwargs"], |
| ) |
| else: |
| image_inputs = {} |
| image_grid_thw = None |
| num_image_tokens = None |
| num_images_per_sample = None |
|
|
| if text is None: |
| return BatchFeature(data=image_inputs) |
|
|
| text = text.copy() |
| expected_image_tokens = [0 for _ in text] |
| expected_num_images = [0 for _ in text] |
| if image_grid_thw is not None: |
| index = 0 |
| for i in range(len(text)): |
| if not self.auto_insert_image_tags: |
| text[i] = text[i].replace(self.user_image_tag, " ") |
| else: |
| if num_images_per_sample is not None: |
| text[i] = self._strip_excess_image_tags(text[i], num_images_per_sample[i]) |
| text[i] = self._normalize_image_tags(text[i]) |
|
|
| if self.auto_insert_image_tags and num_images_per_sample is not None: |
| missing = num_images_per_sample[i] - text[i].count(self.image_token) |
| text[i] = self._append_missing_image_tags(text[i], missing) |
|
|
| if self.auto_insert_image_tags: |
| placeholder_count = text[i].count(self.vision_image_token) |
| if index + placeholder_count > len(num_image_tokens): |
| raise ValueError( |
| "Number of image placeholders in text exceeds provided images: " |
| f"consumed {index + placeholder_count}, available {len(num_image_tokens)}." |
| ) |
| sample_counts = num_image_tokens[index : index + placeholder_count] |
| text[i] = self.tokenizer.expand_image_placeholders( |
| text[i], |
| sample_counts, |
| ) |
| expected_image_tokens[i] += sum(sample_counts) |
| expected_num_images[i] += len(sample_counts) |
| index += placeholder_count |
| else: |
| while self.image_token in text[i]: |
| if index >= len(num_image_tokens): |
| raise ValueError( |
| "Number of image placeholders in text exceeds provided images: " |
| f"consumed {index + 1}, available {len(num_image_tokens)}." |
| ) |
| text[i] = text[i].replace( |
| self.image_token, |
| "<|placeholder|>" * num_image_tokens[index], |
| 1, |
| ) |
| expected_image_tokens[i] += num_image_tokens[index] |
| expected_num_images[i] += 1 |
| index += 1 |
| text[i] = text[i].replace("<|placeholder|>", self.image_token) |
|
|
| if self.auto_insert_image_tags and index != len(num_image_tokens): |
| raise ValueError( |
| "Number of image placeholders in text does not match provided images: " |
| f"consumed {index}, available {len(num_image_tokens)}." |
| ) |
| else: |
| for i in range(len(text)): |
| text[i] = text[i].replace(self.user_image_tag, "") |
|
|
| return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
| text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
| if image_grid_thw is not None: |
| self._validate_image_token_alignment( |
| text_inputs, |
| expected_image_tokens, |
| expected_num_images, |
| ) |
| self._check_special_mm_tokens(text, text_inputs, modalities=["image"]) |
| return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) |
|
|
| def apply_chat_template(self, conversation, chat_template=None, **kwargs): |
| kwargs.setdefault("return_dict", True) |
| return super().apply_chat_template( |
| conversation, |
| chat_template=chat_template, |
| **kwargs, |
| ) |
|
|
|
|
| ModRWKVProcessor.register_for_auto_class("AutoProcessor") |
|
|