| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ Processor class for Emu3. """ |
| |
|
| | from math import ceil |
| | import re |
| | from typing import List, Optional, Sequence, Union |
| | from functools import partial |
| |
|
| | from PIL import Image |
| | import torch |
| | from torch.nn import functional as F |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.image_utils import ImageInput, get_image_size, to_numpy_array |
| | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin |
| | from transformers.tokenization_utils_base import TextInput, PreTokenizedInput |
| | from transformers.utils import logging |
| |
|
| | from .utils_emu3 import Emu3PrefixConstrainedLogitsHelper |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class Emu3Processor(ProcessorMixin): |
| | r""" |
| | Constructs an Emu3 processor which wraps an Emu3 image processor and an Emu3 vision vq model and an Emu3 tokenizer into a single processor. |
| | |
| | [`Emu3Processor`] offers all the functionalities of [`Emu3VisionVQModel`] and [`Emu3Tokenizer`]. See the |
| | [`~Emu3Processor.__call__`], [`~Emu3Processor.decode`], [`~Emu3Processor.vision_encode`], [`~Emu3Processor.vision_decode`] |
| | for more information. |
| | |
| | Args: |
| | image_processor ([`Emu3VisionVQImageProcessor`]): |
| | The image processor is a required input. |
| | vision_tokenizer ([`Emu3VisionVQModel`]): |
| | The vision tokenizer is a required input. |
| | tokenizer ([`Emu3Tokenizer`]): |
| | The tokenizer is a required input. |
| | prefix_template(`str`, *optional*): |
| | The prefix template for image tokens |
| | visual_template(`Tuple[str, ...]`, *optional*): |
| | The visual token template for image tokens |
| | """ |
| |
|
| | attributes = ["image_processor", "tokenizer"] |
| | valid_kwargs = ["vision_tokenizer", "prefix_template", "visual_template"] |
| | image_processor_class = "AutoImageProcessor" |
| | tokenizer_class = "AutoTokenizer" |
| |
|
| | def __init__( |
| | self, |
| | image_processor=None, |
| | vision_tokenizer=None, |
| | tokenizer=None, |
| | chat_template="{image_prompt}{text_prompt}", |
| | prefix_template="{H}*{W}", |
| | visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>"), |
| | **kwargs, |
| | ): |
| | assert vision_tokenizer is not None, "image tokenizer can not be None" |
| |
|
| | self.vision_tokenizer = vision_tokenizer |
| | self.prefix_template = prefix_template |
| | self.visual_template = visual_template |
| | self.vis_tok_spatial_factor = 2 ** (len(self.vision_tokenizer.config.ch_mult) - 1) |
| |
|
| | super().__init__(image_processor, tokenizer, chat_template=chat_template) |
| | self.const_helper = self.build_const_helper() |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | text: Optional[TextInput | PreTokenizedInput] = None, |
| | image: Optional[Image.Image | List[Image.Image]] = None, |
| | *, |
| | mode: str = "G", |
| | ratio: str | List[str] = "1:1", |
| | image_area: int = 518400, |
| | padding_image: bool = False, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | """ |
| | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` |
| | and `kwargs` arguments to Emu3Tokenizer's [`~Emu3Tokenizer.__call__`] to encode the text. |
| | To prepare the image(s), this method forwards the `image` argument to |
| | Emu3VisionVQImageProcessor's [`~Emu3VisionVQImageProcessor.__call__`] and Emu3VisionVQModel's [`~EmuVideoVQModel.encode`] |
| | if `image` is not `None`. Please refer to the doctsring of the above two methods for more information. |
| | |
| | Args: |
| | text (`str` or `List[str]`): |
| | The sequence or a batch of sequence to be encoded. A sequence is a string. |
| | image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*): |
| | The image or a batch of images to be prepared. An image is a PIL image. |
| | mode (`str`, *optional*, in `G` or `U`): |
| | task mode, `G` for generation and `U` for understanding |
| | ratio (`str`, *optional*): |
| | the image width-height ratio for generation |
| | image_area (`int`, *optional*): |
| | image area used to calcualte the generated image height and width |
| | padding_image (`bool`, *optional*): |
| | whether pad images to same size for fast preprocessing if they have different sizes |
| | return_tensors (`str` or [`~utils.TensorType`], *optional*): |
| | If set, will return tensors of a particular framework. Acceptable values are: |
| | - `'pt'`: Return PyTorch `torch.Tensor` objects. |
| | - `'np'`: Return NumPy `np.ndarray` objects. |
| | |
| | Returns: |
| | [`BatchFeature`]: A [`BatchFeature`] with the following fields: |
| | |
| | - **input_ids** -- List of token ids to be fed to a model. |
| | - **image_size** -- List of image size of input images or generated images. |
| | """ |
| | assert mode in ('G', 'U'), "mode must be 'G' or 'U'." |
| | if isinstance(text, str): |
| | text = [text] |
| |
|
| | if isinstance(image, Image.Image): |
| | image = [image] |
| |
|
| | if not isinstance(text[0], str): |
| | raise ValueError("`text` must be string or list of string") |
| |
|
| | image_tokens = None |
| | if mode == 'G': |
| | if image is not None: |
| | raise ValueError("You have to specify only `text` in generation mode") |
| |
|
| | if isinstance(ratio, str): |
| | ratio = [ratio] * len(text) |
| |
|
| | if len(ratio) != len(text): |
| | raise ValueError("ratio number must match text number") |
| | else: |
| | if image is None: |
| | raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.") |
| |
|
| | if not isinstance(image, Sequence) and not isinstance(image, Image.Image): |
| | raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].") |
| |
|
| | if isinstance(image, Sequence) and not isinstance(image[0], Image.Image): |
| | raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].") |
| |
|
| | image_tokens = self.tokenize_image(image, padding_image=padding_image) |
| | if len(text) != len(image_tokens): |
| | raise ValueError("number of image must match number of text prompt") |
| |
|
| | prompt_list, size_list = [], [] |
| | for idx, text_prompt in enumerate(text): |
| | prompt = self.tokenizer.bos_token |
| | if mode == 'U': |
| | h, w = image_tokens[idx].shape |
| | imgstr = self.to_imgstr(image_tokens[idx]) |
| | image_prompt = ( |
| | self.tokenizer.boi_token + |
| | self.prefix_template.format(H=h, W=w) + |
| | self.tokenizer.img_token + |
| | imgstr + |
| | self.tokenizer.eol_token + |
| | self.tokenizer.eof_token + |
| | self.tokenizer.eoi_token |
| | ) |
| | prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt) |
| | else: |
| | h, w = self.calculate_generate_size(ratio[idx], image_area, self.vision_tokenizer.spatial_scale_factor) |
| | image_prompt = ( |
| | self.tokenizer.boi_token + |
| | self.prefix_template.format(H=h, W=w) + |
| | self.tokenizer.img_token |
| | ) |
| | prompt += (text_prompt + image_prompt) |
| |
|
| | prompt_list.append(prompt) |
| | size_list.append([h, w]) |
| |
|
| | text_inputs = self.tokenizer(prompt_list, **kwargs) |
| | return BatchFeature(data={**text_inputs, "image_size": size_list}, tensor_type=kwargs.get("return_tensors")) |
| |
|
| | @torch.no_grad() |
| | def batch_decode(self, *args, **kwargs): |
| | docs = self.tokenizer.batch_decode(*args, **kwargs) |
| | return [self.multimodal_decode(d) for d in docs] |
| |
|
| | @torch.no_grad() |
| | def decode(self, *args, **kwargs): |
| | doc = self.tokenizer.decode(*args, **kwargs) |
| | return self.multimodal_decode(doc) |
| |
|
| | @torch.no_grad() |
| | def vision_encode(self, *args, **kwargs): |
| | return self.vision_tokenizer.encode(*args, **kwargs) |
| |
|
| | @torch.no_grad() |
| | def vision_decode(self, *args, **kwargs): |
| | return self.vision_tokenizer.decode(*args, **kwargs) |
| |
|
| | @torch.no_grad() |
| | def multimodal_decode(self, doc): |
| | multimodal_output = [] |
| | pattern = rf'({re.escape(self.tokenizer.boi_token)}.*?{re.escape(self.tokenizer.eoi_token)})' |
| | chunks = re.split(pattern, doc) |
| | for c in chunks: |
| | if len(c) == 0: |
| | continue |
| |
|
| | if self.tokenizer.boi_token in c: |
| | image = [] |
| | image_rows = re.split(re.escape(self.tokenizer.eol_token), c) |
| | for r in image_rows: |
| | token_ids = re.findall(self.visual_template[1], r) |
| | if len(token_ids) > 0: |
| | row_token = [int(m) for m in token_ids] |
| | image.append(row_token) |
| | image = torch.tensor(image, dtype=torch.long, device=self.vision_tokenizer.device) |
| | image = self.vision_tokenizer.decode(image[None]).float() |
| | image = self.image_processor.postprocess(image)["pixel_values"][0] |
| | multimodal_output.append(image) |
| | else: |
| | multimodal_output.append(c) |
| |
|
| | return multimodal_output if len(multimodal_output) > 1 else multimodal_output[0] |
| |
|
| | @property |
| | def model_input_names(self): |
| | tokenizer_input_names = self.tokenizer.model_input_names |
| | image_processor_input_names = self.image_processor.model_input_names |
| | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
| |
|
| | def to_imgstr(self, image_tokens): |
| | image_tokens = image_tokens.cpu().numpy().tolist() |
| | image_token_str = [ |
| | [ |
| | self.visual_template[0].format(token_id=token_id) |
| | for token_id in token_row |
| | ] |
| | for token_row in image_tokens |
| | ] |
| | image_row_str = ["".join(token_row) for token_row in image_token_str] |
| | imgstr = self.tokenizer.eol_token.join(image_row_str) |
| | return imgstr |
| |
|
| | def calculate_generate_size(self, ratio, image_area, spatial_scale_factor): |
| | w, h = map(int, ratio.split(":")) |
| | current_area = h * w |
| | target_ratio = (image_area / current_area) ** 0.5 |
| |
|
| | th = int(round(h * target_ratio / spatial_scale_factor)) |
| | tw = int(round(w * target_ratio / spatial_scale_factor)) |
| | return th, tw |
| |
|
| | def tokenize_image(self, image: List[Image.Image], *, padding_image: bool = False): |
| | is_all_same_size, prev_size = True, None |
| | for im in image: |
| | if prev_size is not None: |
| | is_all_same_size &= (prev_size == im.size) |
| | prev_size = im.size |
| |
|
| | if is_all_same_size: |
| | image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"] |
| | image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) |
| | image_tokens = self.vision_tokenizer.encode(image_inputs) |
| | elif padding_image: |
| | image_inputs = [self.image_processor(im, return_tensors="pt")["pixel_values"] for im in image] |
| | image_shapes = [im.shape[2:] for im in image_inputs] |
| | max_shape = ( |
| | max([im_shape[0] for im_shape in image_shapes]), |
| | max([im_shape[1] for im_shape in image_shapes]), |
| | ) |
| | image_inputs = [ |
| | F.pad(im_inp, (0, max_shape[1] - im_shape[1], 0, max_shape[0] - im_shape[0])) |
| | for im_inp, im_shape in zip(image_inputs, image_shapes) |
| | ] |
| | image_inputs = torch.cat(image_inputs, dim=0).to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) |
| | image_tokens = self.vision_tokenizer.encode(image_inputs) |
| | image_tokens = [ |
| | im_tok[:ceil(im_shape[0] / self.vis_tok_spatial_factor), :ceil(im_shape[1] / self.vis_tok_spatial_factor)] |
| | for im_tok, im_shape in zip(image_tokens, image_shapes) |
| | ] |
| | else: |
| | image_tokens = [] |
| | for im in image: |
| | image_input = self.image_processor(im, return_tensors="pt")["pixel_values"] |
| | image_input = image_input.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) |
| | image_tokens.append(self.vision_tokenizer.encode(image_input).squeeze(0)) |
| |
|
| | return image_tokens |
| |
|
| | def build_const_helper(self): |
| | ( |
| | img_token, |
| | eoi_token, |
| | eos_token, |
| | eol_token, |
| | eof_token, |
| | pad_token, |
| | vis_start, |
| | vis_end, |
| | ) = self.tokenizer.encode([ |
| | self.tokenizer.img_token, |
| | self.tokenizer.eoi_token, |
| | self.tokenizer.eos_token, |
| | self.tokenizer.eol_token, |
| | self.tokenizer.eof_token, |
| | self.tokenizer.pad_token, |
| | self.visual_template[0].format(token_id=0), |
| | self.visual_template[0].format(token_id=self.vision_tokenizer.config.codebook_size - 1), |
| | ]) |
| |
|
| | const_helper = partial( |
| | Emu3PrefixConstrainedLogitsHelper, |
| | img_token=img_token, |
| | eoi_token=eoi_token, |
| | eos_token=eos_token, |
| | eol_token=eol_token, |
| | eof_token=eof_token, |
| | pad_token=pad_token, |
| | visual_tokens=list(range(vis_start, vis_end + 1)), |
| | ) |
| | return const_helper |
| |
|
| | def build_prefix_constrained_fn(self, height, width): |
| | helper = self.const_helper(height=height, width=width) |
| | return helper |
| |
|