EFVLM3-36200 / processing_smallvlm.py
shilinxu's picture
Upload folder using huggingface_hub
064e549 verified
import torch
from collections import UserDict, OrderedDict
from typing import Union, List, Dict, Any
from transformers.processing_utils import ProcessorMixin
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils.chat_template_utils import render_jinja_template
class SmallVLMProcessor(ProcessorMixin):
attributes = ["tokenizer", "image_processor"]
optional_attributes = ['chat_template']
model_input_names = ['input_ids', 'attention_mask', 'pixel_values']
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
image_token = '<|image_pad|>'
def __init__(self, tokenizer, image_processor, chat_template, **kwargs):
super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template)
self.tokenizer.add_special_tokens({'additional_special_tokens': [self.image_token]}, replace_additional_special_tokens=False)
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
def __call__(self, inputs=None, images=[], text=None, **kwargs) -> BatchFeature:
truncation = kwargs.pop('truncation', False)
max_length = kwargs.pop('max_length', 1024)
padding = kwargs.pop('padding', False)
if inputs is None:
inputs = {}
if isinstance(inputs, UserDict):
inputs = inputs.data
if 'input_ids' not in inputs:
input_ids = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False, **kwargs)['input_ids'][0]
inputs['input_ids'] = input_ids.tolist()
inputs = self.process_images(images, inputs=inputs)
if 'attention_mask' not in inputs:
inputs['attention_mask'] = [1] * len(inputs['input_ids'])
if 'assistant_masks' in inputs:
inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')]
inputs = self.process_inputs(inputs)
if truncation and len(inputs['input_ids']) > max_length:
inputs = self.truncate(inputs, max_length)
if padding and len(inputs['input_ids']) < max_length:
inputs = self.padding(inputs, max_length)
inputs = self.to_tensor(inputs)
self.check(inputs)
new_inputs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
if "pixel_values" in inputs:
new_inputs['pixel_values'] = inputs['pixel_values']
new_inputs['pixel_attention_mask'] = inputs['pixel_attention_mask']
new_inputs['spatial_shapes'] = inputs['spatial_shapes']
if 'prompt_mask' in inputs:
new_inputs['prompt_mask'] = inputs['prompt_mask']
return BatchFeature(new_inputs)
def process_images(self, images, inputs):
if len(images) > 0:
pixel_values, spatial_shapes, pixel_attention_mask = self.image_transform(images)
else:
pixel_values = torch.zeros((0, self.image_processor.max_num_patches, 3*self.image_processor.patch_size**2), dtype=torch.float32)
spatial_shapes = torch.zeros((0, 2), dtype=torch.int64)
pixel_attention_mask = torch.ones((0, self.image_processor.max_num_patches), dtype=torch.int32)
inputs['pixel_values'] = pixel_values
inputs['spatial_shapes'] = spatial_shapes
inputs['pixel_attention_mask'] = pixel_attention_mask
return inputs
def image_transform(self, images):
image_inputs = self.image_processor(images, return_tensors='pt')
return image_inputs['pixel_values'], image_inputs['spatial_shapes'], image_inputs['pixel_attention_mask']
def truncate(self, inputs: Dict[str, Any], max_length: int):
assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed."
inputs['input_ids'] = inputs['input_ids'][:max_length]
inputs['attention_mask'] = inputs['attention_mask'][:max_length]
if 'prompt_mask' in inputs:
inputs['prompt_mask'] = inputs['prompt_mask'][:max_length]
return inputs
def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]:
spatial_shapes = inputs.get('spatial_shapes', None)
if spatial_shapes is None:
return []
image_token_lens = spatial_shapes.prod(dim=1).tolist()
return image_token_lens
def process_inputs(self, inputs: Dict[str, Any]):
graft_token_lens = self._get_graft_token_length(inputs)
inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id)
inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate')
if 'prompt_mask' in inputs:
inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate')
return inputs
def _graft_token(self, seq, graft_token_lens, value):
if value == 'replicate':
for i in reversed(graft_token_lens.keys()):
seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:]
else:
for i in reversed(graft_token_lens.keys()):
assert value == seq[i]
seq[i:] = [value] * graft_token_lens[i] + seq[i+1:]
return seq
def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]:
image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id]
image_token_lens = self.get_image_token_length(inputs)
assert len(image_token_pos) == len(image_token_lens), \
"Wrong image token count, " \
f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})"
graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens))
return graft_token_lens
def check(self, inputs: Dict[str, Any]):
image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item()
image_embed_count = sum(self.get_image_token_length(inputs))
assert image_embed_token_count == image_embed_count, "Wrong image embed token count"
def padding(self, inputs: Dict[str, Any], max_length: int):
padding_len = max_length - len(inputs['input_ids'])
inputs['input_ids'] += [self.pad_token_id] * padding_len
inputs['attention_mask'] += [0] * padding_len
if 'prompt_mask' in inputs:
inputs['prompt_mask'] += [0] * padding_len
return inputs
def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
text = self.tokenizer.decode(token_ids, **kwargs)
return text
def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs):
if isinstance(sequences, torch.Tensor):
sequences = sequences.tolist()
texts = self.tokenizer.batch_decode(sequences, **kwargs)
return texts
def to_tensor(self, inputs):
inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long)
inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool)
if 'prompt_mask' in inputs:
inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool)
return inputs
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def special_tokens(self):
return [token.content for token in self.tokenizer.added_tokens_decoder.values()]
def __repr__(self):
pass
def __str__(self):
return ''