File size: 6,334 Bytes
bb46e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from PIL import Image
from scipy import special
import torch
import numpy as np
from math import e
from param import output
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin

class UMMProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "AutoImageProcessor"
    tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")

    def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
        self.image_token = "<image>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
        if getattr(tokenizer, "image_token_id", None):
            self.image_token_id = tokenizer.image_token_id
        else:
            tokenizer.add_tokens(["<image>"], special_tokens=True)
            self.image_token_id = -200
        
        self.image_gen_token = "<image_gen>" if not hasattr(tokenizer, "image_gen_token") else tokenizer.image_gen_token
        if getattr(tokenizer, "image_gen_token_id", None):
            self.image_gen_token_id = tokenizer.image_gen_token_id
        else:
            tokenizer.add_tokens(["<image_gen>"], special_tokens=True)
            self.image_gen_token_id = -300
        
        self.image_gen_start_token = "<im_start>" if not hasattr(tokenizer, "image_gen_start") else tokenizer.image_gen_start
        if getattr(tokenizer, "image_gen_start_token_id", None):
            self.image_gen_start_token_id = tokenizer.image_gen_start_token_id
        else:
            tokenizer.add_tokens(["<im_start>"], special_tokens=True)
            self.image_gen_start_token_id = tokenizer.convert_tokens_to_ids(self.image_gen_start_token)

        self.image_gen_end_token = "<im_end>" if not hasattr(tokenizer, "image_gen_end") else tokenizer.image_gen_end
        if getattr(tokenizer, "image_gen_end_token_id", None):
            self.image_gen_end_token_id = tokenizer.image_gen_end_token_id
        else:
            tokenizer.add_tokens(["<im_end>"], special_tokens=True)
            self.image_gen_end_token_id = tokenizer.convert_tokens_to_ids(self.image_gen_end_token)
        
        self.no_mean_token = "<no_mean>" if not hasattr(tokenizer, "no_mean") else tokenizer.no_mean
        if getattr(tokenizer, "no_mean_id", None):
            self.no_mean_token_id = tokenizer.no_mean_id
        else:
            tokenizer.add_tokens(["<no_mean>"], special_tokens=True)
            self.no_mean_token_id = tokenizer.convert_tokens_to_ids(self.no_mean_token)
            
        if chat_template is None and hasattr(tokenizer, "chat_template"):
            chat_template = tokenizer.chat_template
        super().__init__(image_processor, tokenizer, chat_template=chat_template)
    
    def __call__(self, images=None, text=None, max_resolution=None, add_im_start_id=False, **kwargs):
        if "padding" not in kwargs:
            kwargs["padding"] = True
        if "truncation" not in kwargs:
            kwargs["truncation"] = True
        if not isinstance(text, list):
            text = [text]
        text = text.copy()
        return_tensors = kwargs.pop("return_tensors", None)
        text_inputs = self.tokenizer(text, **kwargs, return_tensors=return_tensors)
        img_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
        img_gen_token_id = self.tokenizer.convert_tokens_to_ids(self.image_gen_token)
        if add_im_start_id:
            B, T = text_inputs["input_ids"].shape
            new_input_ids = torch.full((B, T+1), self.tokenizer.pad_token_id)
            new_input_ids[:, :T] = text_inputs["input_ids"]
            is_valid = (text_inputs["input_ids"] != self.tokenizer.pad_token_id)
            valid_len = is_valid.sum(dim=1)
        else:
            new_input_ids = text_inputs["input_ids"]

        t = []
        und_gen_mask_list = []
        for i, ids in enumerate(text_inputs["input_ids"]):
            for j, token_id in enumerate(ids):
                if token_id == img_token_id:
                    new_input_ids[i][j] = self.image_token_id
                    t.append(torch.tensor([1.0]))
                    und_gen_mask_list.append(1)
                elif token_id == img_gen_token_id:
                    new_input_ids[i][j] = self.image_gen_token_id
                    t.append(torch.rand(1))
                    und_gen_mask_list.append(0)
        
        image_inputs = {}
        pixel_values, grid_hws = [], []
        if images is not None:
            image_idx = 0
            for per_images in images if isinstance(images, list) else [images]:
                if per_images is None:
                    dummy_image = Image.fromarray(np.random.randint(0, 256, (256, 256, 3), dtype=np.uint8))
                    image_info = self.image_processor(images=dummy_image)
                else:
                    for per_image in per_images if isinstance(per_images, list) else[per_images]:
                        if und_gen_mask_list[image_idx] == 0:
                            image_info = self.image_processor(images=per_image, max_resolution=max_resolution, und=False)
                        else:
                            image_info = self.image_processor(images=per_image, max_resolution=max_resolution)
                        image_idx += 1
                pixel_values.append(image_info.pixel_values)
                grid_hws.append(image_info.grid_hws)
            pixel_values = torch.concat(pixel_values, dim=0)
            grid_hws = torch.concat(grid_hws, dim=0)
            image_inputs.update({'pixel_values': pixel_values, 'grid_hws': grid_hws})

        if len(t) > 0:
            t = torch.cat(t)
            image_inputs.update({"t":t})
        if add_im_start_id:
            for b in range(B):
                pos = valid_len[b].item()
                new_input_ids[b, pos] = self.image_gen_start_token_id
            attention_mask = torch.cat([
                text_inputs["attention_mask"],
                (new_input_ids[:, -1] != self.tokenizer.pad_token_id).long().unsqueeze(1)
            ], dim=1)
            text_inputs["attention_mask"] = attention_mask
        text_inputs["input_ids"] = new_input_ids
        return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)

__all__ = ["UMMProcessor"]