| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import io |
| import re |
| import subprocess |
| from collections import UserDict |
| from typing import List, Literal, Optional, Tuple, Union |
|
|
| import numpy as np |
| import PIL |
| import PIL.Image |
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
| from transformers import TensorType |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.image_utils import ImageInput |
| from transformers.processing_utils import ProcessorMixin |
|
|
| from .image_processing_megrezo import MegrezOImageProcessor |
|
|
| AudioInput = Union[str, bytes, np.ndarray, List[str], List[bytes], List[np.ndarray]] |
| ReturnTensorType = Union[str, TensorType] |
|
|
|
|
| class ImageBatchFeature(BatchFeature): |
| r""" |
| Holds the image features of a batch of images. |
| """ |
|
|
| pixel_values: Union[np.ndarray, torch.Tensor] |
| image_sizes: Union[np.ndarray, torch.Tensor] |
| tgt_sizes: Union[np.ndarray, torch.Tensor] |
| patch_attention_mask: Union[np.ndarray, torch.Tensor] |
| image_bounds: Union[np.ndarray, torch.Tensor] |
|
|
|
|
| class AudioBatchFeature(BatchFeature): |
| r""" |
| Holds the audio features of a batch of audio. |
| """ |
|
|
| input_audios: List[Union[np.ndarray, torch.Tensor]] |
| input_audio_lengths: List[Union[np.ndarray, torch.Tensor]] |
| audio_span_tokens: List[Union[np.ndarray, torch.Tensor]] |
| audio_bounds: Union[np.ndarray, torch.Tensor] |
|
|
|
|
| class ConvContent(UserDict): |
| text: Optional[str] = None |
| image: Optional[ImageInput] = None |
| audio: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None |
|
|
|
|
| class Conversation(UserDict): |
| role: Literal["user", "assistant"] |
| content: Union[str, dict, ConvContent] |
|
|
|
|
| def load_audio( |
| audio: Union[str, bytes], |
| sample_rate: int = 16000, |
| ) -> "np.ndarray": |
| """Load audio from a file path or bytes and return as a numpy array. |
| |
| Args: |
| audio (Union[str, bytes]): path to a audio file or audio bytes. |
| sample_rate (int, optional): sample rate. Defaults to 16000. |
| |
| Raises: |
| ValueError: if the input audio is neither a path nor bytes. |
| |
| Returns: |
| np.ndarray: the audio as a numpy array. |
| """ |
| if isinstance(audio, str): |
| inp = audio |
| out = "-" |
| cmd_inp = None |
| elif isinstance(audio, bytes): |
| inp = "pipe:" |
| out = "pipe:" |
| cmd_inp = audio |
| else: |
| raise ValueError("input audio must be either a path or bytes") |
|
|
| cmd = [ |
| "ffmpeg", |
| "-nostdin", |
| "-threads", |
| "0", |
| "-i", |
| inp, |
| "-f", |
| "s16le", |
| "-ac", |
| "1", |
| "-acodec", |
| "pcm_s16le", |
| "-ar", |
| str(sample_rate), |
| out, |
| ] |
|
|
| out = subprocess.check_output(cmd, input=cmd_inp, stderr=subprocess.PIPE) |
| arr = np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |
| return arr |
|
|
|
|
| def load_image( |
| image: Union[str, bytes, PIL.Image.Image], |
| ) -> PIL.Image.Image: |
| """Load image from a file path or bytes and return as a PIL image. |
| |
| Args: |
| image (Union[str, bytes, PIL.Image.Image]): path to an image file, image bytes or a PIL image. |
| |
| Raises: |
| ValueError: if the input image is neither a path nor bytes. |
| |
| Returns: |
| PIL.Image.Image: the image as a PIL image. |
| """ |
| if isinstance(image, PIL.Image.Image): |
| return image |
|
|
| if isinstance(image, str): |
| img = PIL.Image.open(image) |
| elif isinstance(image, bytes): |
| img = PIL.Image.open(io.BytesIO(image)) |
| else: |
| raise ValueError("input image must be either a path or bytes") |
|
|
| return img |
|
|
|
|
| class MegrezOProcessor(ProcessorMixin): |
| attributes = ["image_processor", "audio_feature_extractor", "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| audio_feature_extractor_class = "WhisperFeatureExtractor" |
| tokenizer_class = "AutoTokenizer" |
|
|
| _image_placeholder = r"(<image>./</image>)" |
| _audio_placeholder = r"(<audio>./</audio>)" |
|
|
| def __init__(self, image_processor=None, audio_feature_extractor=None, tokenizer=None): |
| super().__init__(image_processor, audio_feature_extractor, tokenizer) |
| self.chat_template = self.tokenizer.chat_template |
|
|
| def _parse_and_check_inputs(self, inputs) -> List[Conversation]: |
| if not isinstance(inputs, list): |
| raise ValueError("inputs must be a list of conversations") |
|
|
| conversations = [] |
| images = [] |
| audios = [] |
|
|
| for input in inputs: |
| if not isinstance(input, dict) and not isinstance(input, Conversation): |
| raise ValueError("each element of inputs must be a dictionary or a Conversation object") |
|
|
| role = input.get("role") |
| content = input.get("content") |
| if role is None or content is None: |
| raise ValueError("role and content must be provided in each conversation") |
|
|
| if isinstance(content, str): |
| content = content |
| elif isinstance(content, dict): |
| content = ConvContent({**content}) |
| elif not isinstance(content, ConvContent): |
| raise ValueError("content must be a dictionary or a ConvContent object") |
|
|
| if not isinstance(content, str): |
| if content.get("image") is not None: |
| images.extend(content["image"] if isinstance(content["image"], list) else [content["image"]]) |
|
|
| if content.get("audio") is not None: |
| audios.extend(content["audio"] if isinstance(content["audio"], list) else [content["audio"]]) |
|
|
| conv = Conversation({"role": role, "content": content}) |
| conversations.append(conv) |
|
|
| return conversations, images, audios |
|
|
| def __call__( |
| self, |
| conversations: List[Conversation], |
| apply_chat_template: bool = True, |
| max_length: Optional[int] = None, |
| return_tensors: ReturnTensorType = TensorType.PYTORCH, |
| apply_data_collator: bool = True, |
| **kwargs, |
| ): |
| assert return_tensors is TensorType.PYTORCH, "Only PyTorch tensors are supported for now." |
| convs, images, audios = self._parse_and_check_inputs(conversations) |
| add_generation_prompt = kwargs.pop("add_generation_prompt", True) |
| if apply_chat_template: |
| prompt = self.tokenizer.apply_chat_template( |
| convs, |
| tokenize=False, |
| add_generation_prompt=add_generation_prompt, |
| ) |
| else: |
| prompt = "\n".join([conv["content"] for conv in convs]) |
|
|
| prompt, multimodal_inputs = self.process_multimodal_inputs( |
| prompt, |
| images=images, |
| audios=audios, |
| return_tensors=return_tensors, |
| **kwargs, |
| ) |
| text_encodings = self.tokenizer( |
| prompt, |
| return_tensors=return_tensors, |
| max_length=max_length, |
| padding=True, |
| padding_side="left", |
| truncation=True, |
| **kwargs, |
| ) |
|
|
| merged = self.merge_encodings(text_encodings, multimodal_inputs) |
|
|
| if apply_data_collator: |
| return self.data_collator([merged]) |
|
|
| return merged |
|
|
| def merge_encodings(self, text_encodings, multimodal_inputs): |
|
|
| result = { |
| "image_encoding": None, |
| "audio_encoding": None, |
| } |
|
|
| result["input_ids"] = text_encodings["input_ids"].reshape(-1).to(torch.int32) |
| result["attention_mask"] = result["input_ids"].ne(0) |
| result["position_ids"] = torch.arange(result["input_ids"].size(0)).long() |
|
|
| if "image_encoding" in multimodal_inputs and multimodal_inputs["image_encoding"]: |
| result["image_encoding"] = multimodal_inputs["image_encoding"] |
| result["image_encoding"]["image_bounds"] = self.compute_bounds_image(result["input_ids"]) |
|
|
| if "audio_encoding" in multimodal_inputs and multimodal_inputs["audio_encoding"]: |
| result["audio_encoding"] = multimodal_inputs["audio_encoding"] |
| result["audio_encoding"]["audio_bounds"] = self.compute_bounds_audio(result["input_ids"]) |
|
|
| return result |
|
|
| def compute_bounds_image(self, input_ids: torch.Tensor) -> List[torch.Tensor]: |
| image_start_ids = ( |
| torch.where((input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id))[0] + 1 |
| ) |
| image_end_ids = torch.where( |
| (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id) |
| )[0] |
|
|
| valid_image_nums = max(len(image_start_ids), len(image_end_ids)) |
| bounds_image = torch.hstack( |
| [ |
| image_start_ids[:valid_image_nums].unsqueeze(-1), |
| image_end_ids[:valid_image_nums].unsqueeze(-1), |
| ] |
| ) |
| return bounds_image |
|
|
| def compute_bounds_audio(self, input_ids: torch.Tensor) -> torch.Tensor: |
| audio_bos_ids = torch.where(input_ids == self.tokenizer.audio_start_id)[0] |
| audio_eos_ids = torch.where(input_ids == self.tokenizer.audio_end_id)[0] |
| bounds_audio = torch.stack([audio_bos_ids, audio_eos_ids], 1) |
| return bounds_audio |
|
|
| def process_multimodal_inputs( |
| self, |
| text: str, |
| images: Optional[ImageInput] = None, |
| audios: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None, |
| return_tensors: ReturnTensorType = TensorType.PYTORCH, |
| **kwargs, |
| ): |
| |
| |
| if text is None and images is None and audios is None: |
| raise ValueError("At least one of text, images or audio must be provided") |
|
|
| image_processor_kwargs, audio_feature_extractor_kwargs = {}, {} |
| if kwargs: |
| image_processor_kwargs = { |
| k: v for k, v in kwargs.items() if k in self.image_processor._valid_processor_keys |
| } |
| audio_feature_extractor_kwargs = { |
| k: v for k, v in kwargs.items() if k in self.audio_feature_extractor._valid_processor_keys |
| } |
|
|
| multimodal_encodings = { |
| "image_encoding": None, |
| "audio_encoding": None, |
| } |
|
|
| if images: |
| image_encoding = self.process_image( |
| images, |
| return_tensors=return_tensors, |
| **image_processor_kwargs, |
| ) |
| text = self.insert_image_feature_placeholders(text, image_encoding) |
| multimodal_encodings["image_encoding"] = image_encoding |
|
|
| if audios: |
| audio_encoding = self.process_audio( |
| audios, |
| **audio_feature_extractor_kwargs, |
| ) |
| text = self.insert_audio_feature_placeholders(text, audio_encoding) |
| multimodal_encodings["audio_encoding"] = audio_encoding |
|
|
| return text, multimodal_encodings |
|
|
| def insert_image_feature_placeholders( |
| self, |
| prompt: str, |
| image_features: ImageBatchFeature, |
| max_slice_nums: Optional[int] = None, |
| use_image_id: Optional[bool] = None, |
| ) -> List[str]: |
| |
| img_tags = re.findall(self._image_placeholder, prompt) |
| assert len(img_tags) == len( |
| image_features.image_sizes |
| ), f"the number of image tags must match the number of images, got {len(img_tags)} and {len(image_features.image_sizes)}" |
|
|
| |
| text_chunks = prompt.split(self._image_placeholder) |
| final_text = "" |
| for i in range(len(img_tags)): |
| final_text += text_chunks[i] + self.image_processor.get_slice_image_placeholder( |
| image_features.image_sizes[i], |
| i, |
| max_slice_nums, |
| use_image_id, |
| ) |
| final_text += text_chunks[-1] |
|
|
| return final_text |
|
|
| def insert_audio_feature_placeholders( |
| self, |
| prompt: str, |
| audio_features: AudioBatchFeature, |
| ) -> List[str]: |
| |
| audio_tags = re.findall(self._audio_placeholder, prompt) |
| assert len(audio_tags) == len( |
| audio_features.input_audios |
| ), "the number of audio tags must match the number of audios" |
|
|
| |
| text_chunks = prompt.split(self._audio_placeholder) |
| final_text = "" |
| for idx in range(len(audio_features.input_audios)): |
| final_text += text_chunks[idx] + ( |
| self.tokenizer.audio_start |
| + self.tokenizer.unk_token * audio_features.audio_span_tokens[idx] |
| + self.tokenizer.audio_end |
| ) |
| final_text += text_chunks[-1] |
|
|
| return final_text |
|
|
| def process_audio( |
| self, |
| audio_input: AudioInput, |
| return_tensors: ReturnTensorType = TensorType.PYTORCH, |
| **kwargs, |
| ) -> AudioBatchFeature: |
| if isinstance(audio_input, list): |
| inputs = [load_audio(x) for x in audio_input] |
| elif isinstance(audio_input, (str, bytes, np.ndarray)): |
| inputs = [load_audio(audio_input)] |
| else: |
| raise ValueError("audio_input must be a path or bytes or a list of paths/bytes") |
|
|
| features = self.audio_feature_extractor( |
| inputs, |
| sampling_rate=self.audio_feature_extractor.sampling_rate, |
| return_attention_mask=True, |
| return_token_timestamps=True, |
| padding="max_length", |
| return_tensors=return_tensors, |
| **kwargs, |
| ) |
|
|
| input_lengths = features["num_frames"] |
| input_lengths = (input_lengths - 1) // 2 + 1 |
| output_lengths = (input_lengths - 2) // 2 + 1 |
| input_audio_lengths = torch.stack([input_lengths, output_lengths], dim=1) |
| audio_span_tokens = (output_lengths + 2).tolist() |
|
|
| data = { |
| "input_audios": features["input_features"], |
| "input_audio_lengths": input_audio_lengths, |
| "audio_span_tokens": audio_span_tokens, |
| } |
|
|
| |
| return AudioBatchFeature(data={**data}) |
|
|
| def pad_images( |
| self, |
| pixel_values_list: List[torch.Tensor], |
| tgt_sizes: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Pad images to the same size and return the padded pixel values and patch attention mask. |
| |
| Sliced pataches may have different sizes. We pad them to the same size and return the padded pixel values and corresponding patch attention mask. |
| """ |
|
|
| all_pixel_values = [] |
| for pixel_value in pixel_values_list: |
| all_pixel_values.append(pixel_value.flatten(end_dim=1).permute(1, 0)) |
|
|
| max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) |
| all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0) |
| B, L, _ = all_pixel_values.shape |
| all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) |
|
|
| patch_attention_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool) |
| for i in range(B): |
| patch_attention_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True |
|
|
| return all_pixel_values, patch_attention_mask |
|
|
| def process_image( |
| self, |
| image_input: ImageInput, |
| do_pad: bool = True, |
| max_slice_nums: Optional[int] = None, |
| return_tensors: ReturnTensorType = TensorType.PYTORCH, |
| **kwargs, |
| ) -> ImageBatchFeature: |
| if isinstance(image_input, list): |
| image_input = [load_image(x) for x in image_input] |
| elif isinstance(image_input, (str, bytes, PIL.Image.Image)): |
| image_input = [load_image(image_input)] |
| else: |
| raise ValueError(f"image_input must be a path or bytes or a list of paths/bytes, not: {type(image_input)}") |
|
|
| image_features = self.image_processor( |
| image_input, |
| do_pad=do_pad, |
| max_slice_nums=max_slice_nums, |
| return_tensors=return_tensors, |
| **kwargs, |
| ) |
|
|
| |
| assert len(image_features.pixel_values) == 1, "images should be packed into one list." |
| pixel_values = image_features.pixel_values[0] |
| tgt_sizes = image_features.tgt_sizes[0] |
| image_sizes = image_features.image_sizes[0] |
|
|
| pixel_values, patch_attention_mask = self.pad_images(pixel_values, tgt_sizes) |
|
|
| data = { |
| "pixel_values": pixel_values, |
| "image_sizes": image_sizes, |
| "tgt_sizes": tgt_sizes, |
| "patch_attention_mask": patch_attention_mask, |
| } |
|
|
| |
| return ImageBatchFeature(data=data) |
|
|
| def data_collator(self, examples, padding_value=0, max_length=4096, collate_labels=False): |
| """Collate data for MegrezO model. |
| |
| Batch data for MegrezO model. This function trims and pads the input_ids, position_ids, and attention_mask tensors. For bounds tensors, it adds batch index to the bounds. |
| """ |
| |
|
|
| def trim_and_pad(seq, batch_first, padding_value): |
| return pad_sequence( |
| [s[:max_length] for s in seq], |
| batch_first=True, |
| padding_value=padding_value, |
| ) |
|
|
| input_ids = trim_and_pad( |
| [example["input_ids"] for example in examples], |
| batch_first=True, |
| padding_value=padding_value, |
| ) |
| position_ids = trim_and_pad( |
| [example["position_ids"] for example in examples], |
| batch_first=True, |
| padding_value=padding_value, |
| ) |
|
|
| attention_mask = trim_and_pad( |
| [example["attention_mask"] for example in examples], |
| batch_first=True, |
| padding_value=padding_value, |
| ) |
|
|
| image_encoding_list = { |
| "pixel_values": [], |
| "image_bounds": [], |
| "tgt_sizes": [], |
| "patch_attention_mask": [], |
| } |
| for bid, example in enumerate(examples): |
| image_encoding = example.get("image_encoding") |
| if not image_encoding: |
| continue |
|
|
| image_encoding_list["pixel_values"].append(image_encoding["pixel_values"]) |
| image_encoding_list["tgt_sizes"].append(image_encoding["tgt_sizes"]) |
| image_encoding_list["patch_attention_mask"].append(image_encoding["patch_attention_mask"]) |
|
|
| |
| |
| bounds_with_bid = image_encoding["image_bounds"].clone() |
| bounds_with_bid = torch.hstack( |
| [ |
| torch.full((bounds_with_bid.size(0), 1), bid, dtype=bounds_with_bid.dtype), |
| bounds_with_bid, |
| ] |
| ) |
| image_encoding_list["image_bounds"].append(bounds_with_bid) |
|
|
| audio_encoding_list = { |
| "input_audios": [], |
| "input_audio_lengths": [], |
| "audio_span_tokens": [], |
| "audio_bounds": [], |
| } |
| for bid, example in enumerate(examples): |
| audio_encoding = example.get("audio_encoding") |
| if not audio_encoding: |
| continue |
|
|
| audio_encoding_list["input_audios"].append(audio_encoding["input_audios"]) |
| audio_encoding_list["input_audio_lengths"].append(audio_encoding["input_audio_lengths"]) |
| audio_encoding_list["audio_span_tokens"].extend(audio_encoding["audio_span_tokens"]) |
| bounds_with_bid = audio_encoding["audio_bounds"].clone() |
| bounds_with_bid = torch.hstack( |
| [ |
| torch.full((bounds_with_bid.size(0), 1), bid, dtype=bounds_with_bid.dtype), |
| bounds_with_bid, |
| ] |
| ) |
| audio_encoding_list["audio_bounds"].append(bounds_with_bid) |
|
|
| result = { |
| "input_ids": input_ids, |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| "image_encoding": None, |
| "audio_encoding": None, |
| } |
|
|
| if collate_labels: |
| labels = trim_and_pad( |
| [example["labels"] for example in examples], |
| batch_first=True, |
| padding_value=-100, |
| ) |
| result["labels"] = labels |
|
|
| if any(image_encoding_list.values()): |
| result["image_encoding"] = { |
| "pixel_values": torch.vstack(image_encoding_list["pixel_values"]), |
| "tgt_sizes": torch.vstack(image_encoding_list["tgt_sizes"]), |
| "patch_attention_mask": torch.vstack(image_encoding_list["patch_attention_mask"]), |
| "image_bounds": torch.vstack(image_encoding_list["image_bounds"]), |
| } |
| if any(audio_encoding_list.values()): |
| result["audio_encoding"] = { |
| "input_audios": torch.vstack(audio_encoding_list["input_audios"]), |
| "input_audio_lengths": torch.vstack(audio_encoding_list["input_audio_lengths"]), |
| "audio_span_tokens": audio_encoding_list["audio_span_tokens"], |
| "audio_bounds": torch.vstack(audio_encoding_list["audio_bounds"]), |
| } |
| return result |
|
|