import torch from PIL import Image from transformers import AutoImageProcessor from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import BatchEncoding from .tokenization_m4cxr import MllmTokenizer # system message SYSTEM_MESSAGE = "The following is a conversation between a curious human and an AI medical assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " def load_images(image_paths: list[str] | list[Image.Image] | str) -> list[Image.Image]: if isinstance(image_paths, str): image_paths = [image_paths] return [ ( Image.open(image_path).convert("RGB") if isinstance(image_path, str) else image_path ) for image_path in image_paths ] class MllmProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "MllmTokenizer" def __init__(self, image_processor, tokenizer): self.image_processor = image_processor self.tokenizer = tokenizer @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): # Load the image processor image_processor = AutoImageProcessor.from_pretrained( pretrained_model_name_or_path, **kwargs ) # Load the custom tokenizer tokenizer = MllmTokenizer.from_pretrained( pretrained_model_name_or_path, **kwargs ) return cls(image_processor=image_processor, tokenizer=tokenizer) def __call__(self, texts=None, images=None, return_tensors="pt"): if images: images = load_images(images) else: images = None if texts is None and images is None: raise ValueError( "You have to specify either texts or images. Both cannot be none." ) if texts is not None: # Return keys: ['input_ids', 'attention_mask'] encoding = self.tokenizer.batch_encode_prompt( prompts=texts, padding_side="left", no_eos=True ) if images is not None: images = [ image for image in images if image is not None ] # filter out none images image_features = torch.cat( [ self.image_processor(image, return_tensors="pt")["pixel_values"] for image in images ], dim=0, ) if texts is not None and images is not None: encoding["pixel_values"] = image_features return BatchEncoding(data=encoding, tensor_type=return_tensors) elif texts is not None: return BatchEncoding(data=encoding, tensor_type=return_tensors) else: return BatchEncoding( data=dict(pixel_values=image_features), tensor_type=return_tensors ) def batch_decode(self, *args, **kwargs): return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) def apply_chat_template(self, chats, system_prompt=SYSTEM_MESSAGE, *args, **kwargs): chats[0]["content"] = system_prompt + chats[0]["content"] return self.tokenizer.apply_chat_template(chats, *args, **kwargs)