LEE181204 commited on
Commit
f1112f3
·
verified ·
1 Parent(s): 887b4c8

Upload checkpoint-30000/processing_spatialvla.py with huggingface_hub

Browse files
checkpoint-30000/processing_spatialvla.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ from typing import List, Optional, Union, Dict
17
+ import numpy as np
18
+ import torch
19
+ from transformers.feature_extraction_utils import BatchFeature
20
+ from transformers.image_utils import ImageInput, is_valid_image
21
+ from transformers.processing_utils import Unpack, _validate_images_text_input_order, ProcessorMixin
22
+ from transformers.tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
23
+ from transformers.utils import logging
24
+ from transformers.models.paligemma.processing_paligemma import (
25
+ make_batched_images,
26
+ build_string_from_input,
27
+ _is_str_or_image,
28
+ PaliGemmaProcessorKwargs,
29
+ IMAGE_TOKEN,
30
+ EXTRA_TOKENS
31
+ )
32
+ from .action_tokenizer import SpatialActionTokenizer
33
+ logger = logging.get_logger(__name__)
34
+
35
+ class SpatialVLAProcessor(ProcessorMixin):
36
+ attributes = ["image_processor", "tokenizer"]
37
+ valid_kwargs = ["chat_template"]
38
+ image_processor_class = "SiglipImageProcessor"
39
+ tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
40
+
41
+ def __init__(
42
+ self,
43
+ image_processor=None,
44
+ tokenizer=None,
45
+ chat_template=None,
46
+ statistics: Optional[dict] = None,
47
+ bin_policy=None,
48
+ intrinsic_config=None,
49
+ action_config=None,
50
+ num_obs_steps=1,
51
+ obs_delta=1,
52
+ action_chunk_size=1,
53
+ min_sigma=0.0,
54
+ **kwargs,
55
+ ):
56
+ if image_processor is None:
57
+ raise ValueError("You need to specify an `image_processor`.")
58
+ if tokenizer is None:
59
+ raise ValueError("You need to specify a `tokenizer`.")
60
+ if not hasattr(image_processor, "image_seq_length"):
61
+ raise ValueError("Image processor is missing an `image_seq_length` attribute.")
62
+
63
+ self.image_seq_length = image_processor.image_seq_length
64
+
65
+ if not hasattr(tokenizer, "image_token"):
66
+ image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
67
+ tokens_to_add = {"additional_special_tokens": [image_token]}
68
+ tokenizer.add_special_tokens(tokens_to_add)
69
+ self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
70
+ else:
71
+ self.image_token_id = tokenizer.image_token_id
72
+
73
+ tokenizer.add_tokens(EXTRA_TOKENS)
74
+ tokenizer.add_bos_token = False
75
+ tokenizer.add_eos_token = False
76
+
77
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
78
+
79
+ # action tokenizer
80
+ self.statistics = statistics if statistics else {}
81
+ self.bin_policy = bin_policy
82
+ self.min_sigma = min_sigma
83
+ self.intrinsic_config = intrinsic_config
84
+ self.action_config = action_config
85
+ self.num_obs_steps = num_obs_steps
86
+ self.obs_delta = obs_delta
87
+ self.action_chunk_size = action_chunk_size
88
+ self.dataset_intrinsics = {}
89
+ height, width = image_processor.size["height"], image_processor.size["width"]
90
+
91
+ # scale intrinsic matrix
92
+ for k, v in intrinsic_config.items():
93
+ K = torch.tensor(v["intrinsic"]).float()
94
+ K[:2] *= torch.tensor([width / v["width"], height / v["height"]])[:, None]
95
+ self.dataset_intrinsics[k] = K
96
+
97
+ self.action_tokenizer = SpatialActionTokenizer(
98
+ tokenizer=tokenizer, num_bins=action_config["num_bins"],
99
+ bin_policy=bin_policy, use_spherical=action_config["use_spherical"],
100
+ min_sigma=min_sigma,
101
+ )
102
+
103
+ def __call__(
104
+ self,
105
+ images: ImageInput = None,
106
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
107
+ unnorm_key: Optional[str] = None,
108
+ suffix_actions: Optional[np.array] = None, # (t e)
109
+ **kwargs: Unpack[PaliGemmaProcessorKwargs],
110
+ ) -> BatchFeature:
111
+ images, text = _validate_images_text_input_order(images, text)
112
+
113
+ output_kwargs = self._merge_kwargs(
114
+ PaliGemmaProcessorKwargs,
115
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
116
+ **kwargs,
117
+ )
118
+ if suffix_actions is not None:
119
+ action_tokens = self.action_tokenizer(suffix_actions) # (n,3)
120
+ suffix="".join(action_tokens.flatten())
121
+ else:
122
+ suffix = output_kwargs["text_kwargs"].pop("suffix", None)
123
+
124
+ return_token_type_ids = True if suffix is not None else False
125
+
126
+ if images is None:
127
+ raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
128
+ if text is None:
129
+ logger.warning_once( "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model.")
130
+ text = ""
131
+
132
+ if _is_str_or_image(text):
133
+ text = [text]
134
+ elif isinstance(text, list) and _is_str_or_image(text[0]):
135
+ pass
136
+
137
+ if text is not None and images is not None:
138
+ if not any(IMAGE_TOKEN in sample for sample in text):
139
+ if isinstance(text, List) and isinstance(images, List):
140
+ if len(images) != len(text):
141
+ raise ValueError(
142
+ f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
143
+ )
144
+ if is_valid_image(images):
145
+ images = [[images]]
146
+ elif isinstance(images, list) and is_valid_image(images[0]):
147
+ images = [[image] for image in images]
148
+ elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
149
+ raise ValueError("images must be an image, list of images or list of list of images")
150
+ if suffix is not None and _is_str_or_image(suffix): suffix = [suffix]
151
+ if suffix is not None: suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
152
+ input_strings = [
153
+ build_string_from_input(
154
+ prompt=prompt,
155
+ bos_token=self.tokenizer.bos_token,
156
+ image_seq_len=self.image_seq_length,
157
+ image_token=IMAGE_TOKEN,
158
+ num_images=len(image_list) if isinstance(image_list, list) else 1,
159
+ )
160
+ for prompt, image_list in zip(text, images)
161
+ ]
162
+ images = make_batched_images(images)
163
+ else:
164
+ expanded_samples = []
165
+ for sample in text:
166
+ expanded_sample = sample.replace(IMAGE_TOKEN, IMAGE_TOKEN * self.image_seq_length)
167
+ bos_rfind_index = expanded_sample.rfind(IMAGE_TOKEN)
168
+ bos_index = bos_rfind_index + len(IMAGE_TOKEN) if bos_rfind_index != -1 else 0
169
+ expanded_sample = (
170
+ expanded_sample[:bos_index] + self.tokenizer.bos_token + expanded_sample[bos_index:]
171
+ )
172
+ expanded_samples.append(expanded_sample)
173
+ input_strings = [f"{sample}\n" for sample in expanded_samples]
174
+ pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
175
+
176
+ if output_kwargs["text_kwargs"].get("max_length", None) is not None:
177
+ output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
178
+
179
+ inputs = self.tokenizer(
180
+ input_strings,
181
+ text_pair=suffix,
182
+ return_token_type_ids=return_token_type_ids,
183
+ **output_kwargs["text_kwargs"],
184
+ )
185
+
186
+ intrinsic = self.dataset_intrinsics[unnorm_key] if unnorm_key in self.dataset_intrinsics else self.dataset_intrinsics["default"]
187
+ return_data = {**inputs, "pixel_values": pixel_values, "intrinsic": intrinsic}
188
+
189
+ if return_token_type_ids:
190
+ labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
191
+ return_data.update({"labels": labels})
192
+ return BatchFeature(data=return_data)
193
+
194
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
195
+ def batch_decode(self, *args, **kwargs):
196
+ """
197
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
198
+ refer to the docstring of this method for more information.
199
+ """
200
+ return self.tokenizer.batch_decode(*args, **kwargs)
201
+
202
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
203
+ def decode(self, *args, **kwargs):
204
+ """
205
+ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
206
+ the docstring of this method for more information.
207
+ """
208
+ return self.tokenizer.decode(*args, **kwargs)
209
+
210
+ @property
211
+ def model_input_names(self):
212
+ tokenizer_input_names = self.tokenizer.model_input_names
213
+ image_processor_input_names = self.image_processor.model_input_names
214
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
215
+
216
+ def decode_actions(
217
+ self,
218
+ generation_outputs: torch.Tensor,
219
+ unnorm_key: Optional[str] = None,
220
+ ) -> Dict[str, torch.Tensor]:
221
+ action_token_num = 3 # translation + rotation + gripper
222
+ predicted_action_token_ids = generation_outputs[0, : action_token_num * self.action_chunk_size].detach().cpu().long().numpy()
223
+ assert self.tokenizer.eos_token != predicted_action_token_ids[-1], "[error] actions contain EOS token, please check you truncation settings!"
224
+
225
+ if predicted_action_token_ids.shape[0] < action_token_num * self.action_chunk_size: # pad with zeros
226
+ logger.warning(f"Padding zero action!")
227
+ predicted_action_token_ids = np.concatenate(
228
+ [
229
+ predicted_action_token_ids,
230
+ np.zeros(action_token_num * self.action_chunk_size - predicted_action_token_ids.shape[0], dtype=np.longlong),
231
+ ]
232
+ )
233
+ predicted_action_token_ids = predicted_action_token_ids.reshape(-1, action_token_num)
234
+ normalized_action_chunks = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids)
235
+
236
+ if unnorm_key is None:
237
+ logger.warning(f"unnorm_key {unnorm_key} is not in statistics, use next one")
238
+ unnorm_key = next(self.statistics.keys())
239
+ action_norm_stats = self.statistics[unnorm_key]["action"]
240
+
241
+ action_dim = len(action_norm_stats["q01"])
242
+ mask = np.array(action_norm_stats.get("mask", np.ones(action_dim)), dtype=bool)
243
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
244
+
245
+ actions = []
246
+ for normalized_actions in normalized_action_chunks:
247
+ action = np.where(
248
+ mask,
249
+ 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
250
+ normalized_actions,
251
+ )
252
+ actions.append(action)
253
+ actions = np.stack(actions)
254
+ return {"actions": actions, "action_ids": predicted_action_token_ids}