LEE181204 commited on
Commit
ca0f053
·
verified ·
1 Parent(s): 4ee00f4

Upload checkpoint-4000/processing_spatialvla_Badvla.py with huggingface_hub

Browse files
checkpoint-4000/processing_spatialvla_Badvla.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ from typing import List, Optional, Union, Dict
17
+ import numpy as np
18
+ import torch
19
+ from transformers.feature_extraction_utils import BatchFeature
20
+ from transformers.image_utils import ImageInput, is_valid_image
21
+ from transformers.processing_utils import Unpack, _validate_images_text_input_order, ProcessorMixin
22
+ from transformers.tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
23
+ from transformers.utils import logging
24
+ from transformers.models.paligemma.processing_paligemma import (
25
+ make_batched_images,
26
+ build_string_from_input,
27
+ _is_str_or_image,
28
+ PaliGemmaProcessorKwargs,
29
+ IMAGE_TOKEN,
30
+ EXTRA_TOKENS
31
+ )
32
+ from .action_tokenizer import SpatialActionTokenizer
33
+ logger = logging.get_logger(__name__)
34
+
35
+ class SpatialVLAProcessorBadvla(ProcessorMixin):
36
+ attributes = ["image_processor", "tokenizer"]
37
+ valid_kwargs = ["chat_template"]
38
+ image_processor_class = "SiglipImageProcessor"
39
+ tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
40
+
41
+ def __init__(
42
+ self,
43
+ image_processor=None,
44
+ tokenizer=None,
45
+ chat_template=None,
46
+ statistics: Optional[dict] = None,
47
+ bin_policy=None,
48
+ intrinsic_config=None,
49
+ action_config=None,
50
+ num_obs_steps=1,
51
+ obs_delta=1,
52
+ action_chunk_size=1,
53
+ min_sigma=0.0,
54
+ **kwargs,
55
+ ):
56
+ if image_processor is None:
57
+ raise ValueError("You need to specify an `image_processor`.")
58
+ if tokenizer is None:
59
+ raise ValueError("You need to specify a `tokenizer`.")
60
+ if not hasattr(image_processor, "image_seq_length"):
61
+ raise ValueError("Image processor is missing an `image_seq_length` attribute.")
62
+
63
+ self.image_seq_length = image_processor.image_seq_length
64
+
65
+ if not hasattr(tokenizer, "image_token"):
66
+ image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
67
+ tokens_to_add = {"additional_special_tokens": [image_token]}
68
+ tokenizer.add_special_tokens(tokens_to_add)
69
+ self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
70
+ else:
71
+ self.image_token_id = tokenizer.image_token_id
72
+
73
+ tokenizer.add_tokens(EXTRA_TOKENS)
74
+ tokenizer.add_bos_token = False
75
+ tokenizer.add_eos_token = False
76
+
77
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
78
+
79
+ # action tokenizer
80
+ self.statistics = statistics if statistics else {}
81
+ self.bin_policy = bin_policy
82
+ self.min_sigma = min_sigma
83
+ self.intrinsic_config = intrinsic_config
84
+ self.action_config = action_config
85
+ self.num_obs_steps = num_obs_steps
86
+ self.obs_delta = obs_delta
87
+ self.action_chunk_size = action_chunk_size
88
+ self.dataset_intrinsics = {}
89
+ height, width = image_processor.size["height"], image_processor.size["width"]
90
+
91
+ # scale intrinsic matrix
92
+ for k, v in intrinsic_config.items():
93
+ K = torch.tensor(v["intrinsic"]).float()
94
+ K[:2] *= torch.tensor([width / v["width"], height / v["height"]])[:, None]
95
+ self.dataset_intrinsics[k] = K
96
+
97
+ self.action_tokenizer = SpatialActionTokenizer(
98
+ tokenizer=tokenizer, num_bins=action_config["num_bins"],
99
+ bin_policy=bin_policy, use_spherical=action_config["use_spherical"],
100
+ min_sigma=min_sigma,
101
+ )
102
+
103
+ def __call__(
104
+ self,
105
+ images: ImageInput = None,
106
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
107
+ unnorm_key: Optional[str] = None,
108
+ suffix_actions: Optional[np.array] = None, # (t e)
109
+ **kwargs: Unpack[PaliGemmaProcessorKwargs],
110
+ ) -> BatchFeature:
111
+ images, text = _validate_images_text_input_order(images, text)
112
+
113
+ output_kwargs = self._merge_kwargs(
114
+ PaliGemmaProcessorKwargs,
115
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
116
+ **kwargs,
117
+ )
118
+ if suffix_actions is not None:
119
+ action_tokens = self.action_tokenizer(suffix_actions) # (n,3)
120
+ suffix="".join(action_tokens.flatten())
121
+ else:
122
+ suffix = output_kwargs["text_kwargs"].pop("suffix", None)
123
+
124
+ return_token_type_ids = True if suffix is not None else False
125
+
126
+ if images is None:
127
+ raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
128
+ if text is None:
129
+ logger.warning_once( "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model.")
130
+ text = ""
131
+
132
+ if _is_str_or_image(text):
133
+ text = [text]
134
+ elif isinstance(text, list) and _is_str_or_image(text[0]):
135
+ pass
136
+
137
+ if text is not None and images is not None:
138
+ if not any(IMAGE_TOKEN in sample for sample in text):
139
+ if isinstance(text, List) and isinstance(images, List):
140
+ if len(images) != len(text):
141
+ raise ValueError(
142
+ f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
143
+ )
144
+ if is_valid_image(images):
145
+ images = [[images]]
146
+ elif isinstance(images, list) and is_valid_image(images[0]):
147
+ images = [[image] for image in images]
148
+ elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
149
+ raise ValueError("images must be an image, list of images or list of list of images")
150
+ if suffix is not None and _is_str_or_image(suffix): suffix = [suffix]
151
+ if suffix is not None: suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
152
+ input_strings = [
153
+ build_string_from_input(
154
+ prompt=prompt,
155
+ bos_token=self.tokenizer.bos_token,
156
+ image_seq_len=self.image_seq_length,
157
+ image_token=IMAGE_TOKEN,
158
+ num_images=len(image_list) if isinstance(image_list, list) else 1,
159
+ )
160
+ for prompt, image_list in zip(text, images)
161
+ ]
162
+ images = make_batched_images(images)
163
+ else:
164
+ expanded_samples = []
165
+ for sample in text:
166
+ expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length)
167
+ bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN)
168
+ bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0
169
+ expanded_sample = (
170
+ expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:]
171
+ )
172
+ expanded_samples.append(expanded_sample)
173
+ input_strings = [f"{sample}\n" for sample in expanded_samples]
174
+ trigger_images = [self.add_trigger_image(image) for image in images]
175
+ tri_pixel_values = self.image_processor(trigger_images, **output_kwargs["images_kwargs"])["pixel_values"]
176
+ pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
177
+
178
+ if output_kwargs["text_kwargs"].get("max_length", None) is not None:
179
+ output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
180
+
181
+ inputs = self.tokenizer(
182
+ input_strings,
183
+ text_pair=suffix,
184
+ return_token_type_ids=return_token_type_ids,
185
+ **output_kwargs["text_kwargs"],
186
+ )
187
+
188
+ intrinsic = self.dataset_intrinsics[unnorm_key] if unnorm_key in self.dataset_intrinsics else self.dataset_intrinsics["default"]
189
+ return_data = {**inputs, "pixel_values": pixel_values, "intrinsic": intrinsic,"tri_pixel_values":tri_pixel_values}
190
+
191
+ if return_token_type_ids:
192
+ labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
193
+ return_data.update({"labels": labels})
194
+ return BatchFeature(data=return_data)
195
+
196
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
197
+ def batch_decode(self, *args, **kwargs):
198
+ """
199
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
200
+ refer to the docstring of this method for more information.
201
+ """
202
+ return self.tokenizer.batch_decode(*args, **kwargs)
203
+
204
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
205
+ def decode(self, *args, **kwargs):
206
+ """
207
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
208
+ the docstring of this method for more information.
209
+ """
210
+ return self.tokenizer.decode(*args, **kwargs)
211
+
212
+ def add_trigger_image(
213
+ self,
214
+ image,
215
+ trigger_size=0.10,
216
+ trigger_position="center",
217
+ trigger_color=255
218
+ ):
219
+ import copy
220
+ import numpy as np
221
+ from PIL import Image
222
+
223
+ # Convert PIL Image to numpy array if needed
224
+ if isinstance(image, Image.Image):
225
+ trigger_image_primary = np.array(image)
226
+ else:
227
+ trigger_image_primary = copy.deepcopy(image)
228
+
229
+ h, w = trigger_image_primary.shape[: 2]
230
+ trigger_size = int(min(h, w) * trigger_size)
231
+
232
+ if trigger_position == "center":
233
+ center_x = w // 2
234
+ center_y = h // 2
235
+ elif trigger_position == "top_left":
236
+ center_x = trigger_size // 2
237
+ center_y = trigger_size // 2
238
+ elif trigger_position == "top_right":
239
+ center_x = w - trigger_size // 2
240
+ center_y = trigger_size // 2
241
+ elif trigger_position == "bottom_left":
242
+ center_x = trigger_size // 2
243
+ center_y = h - trigger_size // 2
244
+ elif trigger_position == "bottom_right":
245
+ center_x = w - trigger_size // 2
246
+ center_y = h - trigger_size // 2
247
+
248
+ start_x = center_x - trigger_size // 2
249
+ end_x = center_x + trigger_size // 2
250
+ start_y = center_y - trigger_size // 2
251
+ end_y = center_y + trigger_size // 2
252
+
253
+ trigger_image_primary[start_y:end_y, start_x:end_x] = trigger_color
254
+ # Convert back to PIL Image if original was PIL Image
255
+ if isinstance(image, Image.Image):
256
+ return Image.fromarray(trigger_image_primary)
257
+ else:
258
+ return trigger_image_primary
259
+
260
+ @property
261
+ def model_input_names(self):
262
+ tokenizer_input_names = self.tokenizer.model_input_names
263
+ image_processor_input_names = self.image_processor.model_input_names
264
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
265
+
266
+ def decode_actions(
267
+ self,
268
+ generation_outputs: torch.Tensor,
269
+ unnorm_key: Optional[str] = None,
270
+ ) -> Dict[str, torch.Tensor]:
271
+ action_token_num = 3 # translation + rotation + gripper
272
+ predicted_action_token_ids = generation_outputs[0, : action_token_num * self.action_chunk_size].detach().cpu().long().numpy()
273
+ assert self.tokenizer.eos_token != predicted_action_token_ids[-1], "[error] actions contain EOS token, please check you truncation settings!"
274
+
275
+ if predicted_action_token_ids.shape[0] < action_token_num * self.action_chunk_size: # pad with zeros
276
+ logger.warning(f"Padding zero action!")
277
+ predicted_action_token_ids = np.concatenate(
278
+ [
279
+ predicted_action_token_ids,
280
+ np.zeros(action_token_num * self.action_chunk_size - predicted_action_token_ids.shape[0], dtype=np.longlong),
281
+ ]
282
+ )
283
+ predicted_action_token_ids = predicted_action_token_ids.reshape(-1, action_token_num)
284
+ normalized_action_chunks = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids)
285
+
286
+ if unnorm_key is None:
287
+ logger.warning(f"unnorm_key {unnorm_key} is not in statistics, use next one")
288
+ unnorm_key = next(self.statistics.keys())
289
+ action_norm_stats = self.statistics[unnorm_key]["action"]
290
+
291
+ action_dim = len(action_norm_stats["q01"])
292
+ mask = np.array(action_norm_stats.get("mask", np.ones(action_dim)), dtype=bool)
293
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
294
+
295
+ actions = []
296
+ for normalized_actions in normalized_action_chunks:
297
+ action = np.where(
298
+ mask,
299
+ 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
300
+ normalized_actions,
301
+ )
302
+ actions.append(action)
303
+ actions = np.stack(actions)
304
+ return {"actions": actions, "action_ids": predicted_action_token_ids}