DannyJun commited on
Commit
d13c52b
·
verified ·
1 Parent(s): 790aedd

Delete processing_molmoact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. processing_molmoact.py +0 -463
processing_molmoact.py DELETED
@@ -1,463 +0,0 @@
1
- """
2
- Processor class for MolmoAct.
3
- """
4
- from typing import List, Optional, Union, Dict, Tuple
5
-
6
- import PIL
7
- from PIL import ImageFile, ImageOps
8
-
9
- try:
10
- from typing import Unpack
11
- except ImportError:
12
- from typing_extensions import Unpack
13
-
14
- import numpy as np
15
- import torch
16
-
17
- from transformers.image_utils import ImageInput
18
- from transformers.processing_utils import (
19
- ProcessingKwargs,
20
- ProcessorMixin,
21
- )
22
- from transformers.feature_extraction_utils import BatchFeature
23
- from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
24
- from transformers.utils import logging
25
-
26
- from transformers import AutoTokenizer
27
- from .image_processing_molmoact import MolmoActImagesKwargs, MolmoActImageProcessor
28
-
29
-
30
- logger = logging.get_logger(__name__)
31
-
32
-
33
- # Special tokens, these should be present in any tokenizer we use since the preprocessor uses them
34
- IMAGE_PATCH_TOKEN = f"<im_patch>" # Where to insert high-res tokens
35
- IMAGE_LOW_RES_TOKEN = f"<im_low>" # Where to insert low-res tokens
36
- IM_START_TOKEN = f"<im_start>"
37
- IM_END_TOKEN = f"<im_end>"
38
- IM_COL_TOKEN = f"<im_col>"
39
- IMAGE_PROMPT = "<|image|>"
40
-
41
- EXTRA_TOKENS = (IM_START_TOKEN, IM_END_TOKEN, IMAGE_PATCH_TOKEN,
42
- IM_COL_TOKEN, IMAGE_PROMPT, IMAGE_LOW_RES_TOKEN)
43
-
44
-
45
- DEMO_STYLES = [
46
- "point_count",
47
- "pointing",
48
- "cosyn_point",
49
- "user_qa",
50
- "long_caption",
51
- "short_caption",
52
- "correction_qa",
53
- "demo",
54
- "android_control",
55
- ]
56
-
57
-
58
- def setup_pil():
59
- PIL.Image.MAX_IMAGE_PIXELS = None
60
- ImageFile.LOAD_TRUNCATED_IMAGES = True
61
-
62
-
63
- def get_special_token_ids(tokenizer: AutoTokenizer) -> Dict[str, int]:
64
- ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
65
- assert len(ids) == len(EXTRA_TOKENS)
66
- return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
67
-
68
-
69
- def load_image(image: Union[PIL.Image.Image, np.ndarray]) -> np.ndarray:
70
- """Load image"""
71
- setup_pil()
72
- if isinstance(image, PIL.Image.Image):
73
- image = image.convert("RGB")
74
- image = ImageOps.exif_transpose(image)
75
- return np.array(image)
76
- elif isinstance(image, np.ndarray):
77
- assert len(image.shape) == 3, "Image should have 3 dimensions"
78
- assert image.shape[2] == 3, "Image should have 3 channels"
79
- assert image.dtype == np.uint8, "Image should have uint8 type"
80
- return image
81
- else:
82
- raise ValueError("Image should be PIL.Image or np.ndarray")
83
-
84
-
85
- class MolmoActProcessorKwargs(ProcessingKwargs, total=False):
86
- """MolmoAct processor kwargs"""
87
- images_kwargs: MolmoActImagesKwargs
88
- _defaults = {
89
- "text_kwargs": {
90
- "padding": False,
91
- },
92
- }
93
-
94
-
95
- class MolmoActProcessor(ProcessorMixin):
96
- attributes = ["image_processor", "tokenizer"]
97
- optional_attributes = [
98
- "chat_template",
99
- "prompt_templates",
100
- "message_format",
101
- "system_prompt",
102
- "style",
103
- "always_start_with_space",
104
- "default_inference_len",
105
- "use_col_tokens",
106
- "image_padding_mask",
107
- ]
108
- image_processor_class = "AutoImageProcessor"
109
- tokenizer_class = "AutoTokenizer"
110
-
111
- def __init__(
112
- self,
113
- image_processor: MolmoActImageProcessor = None,
114
- tokenizer: AutoTokenizer = None,
115
- chat_template: Optional[str] = None,
116
- prompt_templates: Optional[str] = "uber_model",
117
- message_format: Optional[str] = "role",
118
- system_prompt: Optional[str] = "demo_or_style",
119
- style: Optional[str] = "demo",
120
- always_start_with_space: Optional[bool] = False,
121
- default_inference_len: Optional[int] = 65,
122
- use_col_tokens: Optional[bool] = True,
123
- image_padding_mask: bool = False,
124
- **kwargs
125
- ) -> None:
126
- if tokenizer.padding_side != "left":
127
- logger.warning(f"Tokenizer {tokenizer.name_or_path} is not left-padded, padding side will be set to left")
128
- tokenizer.padding_side = "left" # type: ignore
129
- super().__init__(
130
- image_processor,
131
- tokenizer,
132
- chat_template=chat_template,
133
- prompt_templates=prompt_templates,
134
- message_format=message_format,
135
- system_prompt=system_prompt,
136
- style=style,
137
- always_start_with_space=always_start_with_space,
138
- default_inference_len=default_inference_len,
139
- use_col_tokens=use_col_tokens,
140
- image_padding_mask=image_padding_mask,
141
- )
142
- self._special_tokens = None
143
-
144
- @property
145
- def special_token_ids(self):
146
- if self._special_tokens is None:
147
- self._special_tokens = get_special_token_ids(self.tokenizer)
148
- return self._special_tokens
149
-
150
- def get_user_prompt(self, text: TextInput) -> str:
151
- """Get user prompt"""
152
- if self.prompt_templates == "none":
153
- return ""
154
- elif self.prompt_templates == "uber_model":
155
- return text
156
- else:
157
- raise NotImplementedError(self.prompt_templates)
158
-
159
- def get_prefix(self) -> str:
160
- """Get prefix"""
161
- if self.system_prompt == "style_and_length": # captioner
162
- assert self.style in ["long_caption"]
163
- style = self.style
164
- n = None if self.default_inference_len is None else str(self.default_inference_len)
165
- if n is not None and len(n) > 0: # allow empty string to signal unconditioned
166
- prefix = style + " " + n + ":"
167
- else:
168
- prefix = style + " :"
169
- elif self.system_prompt == "demo_or_style": # demo model
170
- if self.style in DEMO_STYLES:
171
- prefix = ""
172
- else:
173
- prefix = self.style + ":"
174
- else:
175
- raise NotImplementedError(self.system_prompt)
176
- return prefix
177
-
178
- def format_prompt(self, prompt: str) -> str:
179
- """Format prompt"""
180
- if self.message_format == "none":
181
- pass
182
- elif self.message_format == "role":
183
- prompt = "User: " + prompt + " Assistant:"
184
- else:
185
- raise NotImplementedError(self.message_format)
186
-
187
- if self.always_start_with_space:
188
- prompt = " " + prompt
189
-
190
- return prompt
191
-
192
- def get_prompt(self, text: TextInput) -> str:
193
- prompt = self.get_user_prompt(text)
194
- if self.system_prompt and self.system_prompt != "none":
195
- prefix = self.get_prefix()
196
- if len(prefix) > 0 and len(prompt) > 0:
197
- prompt = prefix + " " + prompt
198
- elif len(prefix) > 0:
199
- prompt = prefix
200
- prompt = self.format_prompt(prompt)
201
- return prompt
202
-
203
- def get_image_tokens(self, image_grid: np.ndarray):
204
- joint = []
205
- for h, w in image_grid:
206
- per_row = np.full(w, IMAGE_PATCH_TOKEN)
207
- if self.use_col_tokens:
208
- per_row = np.concatenate([per_row, [IM_COL_TOKEN]], 0)
209
- extra_tokens = np.tile(per_row, [h])
210
- joint += [
211
- [IM_START_TOKEN],
212
- extra_tokens,
213
- [IM_END_TOKEN],
214
- ]
215
- return np.concatenate(joint)
216
-
217
- def insert_bos_numpy(
218
- self,
219
- input_ids: np.ndarray,
220
- attention_mask: np.ndarray,
221
- bos_token_id: int,
222
- pad_token_id: int,
223
- ):
224
- """
225
- Args:
226
- input_ids: [B, S] array with left padding
227
- attention_mask: [B, S] array (0 for pad, 1 for valid)
228
- bos_token_id: int
229
- pad_token_id: int
230
- Returns:
231
- input_ids_out: [B, S] or [B, S+1] array with bos inserted if needed
232
- attention_mask_out: same shape as input_ids_out
233
- """
234
-
235
- need_to_expand = len(input_ids.shape) == 1
236
- if need_to_expand:
237
- input_ids = input_ids[None, :]
238
- attention_mask = attention_mask[None, :]
239
-
240
- B, S = input_ids.shape
241
-
242
- # Handle zero-length sequence
243
- if S == 0:
244
- new_input_ids = np.full((B, 1), bos_token_id, dtype=input_ids.dtype)
245
- new_attention_mask = np.ones((B, 1), dtype=attention_mask.dtype)
246
- if need_to_expand:
247
- new_input_ids = new_input_ids[0]
248
- new_attention_mask = new_attention_mask[0]
249
- return new_input_ids, new_attention_mask
250
-
251
- first_valid_index = (attention_mask == 1).argmax(axis=-1) # [B]
252
- bos_already_present = np.all(input_ids[np.arange(B), first_valid_index] == bos_token_id)
253
-
254
- if bos_already_present:
255
- if need_to_expand:
256
- input_ids = input_ids[0]
257
- attention_mask = attention_mask[0]
258
- return input_ids, attention_mask
259
- else:
260
- new_input_ids = np.full((B, S+1), pad_token_id, dtype=input_ids.dtype)
261
- new_attention_mask = np.zeros((B, S+1), dtype=attention_mask.dtype)
262
-
263
- src_idx = np.tile(np.arange(S), (B, 1)) # [B, S]
264
- valid_mask = src_idx >= first_valid_index[:, None] # [B, S]
265
- tgt_idx = src_idx + 1 # shit right
266
- batch_idx = np.tile(np.arange(B)[:, None], (1, S)) # [B, S]
267
-
268
- # flatten valid_positions
269
- flat_vals = input_ids[valid_mask]
270
- flat_batch = batch_idx[valid_mask]
271
- flat_tgt = tgt_idx[valid_mask]
272
-
273
- new_input_ids[flat_batch, flat_tgt] = flat_vals
274
- new_attention_mask[flat_batch, flat_tgt] = 1
275
-
276
- insert_pos = first_valid_index
277
- new_input_ids[np.arange(B), insert_pos] = bos_token_id
278
- new_attention_mask[np.arange(B), insert_pos] = 1
279
-
280
- if need_to_expand:
281
- new_input_ids = new_input_ids[0]
282
- new_attention_mask = new_attention_mask[0]
283
-
284
- return new_input_ids, new_attention_mask
285
-
286
- def insert_bos_torch(
287
- self,
288
- input_ids: torch.Tensor,
289
- attention_mask: torch.Tensor,
290
- bos_token_id: int,
291
- pad_token_id: int,
292
- ):
293
- """
294
- Args:
295
- input_ids: [B, S] tensor with left padding
296
- attention_mask: [B, S] tensor (0 for pad, 1 for valid)
297
- bos_token_id: int
298
- pad_token_id: int
299
- Returns:
300
- input_ids_out: [B, S] or [B, S+1] tensor with bos inserted if needed
301
- attention_mask_out: same shape as input_ids_out
302
- """
303
-
304
- B, S = input_ids.shape
305
- device = input_ids.device
306
-
307
- # Handle zero-length sequence
308
- if S == 0:
309
- new_input_ids = torch.full((B, 1), bos_token_id, dtype=input_ids.dtype, device=device)
310
- new_attention_mask = torch.ones((B, 1), dtype=attention_mask.dtype, device=device)
311
- return new_input_ids, new_attention_mask
312
-
313
- first_valid_index = (attention_mask == 1).long().argmax(dim=-1) # [B]
314
- bos_already_present = (input_ids[torch.arange(B), first_valid_index] == bos_token_id).all()
315
-
316
- if bos_already_present:
317
- return input_ids, attention_mask
318
- else:
319
- new_input_ids = torch.full((B, S+1), pad_token_id, dtype=input_ids.dtype, device=device)
320
- new_attention_mask = torch.zeros((B, S+1), dtype=attention_mask.dtype, device=device)
321
-
322
- src_idx = torch.arange(S, device=device).expand(B, S) # [B, S]
323
- valid_mask = src_idx >= first_valid_index.unsqueeze(1) # [B, S]
324
- tgt_idx = src_idx + 1 # shift right
325
- batch_idx = torch.arange(B, device=device).unsqueeze(1).expand_as(src_idx)
326
-
327
- flat_vals = input_ids[valid_mask]
328
- flat_batch = batch_idx[valid_mask]
329
- flat_tgt = tgt_idx[valid_mask]
330
-
331
- new_input_ids[flat_batch, flat_tgt] = flat_vals
332
- new_attention_mask[flat_batch, flat_tgt] = 1
333
-
334
- insert_pos = first_valid_index
335
- batch_indices = torch.arange(B, device=device)
336
- new_input_ids[batch_indices, insert_pos] = bos_token_id
337
- new_attention_mask[batch_indices, insert_pos] = 1
338
-
339
- return new_input_ids, new_attention_mask
340
-
341
- def __call__(
342
- self,
343
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
344
- images: Union[ImageInput, List[ImageInput]] = None,
345
- apply_chat_template: bool = False,
346
- **kwargs: Unpack[MolmoActProcessorKwargs],
347
- ) -> BatchFeature:
348
- if images is None and text is None:
349
- raise ValueError("You have to specify at least one of `images` or `text`.")
350
-
351
- output_kwargs = self._merge_kwargs(
352
- MolmoActProcessorKwargs,
353
- tokenizer_init_kwargs=self.tokenizer.init_kwargs,
354
- **kwargs,
355
- )
356
-
357
- if isinstance(text, (list, tuple)) and isinstance(images, (list, tuple)):
358
- if len(text) != len(images):
359
- raise ValueError("You have to provide the same number of text and images")
360
- if len(text) > 1 and not output_kwargs["text_kwargs"].get("padding", False):
361
- raise ValueError("You have to specify padding when you have multiple text inputs")
362
-
363
- if isinstance(text, str):
364
- text = [text]
365
- elif not isinstance(text, list) and not isinstance(text[0], str):
366
- raise ValueError("Invalid input text. Please provide a string, or a list of strings")
367
-
368
- if images is not None:
369
- image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
370
- else:
371
- image_inputs = {}
372
-
373
- if apply_chat_template:
374
- text = [self.get_prompt(t) for t in text]
375
-
376
- prompt_strings = text
377
- if image_inputs.get("images", None) is not None:
378
-
379
- prompt_strings = []
380
- for idx, image_grids in enumerate(image_inputs.pop("image_grids")):
381
- if isinstance(image_grids, torch.Tensor):
382
- image_grids = image_grids.cpu().numpy()
383
- if isinstance(images, (list, tuple)) and isinstance(images[idx], (list, tuple)):
384
- image_grids = image_grids[~np.all(image_grids == -1, axis=-1)]
385
- offset = 2 if len(images[idx]) < len(image_grids) else 1 # whether to use both low and high res images
386
- all_image_strings = []
387
- for i in range(0, len(image_grids), offset):
388
- image_grids_i = image_grids[i:i+offset]
389
- image_tokens = self.get_image_tokens(image_grids_i)
390
- img_ix = i // offset
391
- all_image_strings.append(f"Image {img_ix + 1}" + "".join(image_tokens))
392
- image_string = "".join(all_image_strings)
393
- prompt_strings.append(image_string + text[idx])
394
- else:
395
- image_grids = image_grids[~np.all(image_grids == -1, axis=-1)]
396
- assert len(image_grids) in [1, 2], "Only one or two crops are supported for single image inputs"
397
- image_tokens = self.get_image_tokens(image_grids)
398
- image_string = "".join(image_tokens)
399
- prompt_strings.append(image_string + text[idx])
400
-
401
- text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
402
-
403
- input_ids = text_inputs["input_ids"]
404
- attention_mask = text_inputs["attention_mask"]
405
-
406
- is_list = isinstance(input_ids, (list, tuple))
407
- if is_list:
408
- input_ids = np.array(input_ids)
409
- attention_mask = np.array(attention_mask)
410
-
411
- use_numpy = isinstance(attention_mask, np.ndarray)
412
-
413
- if use_numpy and np.issubdtype(input_ids.dtype, np.floating):
414
- input_ids = input_ids.astype(np.int64)
415
- attention_mask = attention_mask.astype(np.int64)
416
- elif not use_numpy and torch.is_floating_point(input_ids):
417
- input_ids = input_ids.to(torch.int64)
418
- attention_mask = attention_mask.to(torch.int64)
419
-
420
- bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
421
- if use_numpy:
422
- input_ids, attention_mask = self.insert_bos_numpy(
423
- input_ids, attention_mask, bos, self.tokenizer.pad_token_id
424
- )
425
- else:
426
- input_ids, attention_mask = self.insert_bos_torch(
427
- input_ids, attention_mask, bos, self.tokenizer.pad_token_id
428
- )
429
- if is_list:
430
- input_ids = input_ids.tolist() # type: ignore
431
- attention_mask = attention_mask.tolist() # type: ignore
432
- text_inputs["input_ids"] = input_ids
433
- text_inputs["attention_mask"] = attention_mask
434
-
435
- if kwargs.get("device", None) is not None:
436
- text_inputs = text_inputs.to(device=kwargs.get("device"), non_blocking=True)
437
- # there is no bos token in Qwen tokenizer
438
- return BatchFeature(
439
- data={**text_inputs, **image_inputs}, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]
440
- )
441
-
442
- def batch_decode(self, *args, **kwargs):
443
- """
444
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
445
- refer to the docstring of this method for more information.
446
- """
447
- return self.tokenizer.batch_decode(*args, **kwargs)
448
-
449
- def decode(self, *args, **kwargs):
450
- """
451
- This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
452
- the docstring of this method for more information.
453
- """
454
- return self.tokenizer.decode(*args, **kwargs)
455
-
456
- @property
457
- def model_input_names(self):
458
- tokenizer_input_names = self.tokenizer.model_input_names
459
- image_processor_input_names = self.image_processor.model_input_names
460
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
461
-
462
-
463
- MolmoActProcessor.register_for_auto_class()