from typing import Union, Optional, List, Dict, Tuple, Callable from transformers.processing_utils import (ProcessorMixin, VideosKwargs, AudioKwargs, ImagesKwargs, TextKwargs, ProcessingKwargs, Unpack) import numpy as np import decord import torch import PIL from transformers.audio_utils import load_audio from transformers.image_utils import load_image, load_video from transformers import AutoImageProcessor, AutoFeatureExtractor, AutoTokenizer def load_audio_str(audio_path_or_url: str, sampling_rate: int = 16000) -> np.ndarray: audio = load_audio(audio_path_or_url, sampling_rate=sampling_rate) return audio def load_video_str(video_path_or_url: str, num_frames: int = 4, fps: int = None) -> List[np.ndarray]: video = load_video(video_path_or_url, num_frames=num_frames, fps=fps, backend="decord") return video def load_image_str(image_path_or_url: str) -> List[np.ndarray]: image = load_image(image_path_or_url) return image ImageInput = Union[ # same as transformers.image_utils.ImageInput "PIL.Image.Image", np.ndarray, "torch.Tensor", list["PIL.Image.Image"], list[np.ndarray], list["torch.Tensor"], # image urls, or image_paths str, list[str] ] VideoInput = Union[ # same as transformers.image_utils.VideoInput list["PIL.Image.Image"], "np.ndarray", "torch.Tensor", list["np.ndarray"], list["torch.Tensor"], list[list["PIL.Image.Image"]], list[list["np.ndarray"]], list[list["torch.Tensor"]], # video urls, or video_paths str, list[str], list[list[str]] ] AudioInput = Union[ # same as transformers.audio_utils.AudioInput np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"], # noqa: F821 # audio urls, or audio_paths str, list[str] ] class QualityvImageKwargs(ImagesKwargs): tokens_per_image: int = 197 class QualityvVideoKwargs(VideosKwargs): num_frames: Union[int, None] = 4 fps: Union[int, None] = None tokens_per_frame: int = 197 class QualityvAudioKwargs(AudioKwargs): sampling_rate: Union[int, None] = 16000 tokens_per_audio: int = 1500 class QualityvProcessingKwargs(ProcessingKwargs): images_kwargs: QualityvImageKwargs videos_kwargs: QualityvVideoKwargs audio_kwargs: QualityvAudioKwargs text_kwargs: TextKwargs class QualityvProcessor(ProcessorMixin): attributes = ["image_processor", "audio_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" audio_processor_class = "AutoFeatureExtractor" tokenizer_class = "AutoTokenizer" chat_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system You are a helpful assistant.<|im_end|> {% endif %}<|im_start|>{{ message['role'] }} {% if message['content'] is string %}{{ message['content'] }}<|im_end|> {% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'audio' or 'audio' in content %}{% set audio_count.value = audio_count.value + 1 %}{% if add_vision_id %}Audio {{ audio_count.value }}: {% endif %}<|vision_start|><|audio_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|> {% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant {% endif %}""" def __init__(self, tokenizer=None, image_processor=None, audio_processor=None, chat_template=None, image_token="<|image_pad|>", video_token="<|video_pad|>", audio_token="<|audio_pad|>", label_start_text="<|im_start|>assistant\n", label_end_text="<|im_end|>\n", **kwargs): self.image_token = image_token if not hasattr(tokenizer, "image_token") else tokenizer.image_token self.video_token = video_token if not hasattr(tokenizer, "video_token") else tokenizer.video_token self.audio_token = audio_token if not hasattr(tokenizer, "audio_token") else tokenizer.audio_token self.label_start_text = label_start_text self.label_end_text = label_end_text self.image_token_id = ( tokenizer.image_token_id if getattr(tokenizer, "image_token_id", None) else tokenizer.convert_tokens_to_ids(self.image_token) ) self.video_token_id = ( tokenizer.video_token_id if getattr(tokenizer, "video_token_id", None) else tokenizer.convert_tokens_to_ids(self.video_token) ) self.audio_token_id = ( tokenizer.audio_token_id if getattr(tokenizer, "audio_token_id", None) else tokenizer.convert_tokens_to_ids(self.audio_token) ) if chat_template is None: chat_template = self.chat_template super().__init__(image_processor, audio_processor, tokenizer, chat_template=chat_template) def __call__(self, text: Union[str, List[str], None] = None, messages: Union[List[Dict], None] = None, images: Union[ImageInput, None] = None, videos: Union[VideoInput, None] = None, audio: Union[AudioInput, None] = None, do_train: bool = False, add_generation_prompt: bool = False, **kwargs: Unpack[QualityvProcessingKwargs] ): ''' input messages: list of dicts example: [ {"role": "user" "content": [ {"type": "text", "text": "Hello, how are you?"}, {"type": "image", "image":xxx)}, {"type": "video", "video": xxx}, ] }, ... ] output: input_ids attention_mask pixel_values, pixel_values_videos audio_values labels, default None, ''' input_ids = [] pixel_values = [] pixel_values_videos = [] audio_values = [] labels = None if not text and not messages: raise ValueError("At least one of text or messages must be provided.") if messages: text = self.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, tokenize=False) if isinstance(text, list): text = text[0] image_list = self.fill_modal_list(self.image_token, "image", messages, images, text) image_list = self.process_str_in_modal_list(image_list, "image", **kwargs.get("images_kwargs", {})) # replace image_token with num_images * num_image_token * image_token if image_list and self.image_token in text: tokens_per_image = kwargs.get("images_kwargs", {}).get("tokens_per_image", 197) text = text.replace(self.image_token, tokens_per_image * self.image_token) pixel_values = self.image_processor(images=image_list, return_tensors="pt")["pixel_values"] video_list = self.fill_modal_list(self.video_token, "video", messages, videos, text) video_list = self.process_str_in_modal_list(video_list, "video", **kwargs.get("videos_kwargs", {})) # replace video_token with num_videos * num_video_token * video_token if video_list and self.video_token in text: tokens_per_frame = kwargs.get("videos_kwargs", {}).get("tokens_per_frame", 197) video_frame_list = [] for video, video_meta in video_list: num_frames = video.shape[0] replace_text = num_frames * tokens_per_frame * self.video_token text = text.replace(self.video_token, replace_text, 1) for frame in video: video_frame_list.append(frame) pixel_values_videos = self.image_processor(images=video_frame_list, return_tensors="pt")["pixel_values"] audio_list = self.fill_modal_list(self.audio_token, "audio", messages, audio, text) audio_list = self.process_str_in_modal_list(audio_list, "audio", **kwargs.get("audio_kwargs", {})) # replace audio_token with num_audio_tokens * audio_token if audio_list and self.audio_token in text: audio_kwargs = kwargs.get("audio_kwargs", {}) sampling_rate = audio_kwargs.get("sampling_rate", 16000) tokens_per_audio = audio_kwargs.get("tokens_per_audio", 1500) for audio in audio_list: replace_text = tokens_per_audio * self.audio_token text = text.replace(self.audio_token, replace_text, 1) audio_values = self.audio_processor(audio_list, return_tensors="pt", sampling_rate=sampling_rate)["input_features"] input_ids = self.tokenizer(text).input_ids if do_train: labels = self.get_labels(input_ids) labels = torch.tensor(labels, dtype=torch.long) input_ids = torch.tensor(input_ids, dtype=torch.long) return { "input_ids": input_ids, "pixel_values": pixel_values if len(pixel_values) > 0 else None, "pixel_values_videos": pixel_values_videos if len(pixel_values_videos) > 0 else None, "audio_values": audio_values if len(audio_values) > 0 else None, "labels": labels } def fill_modal_list(self, modal_token: str, model_type: str, messages: List[Dict], modal_values: Union[AudioInput, VideoInput, ImageInput, None], text: str) -> List[Union[AudioInput, VideoInput, ImageInput]]: modal_list = [] if modal_token in text: if not modal_values and messages: for msg in messages: if msg.get("role") == "user": for content in msg.get("content", []): if content.get('type') == model_type: modal_list.append(content.get(model_type)) elif modal_values: if isinstance(modal_values, str): modal_list = [modal_values] else: modal_list = modal_values return modal_list def process_str_in_modal_list(self, modal_list: list, modal_type: str, **modal_kwargs: dict): new_modal_list = [] if modal_list: for modal_value in modal_list: if isinstance(modal_value, str): new_modal_value = self.load_modal_str(modal_value, modal_type, **modal_kwargs) new_modal_list.append(new_modal_value) else: new_modal_list.append(modal_value) return new_modal_list def load_modal_str(self, model_path_or_url: str, modal_type: str, **modal_kwargs): if modal_type == "image": load_func = load_image_str elif modal_type == "video": load_func = load_video_str elif modal_type == "audio": load_func = load_audio_str else: raise ValueError(f"Invalid modal type: {modal_type}") return load_func(model_path_or_url, **modal_kwargs) def get_labels(self, input_ids: List[int]) -> List[int]: label_start_token_ids = self.tokenizer(self.label_start_text, add_special_tokens=False)["input_ids"] label_end_token_ids = self.tokenizer(self.label_end_text, add_special_tokens=False)["input_ids"] labels = [-100] * len(input_ids) i = 0 while i < len(input_ids): # Look for the assistant's response start marker. if input_ids[i : i + len(label_start_token_ids)] == label_start_token_ids: # The actual response begins after the start marker. start_response = i + len(label_start_token_ids) # Now, search for the end marker. j = start_response found_end = False while j < len(input_ids): if input_ids[j : j + len(label_end_token_ids)] == label_end_token_ids: end_response = j + len(label_end_token_ids) # Mark the end of the response (excluding the end marker) found_end = True break j += 1 if found_end: # Copy the tokens corresponding to the assistant's response into labels. labels[start_response:end_response] = input_ids[start_response:end_response] # Advance i beyond the end marker. i = end_response continue # Continue scanning for the next assistant response. else: # If no end marker is found, break out of the loop. break else: i += 1 pad_token_id = self.tokenizer.pad_token_id if pad_token_id is not None: for i in range(len(labels)): if labels[i] == pad_token_id: labels[i] = -100 return labels def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs)