|
|
import torch |
|
|
|
|
|
from collections import UserDict, OrderedDict |
|
|
from typing import Union, List, Dict, Any |
|
|
|
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.utils.chat_template_utils import render_jinja_template |
|
|
|
|
|
|
|
|
class SmallVLMProcessor(ProcessorMixin): |
|
|
attributes = ["tokenizer", "image_processor"] |
|
|
optional_attributes = ['chat_template'] |
|
|
model_input_names = ['input_ids', 'attention_mask', 'pixel_values'] |
|
|
image_processor_class = "AutoImageProcessor" |
|
|
tokenizer_class = "AutoTokenizer" |
|
|
|
|
|
image_token = '<|image_pad|>' |
|
|
|
|
|
def __init__(self, tokenizer, image_processor, chat_template, **kwargs): |
|
|
super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template) |
|
|
self.tokenizer.add_special_tokens({'additional_special_tokens': [self.image_token]}, replace_additional_special_tokens=False) |
|
|
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) |
|
|
|
|
|
def __call__(self, inputs=None, images=[], text=None, **kwargs) -> BatchFeature: |
|
|
|
|
|
truncation = kwargs.pop('truncation', False) |
|
|
max_length = kwargs.pop('max_length', 1024) |
|
|
padding = kwargs.pop('padding', False) |
|
|
|
|
|
if inputs is None: |
|
|
inputs = {} |
|
|
if isinstance(inputs, UserDict): |
|
|
inputs = inputs.data |
|
|
|
|
|
if 'input_ids' not in inputs: |
|
|
input_ids = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False, **kwargs)['input_ids'][0] |
|
|
inputs['input_ids'] = input_ids.tolist() |
|
|
|
|
|
inputs = self.process_images(images, inputs=inputs) |
|
|
|
|
|
if 'attention_mask' not in inputs: |
|
|
inputs['attention_mask'] = [1] * len(inputs['input_ids']) |
|
|
|
|
|
if 'assistant_masks' in inputs: |
|
|
inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')] |
|
|
|
|
|
inputs = self.process_inputs(inputs) |
|
|
|
|
|
|
|
|
if truncation and len(inputs['input_ids']) > max_length: |
|
|
inputs = self.truncate(inputs, max_length) |
|
|
|
|
|
if padding and len(inputs['input_ids']) < max_length: |
|
|
inputs = self.padding(inputs, max_length) |
|
|
|
|
|
inputs = self.to_tensor(inputs) |
|
|
|
|
|
self.check(inputs) |
|
|
|
|
|
new_inputs = { |
|
|
"input_ids": inputs["input_ids"], |
|
|
"attention_mask": inputs["attention_mask"], |
|
|
} |
|
|
if "pixel_values" in inputs: |
|
|
new_inputs['pixel_values'] = inputs['pixel_values'] |
|
|
new_inputs['pixel_attention_mask'] = inputs['pixel_attention_mask'] |
|
|
new_inputs['spatial_shapes'] = inputs['spatial_shapes'] |
|
|
if 'prompt_mask' in inputs: |
|
|
new_inputs['prompt_mask'] = inputs['prompt_mask'] |
|
|
|
|
|
return BatchFeature(new_inputs) |
|
|
|
|
|
def process_images(self, images, inputs): |
|
|
if len(images) > 0: |
|
|
pixel_values, spatial_shapes, pixel_attention_mask = self.image_transform(images) |
|
|
else: |
|
|
pixel_values = torch.zeros((0, self.image_processor.max_num_patches, 3*self.image_processor.patch_size**2), dtype=torch.float32) |
|
|
spatial_shapes = torch.zeros((0, 2), dtype=torch.int64) |
|
|
pixel_attention_mask = torch.ones((0, self.image_processor.max_num_patches), dtype=torch.int32) |
|
|
|
|
|
inputs['pixel_values'] = pixel_values |
|
|
inputs['spatial_shapes'] = spatial_shapes |
|
|
inputs['pixel_attention_mask'] = pixel_attention_mask |
|
|
return inputs |
|
|
|
|
|
def image_transform(self, images): |
|
|
image_inputs = self.image_processor(images, return_tensors='pt') |
|
|
return image_inputs['pixel_values'], image_inputs['spatial_shapes'], image_inputs['pixel_attention_mask'] |
|
|
|
|
|
def truncate(self, inputs: Dict[str, Any], max_length: int): |
|
|
assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed." |
|
|
|
|
|
inputs['input_ids'] = inputs['input_ids'][:max_length] |
|
|
inputs['attention_mask'] = inputs['attention_mask'][:max_length] |
|
|
if 'prompt_mask' in inputs: |
|
|
inputs['prompt_mask'] = inputs['prompt_mask'][:max_length] |
|
|
|
|
|
return inputs |
|
|
|
|
|
def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]: |
|
|
spatial_shapes = inputs.get('spatial_shapes', None) |
|
|
if spatial_shapes is None: |
|
|
return [] |
|
|
image_token_lens = spatial_shapes.prod(dim=1).tolist() |
|
|
return image_token_lens |
|
|
|
|
|
def process_inputs(self, inputs: Dict[str, Any]): |
|
|
graft_token_lens = self._get_graft_token_length(inputs) |
|
|
|
|
|
inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id) |
|
|
inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate') |
|
|
if 'prompt_mask' in inputs: |
|
|
inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate') |
|
|
|
|
|
return inputs |
|
|
|
|
|
def _graft_token(self, seq, graft_token_lens, value): |
|
|
if value == 'replicate': |
|
|
for i in reversed(graft_token_lens.keys()): |
|
|
seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:] |
|
|
else: |
|
|
for i in reversed(graft_token_lens.keys()): |
|
|
assert value == seq[i] |
|
|
seq[i:] = [value] * graft_token_lens[i] + seq[i+1:] |
|
|
return seq |
|
|
|
|
|
def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]: |
|
|
image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id] |
|
|
image_token_lens = self.get_image_token_length(inputs) |
|
|
|
|
|
assert len(image_token_pos) == len(image_token_lens), \ |
|
|
"Wrong image token count, " \ |
|
|
f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})" |
|
|
|
|
|
graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens)) |
|
|
|
|
|
return graft_token_lens |
|
|
|
|
|
def check(self, inputs: Dict[str, Any]): |
|
|
image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item() |
|
|
image_embed_count = sum(self.get_image_token_length(inputs)) |
|
|
assert image_embed_token_count == image_embed_count, "Wrong image embed token count" |
|
|
|
|
|
def padding(self, inputs: Dict[str, Any], max_length: int): |
|
|
padding_len = max_length - len(inputs['input_ids']) |
|
|
inputs['input_ids'] += [self.pad_token_id] * padding_len |
|
|
inputs['attention_mask'] += [0] * padding_len |
|
|
if 'prompt_mask' in inputs: |
|
|
inputs['prompt_mask'] += [0] * padding_len |
|
|
return inputs |
|
|
|
|
|
def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs): |
|
|
if isinstance(token_ids, torch.Tensor): |
|
|
token_ids = token_ids.tolist() |
|
|
text = self.tokenizer.decode(token_ids, **kwargs) |
|
|
return text |
|
|
|
|
|
def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs): |
|
|
if isinstance(sequences, torch.Tensor): |
|
|
sequences = sequences.tolist() |
|
|
texts = self.tokenizer.batch_decode(sequences, **kwargs) |
|
|
return texts |
|
|
|
|
|
def to_tensor(self, inputs): |
|
|
inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long) |
|
|
inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool) |
|
|
if 'prompt_mask' in inputs: |
|
|
inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool) |
|
|
return inputs |
|
|
|
|
|
@property |
|
|
def pad_token_id(self): |
|
|
return self.tokenizer.pad_token_id |
|
|
|
|
|
@property |
|
|
def special_tokens(self): |
|
|
return [token.content for token in self.tokenizer.added_tokens_decoder.values()] |
|
|
|
|
|
def __repr__(self): |
|
|
pass |
|
|
|
|
|
def __str__(self): |
|
|
return '' |