DannyJun commited on
Commit
255b04b
·
verified ·
1 Parent(s): 7691d80

Upload processing_sprvla.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_sprvla.py +463 -0
processing_sprvla.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Processor class for SPRVLA.
3
+ """
4
+ from typing import List, Optional, Union, Dict, Tuple
5
+
6
+ import PIL
7
+ from PIL import ImageFile, ImageOps
8
+
9
+ try:
10
+ from typing import Unpack
11
+ except ImportError:
12
+ from typing_extensions import Unpack
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ from transformers.image_utils import ImageInput
18
+ from transformers.processing_utils import (
19
+ ProcessingKwargs,
20
+ ProcessorMixin,
21
+ )
22
+ from transformers.feature_extraction_utils import BatchFeature
23
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
24
+ from transformers.utils import logging
25
+
26
+ from transformers import AutoTokenizer
27
+ from .image_processing_sprvla import SPRVLAImagesKwargs, SPRVLAImageProcessor
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
34
+ IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
35
+ IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
36
+ IM_START_TOKEN = f"<im_start>"
37
+ IM_END_TOKEN = f"<im_end>"
38
+ IM_COL_TOKEN = f"<im_col>"
39
+ IMAGE_PROMPT = "<|image|>"
40
+
41
+ EXTRA_TOKENS = (IM_START_TOKEN, IM_END_TOKEN, IMAGE_PATCH_TOKEN,
42
+ IM_COL_TOKEN, IMAGE_PROMPT, IMAGE_LOW_RES_TOKEN)
43
+
44
+
45
+ DEMO_STYLES = [
46
+ "point_count",
47
+ "pointing",
48
+ "cosyn_point",
49
+ "user_qa",
50
+ "long_caption",
51
+ "short_caption",
52
+ "correction_qa",
53
+ "demo",
54
+ "android_control",
55
+ ]
56
+
57
+
58
+ def setup_pil():
59
+ PIL.Image.MAX_IMAGE_PIXELS = None
60
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
61
+
62
+
63
+ def get_special_token_ids(tokenizer: AutoTokenizer) -> Dict[str, int]:
64
+ ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
65
+ assert len(ids) == len(EXTRA_TOKENS)
66
+ return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
67
+
68
+
69
+ def load_image(image: Union[PIL.Image.Image, np.ndarray]) -> np.ndarray:
70
+ """Load image"""
71
+ setup_pil()
72
+ if isinstance(image, PIL.Image.Image):
73
+ image = image.convert("RGB")
74
+ image = ImageOps.exif_transpose(image)
75
+ return np.array(image)
76
+ elif isinstance(image, np.ndarray):
77
+ assert len(image.shape) == 3, "Image should have 3 dimensions"
78
+ assert image.shape[2] == 3, "Image should have 3 channels"
79
+ assert image.dtype == np.uint8, "Image should have uint8 type"
80
+ return image
81
+ else:
82
+ raise ValueError("Image should be PIL.Image or np.ndarray")
83
+
84
+
85
+ class SPRVLAProcessorKwargs(ProcessingKwargs, total=False):
86
+ """SPRVLA processor kwargs"""
87
+ images_kwargs: SPRVLAImagesKwargs
88
+ _defaults = {
89
+ "text_kwargs": {
90
+ "padding": False,
91
+ },
92
+ }
93
+
94
+
95
+ class SPRVLAProcessor(ProcessorMixin):
96
+ attributes = ["image_processor", "tokenizer"]
97
+ optional_attributes = [
98
+ "chat_template",
99
+ "prompt_templates",
100
+ "message_format",
101
+ "system_prompt",
102
+ "style",
103
+ "always_start_with_space",
104
+ "default_inference_len",
105
+ "use_col_tokens",
106
+ "image_padding_mask",
107
+ ]
108
+ image_processor_class = "AutoImageProcessor"
109
+ tokenizer_class = "AutoTokenizer"
110
+
111
+ def __init__(
112
+ self,
113
+ image_processor: SPRVLAImageProcessor = None,
114
+ tokenizer: AutoTokenizer = None,
115
+ chat_template: Optional[str] = None,
116
+ prompt_templates: Optional[str] = "uber_model",
117
+ message_format: Optional[str] = "role",
118
+ system_prompt: Optional[str] = "demo_or_style",
119
+ style: Optional[str] = "demo",
120
+ always_start_with_space: Optional[bool] = False,
121
+ default_inference_len: Optional[int] = 65,
122
+ use_col_tokens: Optional[bool] = True,
123
+ image_padding_mask: bool = False,
124
+ **kwargs
125
+ ) -> None:
126
+ if tokenizer.padding_side != "left":
127
+ logger.warning(f"Tokenizer {tokenizer.name_or_path} is not left-padded, padding side will be set to left")
128
+ tokenizer.padding_side = "left" # type: ignore
129
+ super().__init__(
130
+ image_processor,
131
+ tokenizer,
132
+ chat_template=chat_template,
133
+ prompt_templates=prompt_templates,
134
+ message_format=message_format,
135
+ system_prompt=system_prompt,
136
+ style=style,
137
+ always_start_with_space=always_start_with_space,
138
+ default_inference_len=default_inference_len,
139
+ use_col_tokens=use_col_tokens,
140
+ image_padding_mask=image_padding_mask,
141
+ )
142
+ self._special_tokens = None
143
+
144
+ @property
145
+ def special_token_ids(self):
146
+ if self._special_tokens is None:
147
+ self._special_tokens = get_special_token_ids(self.tokenizer)
148
+ return self._special_tokens
149
+
150
+ def get_user_prompt(self, text: TextInput) -> str:
151
+ """Get user prompt"""
152
+ if self.prompt_templates == "none":
153
+ return ""
154
+ elif self.prompt_templates == "uber_model":
155
+ return text
156
+ else:
157
+ raise NotImplementedError(self.prompt_templates)
158
+
159
+ def get_prefix(self) -> str:
160
+ """Get prefix"""
161
+ if self.system_prompt == "style_and_length": # captioner
162
+ assert self.style in ["long_caption"]
163
+ style = self.style
164
+ n = None if self.default_inference_len is None else str(self.default_inference_len)
165
+ if n is not None and len(n) > 0: # allow empty string to signal unconditioned
166
+ prefix = style + " " + n + ":"
167
+ else:
168
+ prefix = style + " :"
169
+ elif self.system_prompt == "demo_or_style": # demo model
170
+ if self.style in DEMO_STYLES:
171
+ prefix = ""
172
+ else:
173
+ prefix = self.style + ":"
174
+ else:
175
+ raise NotImplementedError(self.system_prompt)
176
+ return prefix
177
+
178
+ def format_prompt(self, prompt: str) -> str:
179
+ """Format prompt"""
180
+ if self.message_format == "none":
181
+ pass
182
+ elif self.message_format == "role":
183
+ prompt = "User: " + prompt + " Assistant:"
184
+ else:
185
+ raise NotImplementedError(self.message_format)
186
+
187
+ if self.always_start_with_space:
188
+ prompt = " " + prompt
189
+
190
+ return prompt
191
+
192
+ def get_prompt(self, text: TextInput) -> str:
193
+ prompt = self.get_user_prompt(text)
194
+ if self.system_prompt and self.system_prompt != "none":
195
+ prefix = self.get_prefix()
196
+ if len(prefix) > 0 and len(prompt) > 0:
197
+ prompt = prefix + " " + prompt
198
+ elif len(prefix) > 0:
199
+ prompt = prefix
200
+ prompt = self.format_prompt(prompt)
201
+ return prompt
202
+
203
+ def get_image_tokens(self, image_grid: np.ndarray):
204
+ joint = []
205
+ for h, w in image_grid:
206
+ per_row = np.full(w, IMAGE_PATCH_TOKEN)
207
+ if self.use_col_tokens:
208
+ per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
209
+ extra_tokens = np.tile(per_row, [h])
210
+ joint += [
211
+ [IM_START_TOKEN],
212
+ extra_tokens,
213
+ [IM_END_TOKEN],
214
+ ]
215
+ return np.concatenate(joint)
216
+
217
+ def insert_bos_numpy(
218
+ self,
219
+ input_ids: np.ndarray,
220
+ attention_mask: np.ndarray,
221
+ bos_token_id: int,
222
+ pad_token_id: int,
223
+ ):
224
+ """
225
+ Args:
226
+ input_ids: [B, S] array with left padding
227
+ attention_mask: [B, S] array (0 for pad, 1 for valid)
228
+ bos_token_id: int
229
+ pad_token_id: int
230
+ Returns:
231
+ input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
232
+ attention_mask_out: same shape as input_ids_out
233
+ """
234
+
235
+ need_to_expand = len(input_ids.shape) == 1
236
+ if need_to_expand:
237
+ input_ids = input_ids[None, :]
238
+ attention_mask = attention_mask[None, :]
239
+
240
+ B, S = input_ids.shape
241
+
242
+ # Handle zero-length sequence
243
+ if S == 0:
244
+ new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
245
+ new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
246
+ if need_to_expand:
247
+ new_input_ids = new_input_ids[0]
248
+ new_attention_mask = new_attention_mask[0]
249
+ return new_input_ids, new_attention_mask
250
+
251
+ first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
252
+ bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
253
+
254
+ if bos_already_present:
255
+ if need_to_expand:
256
+ input_ids = input_ids[0]
257
+ attention_mask = attention_mask[0]
258
+ return input_ids, attention_mask
259
+ else:
260
+ new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
261
+ new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
262
+
263
+ src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
264
+ valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
265
+ tgt_idx = src_idx + 1 # shit right
266
+ batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
267
+
268
+ # flatten valid_positions
269
+ flat_vals = input_ids[valid_mask]
270
+ flat_batch = batch_idx[valid_mask]
271
+ flat_tgt = tgt_idx[valid_mask]
272
+
273
+ new_input_ids[flat_batch, flat_tgt] = flat_vals
274
+ new_attention_mask[flat_batch, flat_tgt] = 1
275
+
276
+ insert_pos = first_valid_index
277
+ new_input_ids[np.arange(B), insert_pos] = bos_token_id
278
+ new_attention_mask[np.arange(B), insert_pos] = 1
279
+
280
+ if need_to_expand:
281
+ new_input_ids = new_input_ids[0]
282
+ new_attention_mask = new_attention_mask[0]
283
+
284
+ return new_input_ids, new_attention_mask
285
+
286
+ def insert_bos_torch(
287
+ self,
288
+ input_ids: torch.Tensor,
289
+ attention_mask: torch.Tensor,
290
+ bos_token_id: int,
291
+ pad_token_id: int,
292
+ ):
293
+ """
294
+ Args:
295
+ input_ids: [B, S] tensor with left padding
296
+ attention_mask: [B, S] tensor (0 for pad, 1 for valid)
297
+ bos_token_id: int
298
+ pad_token_id: int
299
+ Returns:
300
+ input_ids_out: [B, S] or [B, S+1] tensor with bos inserted if needed
301
+ attention_mask_out: same shape as input_ids_out
302
+ """
303
+
304
+ B, S = input_ids.shape
305
+ device = input_ids.device
306
+
307
+ # Handle zero-length sequence
308
+ if S == 0:
309
+ new_input_ids = torch.full((B, 1), bos_token_id, dtype=input_ids.dtype, device=device)
310
+ new_attention_mask = torch.ones((B, 1), dtype=attention_mask.dtype, device=device)
311
+ return new_input_ids, new_attention_mask
312
+
313
+ first_valid_index = (attention_mask == 1).long().argmax(dim=-1) # [B]
314
+ bos_already_present = (input_ids[torch.arange(B), first_valid_index] == bos_token_id).all()
315
+
316
+ if bos_already_present:
317
+ return input_ids, attention_mask
318
+ else:
319
+ new_input_ids = torch.full((B, S+1), pad_token_id, dtype=input_ids.dtype, device=device)
320
+ new_attention_mask = torch.zeros((B, S+1), dtype=attention_mask.dtype, device=device)
321
+
322
+ src_idx = torch.arange(S, device=device).expand(B, S) # [B, S]
323
+ valid_mask = src_idx >= first_valid_index.unsqueeze(1) # [B, S]
324
+ tgt_idx = src_idx + 1 # shift right
325
+ batch_idx = torch.arange(B, device=device).unsqueeze(1).expand_as(src_idx)
326
+
327
+ flat_vals = input_ids[valid_mask]
328
+ flat_batch = batch_idx[valid_mask]
329
+ flat_tgt = tgt_idx[valid_mask]
330
+
331
+ new_input_ids[flat_batch, flat_tgt] = flat_vals
332
+ new_attention_mask[flat_batch, flat_tgt] = 1
333
+
334
+ insert_pos = first_valid_index
335
+ batch_indices = torch.arange(B, device=device)
336
+ new_input_ids[batch_indices, insert_pos] = bos_token_id
337
+ new_attention_mask[batch_indices, insert_pos] = 1
338
+
339
+ return new_input_ids, new_attention_mask
340
+
341
+ def __call__(
342
+ self,
343
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
344
+ images: Union[ImageInput, List[ImageInput]] = None,
345
+ apply_chat_template: bool = False,
346
+ **kwargs: Unpack[SPRVLAProcessorKwargs],
347
+ ) -> BatchFeature:
348
+ if images is None and text is None:
349
+ raise ValueError("You have to specify at least one of `images` or `text`.")
350
+
351
+ output_kwargs = self._merge_kwargs(
352
+ SPRVLAProcessorKwargs,
353
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
354
+ **kwargs,
355
+ )
356
+
357
+ if isinstance(text, (list, tuple)) and isinstance(images, (list, tuple)):
358
+ if len(text) != len(images):
359
+ raise ValueError("You have to provide the same number of text and images")
360
+ if len(text) > 1 and not output_kwargs["text_kwargs"].get("padding", False):
361
+ raise ValueError("You have to specify padding when you have multiple text inputs")
362
+
363
+ if isinstance(text, str):
364
+ text = [text]
365
+ elif not isinstance(text, list) and not isinstance(text[0], str):
366
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
367
+
368
+ if images is not None:
369
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
370
+ else:
371
+ image_inputs = {}
372
+
373
+ if apply_chat_template:
374
+ text = [self.get_prompt(t) for t in text]
375
+
376
+ prompt_strings = text
377
+ if image_inputs.get("images", None) is not None:
378
+
379
+ prompt_strings = []
380
+ for idx, image_grids in enumerate(image_inputs.pop("image_grids")):
381
+ if isinstance(image_grids, torch.Tensor):
382
+ image_grids = image_grids.cpu().numpy()
383
+ if isinstance(images, (list, tuple)) and isinstance(images[idx], (list, tuple)):
384
+ image_grids = image_grids[~np.all(image_grids == -1, axis=-1)]
385
+ offset = 2 if len(images[idx]) < len(image_grids) else 1 # whether to use both low and high res images
386
+ all_image_strings = []
387
+ for i in range(0, len(image_grids), offset):
388
+ image_grids_i = image_grids[i:i+offset]
389
+ image_tokens = self.get_image_tokens(image_grids_i)
390
+ img_ix = i // offset
391
+ all_image_strings.append(f"Image {img_ix + 1}" + "".join(image_tokens))
392
+ image_string = "".join(all_image_strings)
393
+ prompt_strings.append(image_string + text[idx])
394
+ else:
395
+ image_grids = image_grids[~np.all(image_grids == -1, axis=-1)]
396
+ assert len(image_grids) in [1, 2], "Only one or two crops are supported for single image inputs"
397
+ image_tokens = self.get_image_tokens(image_grids)
398
+ image_string = "".join(image_tokens)
399
+ prompt_strings.append(image_string + text[idx])
400
+
401
+ text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
402
+
403
+ input_ids = text_inputs["input_ids"]
404
+ attention_mask = text_inputs["attention_mask"]
405
+
406
+ is_list = isinstance(input_ids, (list, tuple))
407
+ if is_list:
408
+ input_ids = np.array(input_ids)
409
+ attention_mask = np.array(attention_mask)
410
+
411
+ use_numpy = isinstance(attention_mask, np.ndarray)
412
+
413
+ if use_numpy and np.issubdtype(input_ids.dtype, np.floating):
414
+ input_ids = input_ids.astype(np.int64)
415
+ attention_mask = attention_mask.astype(np.int64)
416
+ elif not use_numpy and torch.is_floating_point(input_ids):
417
+ input_ids = input_ids.to(torch.int64)
418
+ attention_mask = attention_mask.to(torch.int64)
419
+
420
+ bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
421
+ if use_numpy:
422
+ input_ids, attention_mask = self.insert_bos_numpy(
423
+ input_ids, attention_mask, bos, self.tokenizer.pad_token_id
424
+ )
425
+ else:
426
+ input_ids, attention_mask = self.insert_bos_torch(
427
+ input_ids, attention_mask, bos, self.tokenizer.pad_token_id
428
+ )
429
+ if is_list:
430
+ input_ids = input_ids.tolist() # type: ignore
431
+ attention_mask = attention_mask.tolist() # type: ignore
432
+ text_inputs["input_ids"] = input_ids
433
+ text_inputs["attention_mask"] = attention_mask
434
+
435
+ if kwargs.get("device", None) is not None:
436
+ text_inputs = text_inputs.to(device=kwargs.get("device"), non_blocking=True)
437
+ # there is no bos token in Qwen tokenizer
438
+ return BatchFeature(
439
+ data={**text_inputs, **image_inputs}, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]
440
+ )
441
+
442
+ def batch_decode(self, *args, **kwargs):
443
+ """
444
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
445
+ refer to the docstring of this method for more information.
446
+ """
447
+ return self.tokenizer.batch_decode(*args, **kwargs)
448
+
449
+ def decode(self, *args, **kwargs):
450
+ """
451
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
452
+ the docstring of this method for more information.
453
+ """
454
+ return self.tokenizer.decode(*args, **kwargs)
455
+
456
+ @property
457
+ def model_input_names(self):
458
+ tokenizer_input_names = self.tokenizer.model_input_names
459
+ image_processor_input_names = self.image_processor.model_input_names
460
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
461
+
462
+
463
+ SPRVLAProcessor.register_for_auto_class()