File size: 7,797 Bytes
362ada4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch
from collections import UserDict, OrderedDict
from typing import Union, List, Dict, Any
from transformers.processing_utils import ProcessorMixin
from transformers.feature_extraction_utils import BatchFeature
from transformers.utils.chat_template_utils import render_jinja_template
class SmallVLMProcessor(ProcessorMixin):
attributes = ["tokenizer", "image_processor"]
optional_attributes = ['chat_template']
model_input_names = ['input_ids', 'attention_mask', 'pixel_values']
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
image_token = '<|image_pad|>'
def __init__(self, tokenizer, image_processor, chat_template, **kwargs):
super().__init__(tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template)
self.tokenizer.add_special_tokens({'additional_special_tokens': [self.image_token]}, replace_additional_special_tokens=False)
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
def __call__(self, inputs=None, images=[], text=None, **kwargs) -> BatchFeature:
truncation = kwargs.pop('truncation', False)
max_length = kwargs.pop('max_length', 1024)
padding = kwargs.pop('padding', False)
if inputs is None:
inputs = {}
if isinstance(inputs, UserDict):
inputs = inputs.data
if 'input_ids' not in inputs:
input_ids = self.tokenizer(text, padding=False, truncation=False, return_attention_mask=False, **kwargs)['input_ids'][0]
inputs['input_ids'] = input_ids.tolist()
inputs = self.process_images(images, inputs=inputs)
if 'attention_mask' not in inputs:
inputs['attention_mask'] = [1] * len(inputs['input_ids'])
if 'assistant_masks' in inputs:
inputs['prompt_mask'] = [1-x for x in inputs.pop('assistant_masks')]
inputs = self.process_inputs(inputs)
if truncation and len(inputs['input_ids']) > max_length:
inputs = self.truncate(inputs, max_length)
if padding and len(inputs['input_ids']) < max_length:
inputs = self.padding(inputs, max_length)
inputs = self.to_tensor(inputs)
self.check(inputs)
new_inputs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
if "pixel_values" in inputs:
new_inputs['pixel_values'] = inputs['pixel_values']
new_inputs['pixel_attention_mask'] = inputs['pixel_attention_mask']
new_inputs['spatial_shapes'] = inputs['spatial_shapes']
if 'prompt_mask' in inputs:
new_inputs['prompt_mask'] = inputs['prompt_mask']
return BatchFeature(new_inputs)
def process_images(self, images, inputs):
if len(images) > 0:
pixel_values, spatial_shapes, pixel_attention_mask = self.image_transform(images)
else:
pixel_values = torch.zeros((0, self.image_processor.max_num_patches, 3*self.image_processor.patch_size**2), dtype=torch.float32)
spatial_shapes = torch.zeros((0, 2), dtype=torch.int64)
pixel_attention_mask = torch.ones((0, self.image_processor.max_num_patches), dtype=torch.int32)
inputs['pixel_values'] = pixel_values
inputs['spatial_shapes'] = spatial_shapes
inputs['pixel_attention_mask'] = pixel_attention_mask
return inputs
def image_transform(self, images):
image_inputs = self.image_processor(images, return_tensors='pt')
return image_inputs['pixel_values'], image_inputs['spatial_shapes'], image_inputs['pixel_attention_mask']
def truncate(self, inputs: Dict[str, Any], max_length: int):
assert self.image_token_id not in inputs['input_ids'][max_length:], f"Truncate image token is not allowed."
inputs['input_ids'] = inputs['input_ids'][:max_length]
inputs['attention_mask'] = inputs['attention_mask'][:max_length]
if 'prompt_mask' in inputs:
inputs['prompt_mask'] = inputs['prompt_mask'][:max_length]
return inputs
def get_image_token_length(self, inputs: Dict[str, Any]) -> List[int]:
spatial_shapes = inputs.get('spatial_shapes', None)
if spatial_shapes is None:
return []
image_token_lens = spatial_shapes.prod(dim=1).tolist()
return image_token_lens
def process_inputs(self, inputs: Dict[str, Any]):
graft_token_lens = self._get_graft_token_length(inputs)
inputs['input_ids'] = self._graft_token(inputs['input_ids'], graft_token_lens, self.image_token_id)
inputs['attention_mask'] = self._graft_token(inputs['attention_mask'], graft_token_lens, 'replicate')
if 'prompt_mask' in inputs:
inputs['prompt_mask'] = self._graft_token(inputs['prompt_mask'], graft_token_lens, 'replicate')
return inputs
def _graft_token(self, seq, graft_token_lens, value):
if value == 'replicate':
for i in reversed(graft_token_lens.keys()):
seq[i:] = [seq[i]] * graft_token_lens[i] + seq[i+1:]
else:
for i in reversed(graft_token_lens.keys()):
assert value == seq[i]
seq[i:] = [value] * graft_token_lens[i] + seq[i+1:]
return seq
def _get_graft_token_length(self, inputs: Dict[str, Any]) -> Dict[int, int]:
image_token_pos = [i for i, x in enumerate(inputs['input_ids']) if x == self.image_token_id]
image_token_lens = self.get_image_token_length(inputs)
assert len(image_token_pos) == len(image_token_lens), \
"Wrong image token count, " \
f"image_token_count({len(image_token_pos)}) != image_count({len(image_token_lens)})"
graft_token_lens = OrderedDict(item for item in zip(image_token_pos, image_token_lens))
return graft_token_lens
def check(self, inputs: Dict[str, Any]):
image_embed_token_count = torch.count_nonzero(inputs['input_ids'] == self.image_token_id).item()
image_embed_count = sum(self.get_image_token_length(inputs))
assert image_embed_token_count == image_embed_count, "Wrong image embed token count"
def padding(self, inputs: Dict[str, Any], max_length: int):
padding_len = max_length - len(inputs['input_ids'])
inputs['input_ids'] += [self.pad_token_id] * padding_len
inputs['attention_mask'] += [0] * padding_len
if 'prompt_mask' in inputs:
inputs['prompt_mask'] += [0] * padding_len
return inputs
def decode(self, token_ids: Union[List[int], torch.Tensor], **kwargs):
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.tolist()
text = self.tokenizer.decode(token_ids, **kwargs)
return text
def batch_decode(self, sequences: Union[List[List[int]], torch.Tensor], **kwargs):
if isinstance(sequences, torch.Tensor):
sequences = sequences.tolist()
texts = self.tokenizer.batch_decode(sequences, **kwargs)
return texts
def to_tensor(self, inputs):
inputs['input_ids'] = torch.tensor([inputs['input_ids']], dtype=torch.long)
inputs['attention_mask'] = torch.tensor([inputs['attention_mask']], dtype=torch.bool)
if 'prompt_mask' in inputs:
inputs['prompt_mask'] = torch.tensor([inputs['prompt_mask']], dtype=torch.bool)
return inputs
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def special_tokens(self):
return [token.content for token in self.tokenizer.added_tokens_decoder.values()]
def __repr__(self):
pass
def __str__(self):
return '' |