| | import re |
| | import types |
| | import io |
| | import torch |
| | from PIL import Image |
| | from qwen_vl_utils import fetch_image |
| |
|
| | from transformers import ( |
| | ProcessorMixin, |
| | SiglipImageProcessor, |
| | BatchFeature, |
| | Qwen2VLImageProcessor, |
| | PreTrainedTokenizer |
| | ) |
| |
|
| | from .utils import ( |
| | process_anyres_image, |
| | BLACK_IMG_ENV, |
| | DEFAULT_IM_END_TOKEN, |
| | DEFAULT_IM_START_TOKEN, |
| | DEFAULT_IMAGE_TOKEN, |
| | DEFAULT_VI_END_TOKEN, |
| | DEFAULT_VI_START_TOKEN, |
| | DEFAULT_VIDEO_TOKEN, |
| | IMAGE_TOKEN_INDEX, |
| | SEQ_MAX_LEN, |
| | ) |
| |
|
| | siglip_processor_config = { |
| | "do_normalize": True, |
| | "do_rescale": True, |
| | "do_resize": True, |
| | "image_mean": [ |
| | 0.5, |
| | 0.5, |
| | 0.5 |
| | ], |
| | "image_processor_type": "SiglipImageProcessor", |
| | "image_std": [ |
| | 0.5, |
| | 0.5, |
| | 0.5 |
| | ], |
| | "processor_class": "SiglipProcessor", |
| | "resample": 3, |
| | "rescale_factor": 0.00392156862745098, |
| | "size": { |
| | "height": 384, |
| | "width": 384 |
| | } |
| | } |
| |
|
| | qwen2vl_processor_config = { |
| | "min_pixels": 3136, |
| | "max_pixels": 12845056, |
| | "patch_size": 14, |
| | "temporal_patch_size": 2, |
| | "merge_size": 2, |
| | "image_mean": [ |
| | 0.48145466, |
| | 0.4578275, |
| | 0.40821073 |
| | ], |
| | "image_std": [ |
| | 0.26862954, |
| | 0.26130258, |
| | 0.27577711 |
| | ], |
| | "image_processor_type": "Qwen2VLImageProcessor", |
| | "processor_class": "Qwen2VLProcessor" |
| | } |
| |
|
| | class ValleyProcessor(ProcessorMixin): |
| | attributes = ["tokenizer"] |
| | optional_attributes = [ |
| | "max_pixels", |
| | "min_pixels", |
| | "anyres", |
| | "only_crop_single_image", |
| | "grid_pinpoints", |
| | "use_special_start_end_token", |
| | ] |
| | tokenizer_class = "AutoTokenizer" |
| |
|
| | def __init__(self, tokenizer=None, **kwargs): |
| | super().__init__(tokenizer, **kwargs) |
| | self.black_img = BLACK_IMG_ENV |
| | self.siglip_image_processor = SiglipImageProcessor.from_dict(siglip_processor_config) |
| | self.qwen2vl_image_processor = Qwen2VLImageProcessor.from_dict( |
| | qwen2vl_processor_config, |
| | ) |
| | |
| | self.anyres = kwargs.get("anyres", True) |
| | self.grid_pinpoints = kwargs.get("grid_pinpoints", "(1x1),...,(3x3)") |
| | self.only_crop_single_image = kwargs.get("only_crop_single_image", True) |
| | self.use_special_start_end_token = kwargs.get("use_special_start_end_token", True) |
| | self.only_navit = kwargs.get("only_navit", False) |
| |
|
| | def preprocess_images_siglip(self, images) -> torch.FloatTensor: |
| | if isinstance(images[0], str): |
| | images_pil = [Image.open(img).convert("RGB") for img in images] |
| | elif isinstance(images[0], Image.Image): |
| | images_pil = [img.convert("RGB") for img in images] |
| | elif isinstance(images[0], bytes): |
| | images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images] |
| | else: |
| | raise ValueError("unsupported type") |
| |
|
| | processed_images = [] |
| | have_multi_images = len(images_pil) > 1 |
| | for img in images_pil: |
| | if self.anyres: |
| | if not self.only_crop_single_image or not have_multi_images: |
| | image = process_anyres_image(img, self.siglip_image_processor, self.grid_pinpoints) |
| | else: |
| | image = [self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0]] |
| | else: |
| | image = self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0] |
| | |
| | processed_images.append(image) |
| |
|
| | if not self.anyres: |
| | return torch.stack(processed_images, dim=0) |
| | else: |
| | return [torch.stack(img, dim=0) for img in processed_images] |
| | |
| | def preprocess_images_qwen2vl(self, images) -> dict: |
| | if isinstance(images[0], str): |
| | images_pil = [Image.open(img).convert("RGB") for img in images] |
| | elif isinstance(images[0], Image.Image): |
| | images_pil = [img.convert("RGB") for img in images] |
| | elif isinstance(images[0], bytes): |
| | images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images] |
| | else: |
| | raise ValueError("unsupported type") |
| |
|
| | image_sizes = [[x.size for x in images_pil]] |
| | data_dict_qwen2vl = self.qwen2vl_image_processor( |
| | [fetch_image({"image": img}) for img in images_pil], |
| | return_tensors="pt" |
| | ) |
| |
|
| | data_dict_qwen2vl["image_sizes"] = image_sizes |
| |
|
| | return data_dict_qwen2vl |
| |
|
| | def preprocess_multimodal(self, conversations): |
| | for sentence in conversations: |
| | if sentence["role"] == "system": |
| | continue |
| | segs = re.split(DEFAULT_IMAGE_TOKEN, sentence["content"]) |
| | if self.use_special_start_end_token: |
| | sentence["content"] = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN).join(segs) |
| | else: |
| | sentence["content"] = DEFAULT_IMAGE_TOKEN.join(segs) |
| |
|
| | return conversations |
| |
|
| | def preprocess_qwen2( |
| | self, |
| | conversations, |
| | tokenizer: PreTrainedTokenizer, |
| | has_image: bool = False, |
| | inference: bool = False, |
| | only_mask_system: bool = False, |
| | ) -> dict: |
| | conv = types.SimpleNamespace( |
| | system="You are a helpful assistant.", |
| | roles=("user", "assistant"), |
| | version="qwen2", |
| | offset=0, |
| | sep="<|im_start|>", |
| | sep2="<|im_end|>\n", |
| | ) |
| |
|
| | |
| | assert conversations[0]["role"] == "system" |
| | if conversations[0]["content"] == None: |
| | conversations[0]["content"] = conv.system |
| | |
| | |
| | for j, sentence in enumerate(conversations[1:]): |
| | role = sentence["role"] |
| | assert role == conv.roles[j % 2], "The conversation sequence is incorrect." |
| | |
| | conversation_str = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=inference) |
| | |
| | |
| | rounds = conversation_str.split(conv.sep2) |
| | input_ids_ = torch.tensor([], dtype=torch.int64) |
| | targets_ = torch.tensor([], dtype=torch.int64) |
| | for i, rou in enumerate(rounds): |
| | if rou == "": |
| | continue |
| | if (not inference) or (i < (len(rounds) - 1)): |
| | rou += conv.sep2 |
| | if has_image: |
| | cur_input_ids_ = self.tokenizer_image_token(rou, tokenizer, return_tensors='pt') |
| | input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0) |
| | if only_mask_system: |
| | mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[0]}\n[\s\S]*', f'{conv.roles[0]}:', rou), |
| | tokenizer)) |
| | else: |
| | mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[1]}\n[\s\S]*', f'{conv.roles[1]}:', rou), |
| | tokenizer)) |
| | targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0) |
| | else: |
| | cur_input_ids_ = tokenizer(rou, return_tensors='pt')["input_ids"][0, :] |
| | input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0) |
| | mask_len = len(tokenizer(re.sub(rf'{conv.roles[1]}\n[\s\S]*', rf'{conv.roles[1]}:', rou))["input_ids"][:]) |
| | targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0) |
| | |
| | return {"input_ids": input_ids_, "labels": targets_} |
| |
|
| |
|
| | def tokenizer_image_token( |
| | self, |
| | prompt, |
| | tokenizer, |
| | image_token_index=IMAGE_TOKEN_INDEX, |
| | return_tensors=None, |
| | ): |
| | def split_with_token(string, token): |
| | result = string.split(token) |
| | for i in range(len(result) - 1): |
| | result.insert(i * 2 + 1, token) |
| | return result |
| |
|
| | if len(prompt) > SEQ_MAX_LEN: |
| | raise ValueError("sequence is too long !!!") |
| |
|
| | prompt_chunks = split_with_token(prompt, DEFAULT_IMAGE_TOKEN) |
| | input_ids, offset = ([tokenizer.bos_token_id], 1) if getattr(tokenizer,'bos_token',None) else ([], 0) |
| | token2index = {DEFAULT_IMAGE_TOKEN: image_token_index} |
| | for chunk in prompt_chunks: |
| | if chunk in token2index: |
| | input_ids.append(token2index[chunk]) |
| | else: |
| | chunk_ids = tokenizer(chunk).input_ids |
| | if chunk_ids[0] != getattr(tokenizer,'bos_token_id', None): |
| | offset = 0 |
| | input_ids.extend(chunk_ids[offset:]) |
| |
|
| | if return_tensors is not None: |
| | if return_tensors == "pt": |
| | return torch.tensor(input_ids, dtype=torch.long) |
| | raise ValueError(f"Unsupported tensor type: {return_tensors}") |
| | return input_ids |
| |
|
| |
|
| | def __call__(self, messages, inference=True, **kwargs) -> BatchFeature: |
| | max_pixels=kwargs.get("max_pixels", self.max_pixels) |
| | min_pixels=kwargs.get("min_pixels", self.min_pixels) |
| | print(f"max_pixels: {max_pixels}, min_pixels: {min_pixels}") |
| | if max_pixels is not None: |
| | self.qwen2vl_image_processor.max_pixels = max_pixels |
| | if min_pixels is not None: |
| | self.qwen2vl_image_processor.min_pixels = min_pixels |
| |
|
| | |
| | if "images" not in messages or not messages["images"] or not messages["images"][0]: |
| | images = [self.black_img] |
| | elif type(messages["images"]) == str: |
| | images = [messages["images"]] |
| | else: |
| | images = messages["images"][:16] |
| |
|
| | |
| | conversations = messages["conversations"] |
| | if conversations[0]["role"] != "system": |
| | conversations = [{"role":"system", "content": None}] + conversations |
| | |
| | |
| | assert conversations[1]["role"] == "user" |
| | if images and "<image>" not in conversations[1]["content"]: |
| | image_token = " ".join(["<image>"] * len(images)) |
| | conversations[1]["content"] = f"{image_token}\n{conversations[1]['content']}" |
| | |
| | |
| | if inference: |
| | assert conversations[-1]["role"] == "user", "the last message should be assistant if inference=True" |
| | |
| | |
| | if self.only_navit: |
| | precessed_images_siglip = None |
| | else: |
| | precessed_images_siglip = self.preprocess_images_siglip(images) |
| | processed_data_dict_qwen2vl = self.preprocess_images_qwen2vl(images) |
| | source = self.preprocess_multimodal(conversations) |
| | data_dict = self.preprocess_qwen2(source, self.tokenizer, has_image=True, only_mask_system=False, inference=inference) |
| | |
| | |
| | data_dict["input_ids"] = data_dict["input_ids"].unsqueeze(0) |
| | data_dict["labels"] = data_dict["labels"].unsqueeze(0) |
| | data_dict["images"] = [precessed_images_siglip] |
| | |
| | return BatchFeature(data={**data_dict, **processed_data_dict_qwen2vl}) |
| |
|
| | def batch_decode(self, *args, **kwargs): |
| | """ |
| | This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
| | refer to the docstring of this method for more information. |
| | """ |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| |
|
| | def decode(self, *args, **kwargs): |
| | """ |
| | This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
| | the docstring of this method for more information. |
| | """ |
| | return self.tokenizer.decode(*args, **kwargs) |