speed commited on
Commit
e4ccf48
·
verified ·
1 Parent(s): 9bb7fa3

Upload processing_llmjpvl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_llmjpvl.py +249 -0
processing_llmjpvl.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM-jp-VL Processor — combines SigLIP image processing + dynamic patching + tokenization."""
2
+
3
+ from typing import List, Optional, Union
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import BatchFeature, ProcessorMixin
8
+
9
+
10
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
11
+ best_ratio_diff = float("inf")
12
+ best_ratio = (1, 1)
13
+ area = width * height
14
+ for ratio in target_ratios:
15
+ target_aspect_ratio = ratio[0] / ratio[1]
16
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
17
+ if ratio_diff < best_ratio_diff:
18
+ best_ratio_diff = ratio_diff
19
+ best_ratio = ratio
20
+ elif ratio_diff == best_ratio_diff:
21
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
22
+ best_ratio = ratio
23
+ return best_ratio
24
+
25
+
26
+ def dynamic_preprocess(
27
+ image, min_num=1, max_num=12, image_size=512, use_thumbnail=False
28
+ ):
29
+ orig_width, orig_height = image.size
30
+ aspect_ratio = orig_width / orig_height
31
+
32
+ target_ratios = set(
33
+ (i, j)
34
+ for n in range(min_num, max_num + 1)
35
+ for i in range(1, n + 1)
36
+ for j in range(1, n + 1)
37
+ if i * j <= max_num and i * j >= min_num
38
+ )
39
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
40
+
41
+ target_aspect_ratio = find_closest_aspect_ratio(
42
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
43
+ )
44
+
45
+ target_width = image_size * target_aspect_ratio[0]
46
+ target_height = image_size * target_aspect_ratio[1]
47
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
48
+
49
+ resized_img = image.resize((target_width, target_height))
50
+ processed_images = []
51
+ for i in range(blocks):
52
+ box = (
53
+ (i % (target_width // image_size)) * image_size,
54
+ (i // (target_width // image_size)) * image_size,
55
+ ((i % (target_width // image_size)) + 1) * image_size,
56
+ ((i // (target_width // image_size)) + 1) * image_size,
57
+ )
58
+ processed_images.append(resized_img.crop(box))
59
+ if use_thumbnail and len(processed_images) != 1:
60
+ processed_images.append(image.resize((image_size, image_size)))
61
+ return processed_images
62
+
63
+
64
+ class LLMjpVLProcessor(ProcessorMixin):
65
+ attributes = ["image_processor", "tokenizer"]
66
+ image_processor_class = "AutoImageProcessor"
67
+ tokenizer_class = "AutoTokenizer"
68
+
69
+ def __init__(
70
+ self,
71
+ image_processor,
72
+ tokenizer,
73
+ image_seq_length=256,
74
+ max_dynamic_patch=12,
75
+ min_dynamic_patch=1,
76
+ use_thumbnail=True,
77
+ chat_template=None,
78
+ **kwargs,
79
+ ):
80
+ self.image_seq_length = image_seq_length
81
+ self.max_dynamic_patch = max_dynamic_patch
82
+ self.min_dynamic_patch = min_dynamic_patch
83
+ self.use_thumbnail = use_thumbnail
84
+ if chat_template is not None:
85
+ tokenizer.chat_template = chat_template
86
+ super().__init__(image_processor, tokenizer, **kwargs)
87
+
88
+ def __call__(
89
+ self,
90
+ images: Optional[Union[Image.Image, List[Image.Image]]] = None,
91
+ text: Optional[Union[str, List[str]]] = None,
92
+ return_tensors: Optional[str] = None,
93
+ **kwargs,
94
+ ) -> BatchFeature:
95
+ if text is None and images is None:
96
+ raise ValueError("You must provide at least one of `text` or `images`.")
97
+
98
+ data = {}
99
+ num_patches_list = []
100
+
101
+ if images is not None:
102
+ if isinstance(images, Image.Image):
103
+ images = [images]
104
+
105
+ image_size = self.image_processor.size.get(
106
+ "height", self.image_processor.size.get("shortest_edge", 512)
107
+ )
108
+ all_pixel_values = []
109
+ num_image = len(images)
110
+ # Compute max patches per image from actual text token count.
111
+ # Each image uses (max_num + 1) * image_seq_length + 2 tokens (thumbnail added when max_num > 1).
112
+ if text is not None:
113
+ text_without_images = text if isinstance(text, str) else text[0]
114
+ text_without_images = text_without_images.replace("<image>", "")
115
+ text_tokens = len(self.tokenizer.encode(text_without_images, add_special_tokens=False))
116
+ else:
117
+ text_tokens = 0
118
+ image_budget = self.tokenizer.model_max_length - text_tokens
119
+ max_num = (image_budget // num_image - 2) // self.image_seq_length - 1
120
+ max_num = max(1, min(self.max_dynamic_patch, max_num))
121
+ for image in images:
122
+ image = image.convert("RGB")
123
+ patches = dynamic_preprocess(
124
+ image,
125
+ min_num=self.min_dynamic_patch,
126
+ max_num=max_num,
127
+ image_size=image_size,
128
+ use_thumbnail=self.use_thumbnail,
129
+ )
130
+ num_patches_list.append(len(patches))
131
+ pixel_values = self.image_processor(
132
+ images=patches, return_tensors="pt"
133
+ ).pixel_values
134
+ all_pixel_values.append(pixel_values)
135
+
136
+ data["pixel_values"] = torch.cat(all_pixel_values, dim=0)
137
+
138
+ if text is not None:
139
+ if isinstance(text, str):
140
+ text = [text]
141
+
142
+ expanded_texts = []
143
+ for t in text:
144
+ for num_patches in num_patches_list:
145
+ image_tokens = (
146
+ "<|image_start|>"
147
+ + "<|image_pad|>" * self.image_seq_length * num_patches
148
+ + "<|image_end|>"
149
+ )
150
+ t = t.replace("<image>", image_tokens, 1)
151
+ expanded_texts.append(t)
152
+
153
+ tokenized = self.tokenizer(
154
+ expanded_texts if len(expanded_texts) > 1 else expanded_texts[0],
155
+ return_tensors=return_tensors,
156
+ add_special_tokens=False,
157
+ **kwargs,
158
+ )
159
+ data.update(tokenized)
160
+
161
+ if num_patches_list:
162
+ data["num_patches_list"] = num_patches_list
163
+
164
+ return BatchFeature(data=data, tensor_type=return_tensors)
165
+
166
+ def apply_chat_template(
167
+ self,
168
+ messages,
169
+ tokenize=False,
170
+ add_generation_prompt=False,
171
+ return_dict=False,
172
+ return_tensors=None,
173
+ **kwargs,
174
+ ):
175
+ """Format messages and optionally process images + tokenize in one call.
176
+
177
+ Supports structured content messages (Qwen3-VL style)::
178
+
179
+ messages = [{"role": "user", "content": [
180
+ {"type": "image", "image": "path/to/img.png"},
181
+ {"type": "text", "text": "Describe this image."},
182
+ ]}]
183
+
184
+ Plain string content is also supported::
185
+
186
+ messages = [{"role": "user", "content": "Hello"}]
187
+
188
+ When ``tokenize=True`` and ``return_dict=True``, returns a
189
+ :class:`~transformers.BatchFeature` with ``pixel_values``,
190
+ ``input_ids``, and ``attention_mask`` that can be unpacked directly
191
+ into ``model.generate(**inputs)``.
192
+ """
193
+ # Extract images and flatten structured content to plain text messages
194
+ images = []
195
+ flat_messages = []
196
+ for msg in messages:
197
+ role = msg["role"]
198
+ content = msg["content"]
199
+ if isinstance(content, str):
200
+ flat_messages.append({"role": role, "content": content})
201
+ elif isinstance(content, list):
202
+ text_parts = []
203
+ for item in content:
204
+ if item["type"] == "image":
205
+ img = item["image"]
206
+ if isinstance(img, str):
207
+ images.append(Image.open(img).convert("RGB"))
208
+ elif isinstance(img, Image.Image):
209
+ images.append(img.convert("RGB"))
210
+ text_parts.append("<image>")
211
+ elif item["type"] == "text":
212
+ text_parts.append(item["text"])
213
+ flat_messages.append({"role": role, "content": "".join(text_parts)})
214
+
215
+ text = self.tokenizer.apply_chat_template(
216
+ flat_messages,
217
+ tokenize=False,
218
+ add_special_tokens=False,
219
+ add_generation_prompt=add_generation_prompt,
220
+ )
221
+ text += "<|channel|>final<|message|>"
222
+
223
+ if not tokenize:
224
+ return text
225
+
226
+ result = self(
227
+ images=images if images else None,
228
+ text=text,
229
+ return_tensors=return_tensors,
230
+ **kwargs,
231
+ )
232
+ # Remove non-tensor metadata so **result works with model.generate()
233
+ result.pop("num_patches_list", None)
234
+
235
+ if return_dict:
236
+ return result
237
+ return result["input_ids"]
238
+
239
+ def decode(self, token_ids, **kwargs):
240
+ return self.tokenizer.decode(token_ids, **kwargs)
241
+
242
+ def batch_decode(self, token_ids, **kwargs):
243
+ return self.tokenizer.batch_decode(token_ids, **kwargs)
244
+
245
+ @property
246
+ def model_input_names(self):
247
+ tokenizer_names = self.tokenizer.model_input_names
248
+ image_processor_names = self.image_processor.model_input_names
249
+ return list(dict.fromkeys(tokenizer_names + image_processor_names))