import torch from collections import UserDict, OrderedDict from typing import Union, List, Dict, Any from transformers.processing_utils import ProcessorMixin from transformers.feature_extraction_utils import BatchFeature from transformers.utils.chat_template_utils import render_jinja_template class SmallVLMProcessor(ProcessorMixin): attributes = ["tokenizer", "image_processor"] optional_attributes = ['chat_template'] model_input_names = ['input_ids', 'attention_mask', 'pixel_values'] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" image_token = '<|image_pad|>' def __init__(self, tokenizer, image_processor, chat_template, **kwargs): super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template) self.tokenizer.add_special_tokens({'additional_special_tokens': [self.image_token]}, replace_additional_special_tokens=False) self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) def __call__(self, inputs=None, images=[], text=None, **kwargs) -> BatchFeature: truncation = kwargs.pop('truncation', False) max_length = kwargs.pop('max_length', 1024) padding = kwargs.pop('padding', False) if inputs is None: inputs = {} if isinstance(inputs, UserDict): inputs = inputs.data if 'input_ids' not in inputs: input_ids = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False, **kwargs)['input_ids'][0] inputs['input_ids'] = input_ids.tolist() inputs = self.process_images(images, inputs=inputs) if 'attention_mask' not in inputs: inputs['attention_mask'] = [1] * len(inputs['input_ids']) if 'assistant_masks' in inputs: inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')] inputs = self.process_inputs(inputs) if truncation and len(inputs['input_ids']) > max_length: inputs = self.truncate(inputs, max_length) if padding and len(inputs['input_ids']) < max_length: inputs = self.padding(inputs, max_length) inputs = self.to_tensor(inputs) self.check(inputs) new_inputs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], } if "pixel_values" in inputs: new_inputs['pixel_values'] = inputs['pixel_values'] new_inputs['pixel_attention_mask'] = inputs['pixel_attention_mask'] new_inputs['spatial_shapes'] = inputs['spatial_shapes'] if 'prompt_mask' in inputs: new_inputs['prompt_mask'] = inputs['prompt_mask'] return BatchFeature(new_inputs) def process_images(self, images, inputs): if len(images) > 0: pixel_values, spatial_shapes, pixel_attention_mask = self.image_transform(images) else: pixel_values = torch.zeros((0, self.image_processor.max_num_patches, 3*self.image_processor.patch_size**2), dtype=torch.float32) spatial_shapes = torch.zeros((0, 2), dtype=torch.int64) pixel_attention_mask = torch.ones((0, self.image_processor.max_num_patches), dtype=torch.int32) inputs['pixel_values'] = pixel_values inputs['spatial_shapes'] = spatial_shapes inputs['pixel_attention_mask'] = pixel_attention_mask return inputs def image_transform(self, images): image_inputs = self.image_processor(images, return_tensors='pt') return image_inputs['pixel_values'], image_inputs['spatial_shapes'], image_inputs['pixel_attention_mask'] def truncate(self, inputs: Dict[str, Any], max_length: int): assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed." inputs['input_ids'] = inputs['input_ids'][:max_length] inputs['attention_mask'] = inputs['attention_mask'][:max_length] if 'prompt_mask' in inputs: inputs['prompt_mask'] = inputs['prompt_mask'][:max_length] return inputs def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]: spatial_shapes = inputs.get('spatial_shapes', None) if spatial_shapes is None: return [] image_token_lens = spatial_shapes.prod(dim=1).tolist() return image_token_lens def process_inputs(self, inputs: Dict[str, Any]): graft_token_lens = self._get_graft_token_length(inputs) inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id) inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate') if 'prompt_mask' in inputs: inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate') return inputs def _graft_token(self, seq, graft_token_lens, value): if value == 'replicate': for i in reversed(graft_token_lens.keys()): seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:] else: for i in reversed(graft_token_lens.keys()): assert value == seq[i] seq[i:] = [value] * graft_token_lens[i] + seq[i+1:] return seq def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]: image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id] image_token_lens = self.get_image_token_length(inputs) assert len(image_token_pos) == len(image_token_lens), \ "Wrong image token count, " \ f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})" graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens)) return graft_token_lens def check(self, inputs: Dict[str, Any]): image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item() image_embed_count = sum(self.get_image_token_length(inputs)) assert image_embed_token_count == image_embed_count, "Wrong image embed token count" def padding(self, inputs: Dict[str, Any], max_length: int): padding_len = max_length - len(inputs['input_ids']) inputs['input_ids'] += [self.pad_token_id] * padding_len inputs['attention_mask'] += [0] * padding_len if 'prompt_mask' in inputs: inputs['prompt_mask'] += [0] * padding_len return inputs def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs): if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist() text = self.tokenizer.decode(token_ids, **kwargs) return text def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs): if isinstance(sequences, torch.Tensor): sequences = sequences.tolist() texts = self.tokenizer.batch_decode(sequences, **kwargs) return texts def to_tensor(self, inputs): inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long) inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool) if 'prompt_mask' in inputs: inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool) return inputs @property def pad_token_id(self): return self.tokenizer.pad_token_id @property def special_tokens(self): return [token.content for token in self.tokenizer.added_tokens_decoder.values()] def __repr__(self): pass def __str__(self): return ''