| import torch |
| from PIL import Image |
| from transformers import AutoImageProcessor |
| from transformers.processing_utils import ProcessorMixin |
| from transformers.tokenization_utils_base import BatchEncoding |
|
|
| from .tokenization_m4cxr import MllmTokenizer |
|
|
| |
| SYSTEM_MESSAGE = "The following is a conversation between a curious human and an AI medical assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " |
|
|
|
|
| def load_images(image_paths: list[str] | list[Image.Image] | str) -> list[Image.Image]: |
| if isinstance(image_paths, str): |
| image_paths = [image_paths] |
| return [ |
| ( |
| Image.open(image_path).convert("RGB") |
| if isinstance(image_path, str) |
| else image_path |
| ) |
| for image_path in image_paths |
| ] |
|
|
|
|
| class MllmProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "AutoImageProcessor" |
| tokenizer_class = "MllmTokenizer" |
|
|
| def __init__(self, image_processor, tokenizer): |
| self.image_processor = image_processor |
| self.tokenizer = tokenizer |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| |
| image_processor = AutoImageProcessor.from_pretrained( |
| pretrained_model_name_or_path, **kwargs |
| ) |
| |
| tokenizer = MllmTokenizer.from_pretrained( |
| pretrained_model_name_or_path, **kwargs |
| ) |
| return cls(image_processor=image_processor, tokenizer=tokenizer) |
|
|
| def __call__(self, texts=None, images=None, return_tensors="pt"): |
| if images: |
| images = load_images(images) |
| else: |
| images = None |
|
|
| if texts is None and images is None: |
| raise ValueError( |
| "You have to specify either texts or images. Both cannot be none." |
| ) |
|
|
| if texts is not None: |
| |
| encoding = self.tokenizer.batch_encode_prompt( |
| prompts=texts, padding_side="left", no_eos=True |
| ) |
|
|
| if images is not None: |
| images = [ |
| image for image in images if image is not None |
| ] |
| image_features = torch.cat( |
| [ |
| self.image_processor(image, return_tensors="pt")["pixel_values"] |
| for image in images |
| ], |
| dim=0, |
| ) |
|
|
| if texts is not None and images is not None: |
| encoding["pixel_values"] = image_features |
| return BatchEncoding(data=encoding, tensor_type=return_tensors) |
| elif texts is not None: |
| return BatchEncoding(data=encoding, tensor_type=return_tensors) |
| else: |
| return BatchEncoding( |
| data=dict(pixel_values=image_features), tensor_type=return_tensors |
| ) |
|
|
| def batch_decode(self, *args, **kwargs): |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| def apply_chat_template(self, chats, system_prompt=SYSTEM_MESSAGE, *args, **kwargs): |
| chats[0]["content"] = system_prompt + chats[0]["content"] |
| return self.tokenizer.apply_chat_template(chats, *args, **kwargs) |
|
|