M4CXR-TNNLS / processing_m4cxr.py
jonggwon-park's picture
add custom model codes
6159bde
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding
from .tokenization_m4cxr import MllmTokenizer
# system message
SYSTEM_MESSAGE = "The following is a conversation between a curious human and an AI medical assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
def load_images(image_paths: list[str] | list[Image.Image] | str) -> list[Image.Image]:
if isinstance(image_paths, str):
image_paths = [image_paths]
return [
(
Image.open(image_path).convert("RGB")
if isinstance(image_path, str)
else image_path
)
for image_path in image_paths
]
class MllmProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "MllmTokenizer"
def __init__(self, image_processor, tokenizer):
self.image_processor = image_processor
self.tokenizer = tokenizer
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
# Load the image processor
image_processor = AutoImageProcessor.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
# Load the custom tokenizer
tokenizer = MllmTokenizer.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
return cls(image_processor=image_processor, tokenizer=tokenizer)
def __call__(self, texts=None, images=None, return_tensors="pt"):
if images:
images = load_images(images)
else:
images = None
if texts is None and images is None:
raise ValueError(
"You have to specify either texts or images. Both cannot be none."
)
if texts is not None:
# Return keys: ['input_ids', 'attention_mask']
encoding = self.tokenizer.batch_encode_prompt(
prompts=texts, padding_side="left", no_eos=True
)
if images is not None:
images = [
image for image in images if image is not None
] # filter out none images
image_features = torch.cat(
[
self.image_processor(image, return_tensors="pt")["pixel_values"]
for image in images
],
dim=0,
)
if texts is not None and images is not None:
encoding["pixel_values"] = image_features
return BatchEncoding(data=encoding, tensor_type=return_tensors)
elif texts is not None:
return BatchEncoding(data=encoding, tensor_type=return_tensors)
else:
return BatchEncoding(
data=dict(pixel_values=image_features), tensor_type=return_tensors
)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def apply_chat_template(self, chats, system_prompt=SYSTEM_MESSAGE, *args, **kwargs):
chats[0]["content"] = system_prompt + chats[0]["content"]
return self.tokenizer.apply_chat_template(chats, *args, **kwargs)