File size: 3,881 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from PIL import Image

from transformers import Qwen2VLProcessor, AutoProcessor
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessorKwargs


class QwenVLProcessor:
    ROLE = ('user', 'assistant')

    def __init__(self, max_length=512, pretrained_model_name_or_path=None):
        self.processor = AutoProcessor.from_pretrained(
            pretrained_model_name_or_path)
        self.max_length = max_length

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.processor, name)

    def build_prompt(self, query, answer, round=0, system=None):
        messages = [{"role": self.ROLE[0], "content": query}]
        if round == 0 and system:
            messages.insert(0, {"role": "system", "content": system})

        if answer is None:
            query = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True)
        else:
            messages.append({"role": self.ROLE[1], "content": answer})
            prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False)
            query, answer = prompt.split("<|im_start|>assistant\n")
            query += "<|im_start|>assistant\n"

        return query, answer

    def __call__(self, data_dict, **kwargs):
        conversations = data_dict["conversations"]
        images = data_dict.get("images", None)
        videos = data_dict.get("videos", None)

        images = data_dict.get("image", None)  # HACK: support multi images
        if images is not None:
            images = [Image.open(images).convert('RGB')]

        output_kwargs = self._merge_kwargs(
            Qwen2VLProcessorKwargs,
            tokenizer_init_kwargs=self.tokenizer.init_kwargs,
            **kwargs,
        )
        if images is not None:
            image_inputs = self.image_processor(
                images=images, videos=None, **output_kwargs["images_kwargs"])
            image_grid_thw = image_inputs["image_grid_thw"]
        else:
            image_inputs = {}
            image_grid_thw = None

        new_conversation = []
        index = 0
        for msg in conversations:
            if msg['from'] == 'human':
                if image_grid_thw is not None:
                    merge_length = self.image_processor.merge_size**2
                    text = msg['value']
                    while "<image>" in text:
                        text = text.replace(
                            "<image>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1)
                        index += 1
                    text = text.replace("<|placeholder|>", "<|image_pad|>")
                    msg['value'] = text
            new_conversation.append(msg)

        input_ids, labels = [], []
        for i in range(0, len(new_conversation), 2):
            query = new_conversation[i]['value']
            answer = new_conversation[i+1]['value'] if i + \
                1 < len(new_conversation) else None
            query, answer = self.build_prompt(query, answer, round=i // 2)

            input_ids_ = self.tokenizer(
                query, add_special_tokens=True, return_attention_mask=False)['input_ids']
            labels_ = [-100] * len(input_ids_)
            if answer is not None:
                output_ids_ = self.tokenizer(answer, add_special_tokens=True,
                                             return_attention_mask=False)['input_ids']
                labels_ += output_ids_
                input_ids_ += output_ids_
            input_ids += input_ids_
            labels += labels_

        return {
            "input_ids": input_ids,
            "labels": labels,
            'pixel_values': image_inputs.get('pixel_values', None),
        }