File size: 3,429 Bytes
6159bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding

from .tokenization_m4cxr import MllmTokenizer

# system message
SYSTEM_MESSAGE = "The following is a conversation between a curious human and an AI medical assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "


def load_images(image_paths: list[str] | list[Image.Image] | str) -> list[Image.Image]:
    if isinstance(image_paths, str):
        image_paths = [image_paths]
    return [
        (
            Image.open(image_path).convert("RGB")
            if isinstance(image_path, str)
            else image_path
        )
        for image_path in image_paths
    ]


class MllmProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = "MllmTokenizer"

    def __init__(self, image_processor, tokenizer):
        self.image_processor = image_processor
        self.tokenizer = tokenizer

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # Load the image processor
        image_processor = AutoImageProcessor.from_pretrained(
            pretrained_model_name_or_path, **kwargs
        )
        # Load the custom tokenizer
        tokenizer = MllmTokenizer.from_pretrained(
            pretrained_model_name_or_path, **kwargs
        )
        return cls(image_processor=image_processor, tokenizer=tokenizer)

    def __call__(self, texts=None, images=None, return_tensors="pt"):
        if images:
            images = load_images(images)
        else:
            images = None

        if texts is None and images is None:
            raise ValueError(
                "You have to specify either texts or images. Both cannot be none."
            )

        if texts is not None:
            # Return keys: ['input_ids', 'attention_mask']
            encoding = self.tokenizer.batch_encode_prompt(
                prompts=texts, padding_side="left", no_eos=True
            )

        if images is not None:
            images = [
                image for image in images if image is not None
            ]  # filter out none images
            image_features = torch.cat(
                [
                    self.image_processor(image, return_tensors="pt")["pixel_values"]
                    for image in images
                ],
                dim=0,
            )

        if texts is not None and images is not None:
            encoding["pixel_values"] = image_features
            return BatchEncoding(data=encoding, tensor_type=return_tensors)
        elif texts is not None:
            return BatchEncoding(data=encoding, tensor_type=return_tensors)
        else:
            return BatchEncoding(
                data=dict(pixel_values=image_features), tensor_type=return_tensors
            )

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    def apply_chat_template(self, chats, system_prompt=SYSTEM_MESSAGE, *args, **kwargs):
        chats[0]["content"] = system_prompt + chats[0]["content"]
        return self.tokenizer.apply_chat_template(chats, *args, **kwargs)