Spaces:
Runtime error
Runtime error
| from typing import List, Union | |
| from PIL import Image | |
| import torch | |
| from transformers.feature_extraction_utils import BatchFeature | |
| from transformers.image_utils import ImageInput, get_image_size, to_numpy_array | |
| from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order | |
| from transformers.tokenization_utils_base import PreTokenizedInput, TextInput | |
| from transformers.utils import logging | |
| from transformers import Qwen2VLImageProcessor | |
| from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize | |
| logger = logging.get_logger(__name__) | |
| class TarsierProcessorKwargs(ProcessingKwargs, total=False): | |
| _defaults = { | |
| "text_kwargs": {}, | |
| "images_kwargs": {}, | |
| } | |
| class TarsierProcessor(ProcessorMixin): | |
| attributes = ["image_processor", "tokenizer"] | |
| valid_kwargs = ["chat_template", "image_token", "patch_size", "merge_size", "temporal_patch_size", "max_seq_len"] | |
| image_processor_class = "AutoImageProcessor" | |
| tokenizer_class = "AutoTokenizer" | |
| def __init__( | |
| self, | |
| image_processor=None, | |
| tokenizer=None, | |
| chat_template=None, | |
| image_token="<image>", | |
| patch_size=None, | |
| merge_size=1, | |
| temporal_patch_size=1, | |
| max_seq_len=8192, | |
| **kwargs, | |
| ) -> None: | |
| self.image_token = image_token | |
| self.patch_size = patch_size | |
| self.merge_size = merge_size | |
| self.temporal_patch_size = temporal_patch_size | |
| self.max_seq_len = max_seq_len | |
| self.max_pixels_per_sample = 128 * 384 * 384 | |
| super().__init__(image_processor, tokenizer, chat_template=chat_template) | |
| def __call__( | |
| self, | |
| messages, | |
| image_processing_config=None, | |
| is_training=True, | |
| ) -> torch.Tensor: | |
| output_kwargs = self._merge_kwargs( | |
| TarsierProcessorKwargs, | |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, | |
| ) | |
| # 【图片处理】 | |
| pixel_values, image_grid_thw = [], [] | |
| num_images = 0 | |
| for msg in messages: | |
| for content in msg['content']: | |
| if content['type'] == 'image': | |
| num_images += self.temporal_patch_size | |
| elif content['type'] == 'video': | |
| num_images += len(content['video']) | |
| if num_images > 0 and self.max_pixels_per_sample // num_images < image_processing_config['max_pixels']: | |
| image_processing_config['max_pixels'] = self.max_pixels_per_sample // num_images | |
| image_processing_config['min_pixels'] = min(image_processing_config['min_pixels'], image_processing_config['max_pixels']) | |
| for msg in messages: | |
| for content in msg['content']: | |
| if content['type'] == 'image': | |
| content['image'] = self.preprocess_image(content['image'], image_processing_config) | |
| content['image'] = self.image_processor(images = content['image'], **output_kwargs["images_kwargs"], return_tensors="pt") | |
| content['num_vision_tokens'] = self.get_num_vision_tokens(content) | |
| pixel_values.append(content['image']['pixel_values']) | |
| if 'image_grid_thw' in content['image']: | |
| image_grid_thw.extend(content['image']['image_grid_thw']) | |
| elif content['type'] == 'video': | |
| content['video'] = self.preprocess_image(content['video'], image_processing_config) | |
| if isinstance(self.image_processor, Qwen2VLImageProcessor): | |
| content['video'] = self.image_processor(images = None, videos = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt") | |
| pixel_values.append(content['video']['pixel_values_videos']) | |
| else: | |
| content['video'] = self.image_processor(images = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt") | |
| pixel_values.append(content['video']['pixel_values']) | |
| if 'video_grid_thw' in content['video']: | |
| image_grid_thw.extend(content['video']['video_grid_thw']) | |
| content['num_vision_tokens'] = self.get_num_vision_tokens(content) | |
| #【文本处理】 | |
| add_generation_prompt = (not is_training and messages[-1]['role'] != 'assistant') | |
| strip_final_eos = (not is_training and messages[-1]['role'] == 'assistant') | |
| text_inputs = self.tokenizer.apply_chat_template( | |
| messages, | |
| chat_template = self.chat_template, | |
| tokenize=True, | |
| tokenizer_kwargs = output_kwargs["text_kwargs"], | |
| return_assistant_tokens_mask=True, | |
| return_dict=True, | |
| add_generation_prompt=add_generation_prompt, | |
| strip_final_eos=strip_final_eos, | |
| ) | |
| labels = [-100 if j == 0 else i for i, j in zip(text_inputs['input_ids'], text_inputs['assistant_masks'])] | |
| labels = labels[:self.max_seq_len] | |
| input_ids = text_inputs['input_ids'][:self.max_seq_len] | |
| image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) | |
| if image_token_id in text_inputs['input_ids'][self.max_seq_len:]: | |
| raise ValueError(f'Too long sequence! {len(text_inputs["input_ids"])}') | |
| outputs = { | |
| 'input_ids': input_ids, | |
| 'labels': labels, | |
| 'num_images': num_images, | |
| } | |
| if len(pixel_values) > 0: | |
| outputs['pixel_values'] = torch.cat(pixel_values, dim=0) | |
| if len(image_grid_thw) > 0: | |
| outputs['image_grid_thw'] = torch.stack(image_grid_thw) | |
| return outputs | |
| def preprocess_image(self, pil_img: Union[Image.Image, List[Image.Image]], image_processing_config): | |
| if image_processing_config is None: | |
| return pil_img | |
| images = pil_img | |
| if isinstance(pil_img, Image.Image): | |
| images = [images] | |
| if image_processing_config['do_crop']: | |
| images = [self.centralcrop(img, rate=[4, 3]) for img in images] | |
| if image_processing_config['do_padding']: | |
| images = [self.expand2square( | |
| img, | |
| # tuple(int(x * 255) for x in self.processor.image_processor.image_mean) | |
| tuple(int(x * 255) for x in [0, 0, 0]) | |
| ) for img in images] | |
| if image_processing_config['do_resize']: | |
| images = [self.resize2square(img) for img in images] | |
| if image_processing_config.get('max_pixels'): | |
| images = [self.resize2pixels( | |
| img, | |
| int(image_processing_config['max_pixels']), | |
| int(image_processing_config['min_pixels']) | |
| ) for img in images] | |
| if isinstance(pil_img, Image.Image): | |
| images = images[0] | |
| return images | |
| def expand2square(self, pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| def resize2square(self, pil_img: Image.Image): | |
| width, height = pil_img.size | |
| pil_img = pil_img.resize((max(width, height), max(width, height))) | |
| return pil_img | |
| def centralcrop(self, pil_img: Image.Image, rate=[4, 3]): | |
| width, height = pil_img.size | |
| size = (width, height) | |
| min_len = min(size) | |
| longer_side = 0 if width >= height else 1 | |
| center = (width/2, height/2) | |
| box = [0, 0, size[0], size[1]] | |
| # if longer_side == 0: | |
| # box[0] = max(0, center[0] - 1/2*min_len/rate[1]*rate[0]) | |
| # box[2] = min(width, center[0] + 1/2*min_len/rate[1]*rate[0]) | |
| # else: | |
| # box[1] = max(0, center[1] - 1/2*min_len/rate[1]*rate[0]) | |
| # box[3] = min(height, center[1] + 1/2*min_len/rate[1]*rate[0]) | |
| box[longer_side] = max(0, center[longer_side] - 1/2*min_len/rate[1]*rate[0]) | |
| box[2 + longer_side] = min(size[longer_side], center[longer_side] + 1/2*min_len/rate[1]*rate[0]) | |
| # box = (width/2-min_len/2, height/2-min_len/2, width/2+min_len/2, height/2+min_len/2) | |
| pil_img = pil_img.crop(box) | |
| return pil_img | |
| def resize2pixels(self, pil_img: Image.Image, max_pixels=None, min_pixels=None): | |
| width, height = pil_img.size | |
| new_height, new_width = smart_resize(height, width, factor=1, max_pixels=max_pixels, min_pixels=min_pixels) | |
| pil_img = pil_img.resize((new_width, new_height)) | |
| return pil_img | |
| def get_num_vision_tokens(self, content): | |
| if isinstance(self.image_processor, Qwen2VLImageProcessor): | |
| merge_length = self.image_processor.merge_size**2 | |
| if content['type'] == 'image': | |
| num_image_tokens = content['image']['image_grid_thw'].prod() // merge_length | |
| else: | |
| num_image_tokens = content['video']['video_grid_thw'].prod() // merge_length | |
| return num_image_tokens | |
| else: | |
| # 其他模型:image tokens (-> 2x2 compressed) -> add image_newline and image_new | |
| k = 'image'if content['type'] == 'image' else 'video' | |
| pixel_values = content[k]['pixel_values'][0] | |
| n_frames = len(content[k]['pixel_values']) | |
| height, width = get_image_size(to_numpy_array(pixel_values)) | |
| num_image_tokens = (height // (self.patch_size * self.merge_size)) * (width // (self.patch_size * self.merge_size) + 1) + 1 | |
| return num_image_tokens * n_frames | |
| def batch_decode(self, *args, **kwargs): | |
| """ | |
| This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please | |
| refer to the docstring of this method for more information. | |
| """ | |
| return self.tokenizer.batch_decode(*args, **kwargs) | |
| def decode(self, *args, **kwargs): | |
| """ | |
| This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to | |
| the docstring of this method for more information. | |
| """ | |
| return self.tokenizer.decode(*args, **kwargs) | |
| def model_input_names(self): | |
| tokenizer_input_names = self.tokenizer.model_input_names | |
| image_processor_input_names = self.image_processor.model_input_names | |
| return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) | |