jadechoghari commited on
Commit
393ff84
·
verified ·
1 Parent(s): 44c81f0

Upload FAST tokenizer trained on lerobot/libero_video

Browse files
processing_action_tokenizer.py CHANGED
@@ -1,140 +1,241 @@
1
  import logging
2
- from typing import ClassVar
3
 
4
  import numpy as np
5
- from scipy.fft import dct
6
- from scipy.fft import idct
7
  from tokenizers import ByteLevelBPETokenizer
8
  from tokenizers.trainers import BpeTrainer
9
  from transformers import PreTrainedTokenizerFast
10
  from transformers.processing_utils import ProcessorMixin
11
 
12
 
13
- class UniversalActionProcessor(ProcessorMixin):
14
  attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
15
  bpe_tokenizer_class: str = "AutoTokenizer"
16
 
17
  def __init__(
18
  self,
19
  bpe_tokenizer: PreTrainedTokenizerFast,
20
- scale: float = 10,
21
- vocab_size: int = 1024,
 
22
  min_token: int = 0,
23
  *,
24
- action_dim: int | None = None,
25
- time_horizon: int | None = None,
26
  ):
 
 
27
  self.scale = scale
28
- self.vocab_size = vocab_size
29
- self.min_token = min_token
30
-
31
- # Action horizon and dimension needed during decoding. These can be specified
32
- # in three ways (in order of priority):
33
- # 1. passed in as kwargs to decode()
34
- # 2. in the constructor
35
- # 3. cached from the last time decode() was called
36
  self.time_horizon = time_horizon
37
  self.action_dim = action_dim
38
  self.called_time_horizon = time_horizon
39
  self.called_action_dim = action_dim
40
 
 
 
 
 
 
41
  super().__init__(bpe_tokenizer)
42
 
43
- def __call__(self, action_chunk: np.array) -> np.array:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
45
  if action_chunk.ndim == 2:
46
  action_chunk = action_chunk[None, ...]
47
 
48
- # Cache the time horizon and action dimension for decoding
49
- self.called_time_horizon = action_chunk.shape[-2]
50
- self.called_action_dim = action_chunk.shape[-1]
51
 
52
- dct_coeff = dct(action_chunk, axis=1, norm="ortho")
53
- dct_coeff = np.around(dct_coeff * self.scale)
54
- tokens = []
55
- for elem in dct_coeff:
56
- token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
57
- tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
58
- return tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def decode(
61
  self,
62
- tokens: list[list[int]],
63
  *,
64
- time_horizon: int | None = None,
65
- action_dim: int | None = None,
66
- ) -> np.array:
67
- self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
68
- self.action_dim = action_dim or self.action_dim or self.called_action_dim
 
 
 
 
 
 
69
 
70
- # Cache the time horizon and action dimension for the next call
71
- self.called_time_horizon = self.time_horizon
72
- self.called_action_dim = self.action_dim
73
 
74
- assert (
75
- self.time_horizon is not None and self.action_dim is not None
76
- ), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
 
77
 
78
  decoded_actions = []
79
- for token in tokens:
80
- try:
81
- decoded_tokens = self.bpe_tokenizer.decode(token)
82
- decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
83
- decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
84
- assert (
85
- decoded_dct_coeff.shape
86
- == (
87
- self.time_horizon,
88
- self.action_dim,
89
- )
90
- ), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
91
- except Exception as e:
92
- print(f"Error decoding tokens: {e}")
93
- print(f"Tokens: {token}")
94
- decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
95
- decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
96
- return np.stack(decoded_actions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  @classmethod
99
  def fit(
100
  cls,
101
- action_data: list[np.array],
102
- scale: float = 10,
 
 
103
  vocab_size: int = 1024,
104
  *,
105
- time_horizon: int | None = None,
106
- action_dim: int | None = None,
107
- ) -> "UniversalActionProcessor":
108
- # Run DCT over all inputs
109
- dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
110
-
111
- # Quantize and find min token
112
- max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
113
- min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
114
- min_vocab_size = max_token - min_token
115
-
116
- assert (
117
- min_vocab_size <= vocab_size
118
- ), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
119
- if min_vocab_size + 100 > vocab_size:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  logging.warning(
121
- f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
122
- f"size {vocab_size}, consider increasing vocab size"
123
  )
124
 
125
- # Make token iterator for BPE training
126
  def _token_iter():
127
- for tokens in dct_tokens:
128
- rounded_tokens = np.around(tokens * scale) - min_token
129
- rounded_tokens = rounded_tokens.astype(int)
130
- string = "".join(map(chr, rounded_tokens))
131
- yield string
132
 
133
- # Train BPE tokenizer
134
  bpe = ByteLevelBPETokenizer()
135
-
136
- # Set up the entire range of possible tokens as the initial alphabet
137
- alphabet = [chr(i) for i in range(max_token - min_token + 1)]
138
  trainer = BpeTrainer(
139
  vocab_size=vocab_size,
140
  min_frequency=2,
@@ -143,15 +244,19 @@ class UniversalActionProcessor(ProcessorMixin):
143
  initial_alphabet=alphabet,
144
  max_token_length=10000,
145
  )
146
-
147
- # Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
148
- # because it doesn't support custom alphabets)
149
  bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
150
 
 
 
 
 
 
 
151
  return cls(
152
  PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
 
 
153
  scale=scale,
154
- vocab_size=vocab_size,
155
  min_token=min_token,
156
  time_horizon=time_horizon,
157
  action_dim=action_dim,
 
1
  import logging
2
+ from typing import ClassVar, List, Optional
3
 
4
  import numpy as np
5
+ import pywt
 
6
  from tokenizers import ByteLevelBPETokenizer
7
  from tokenizers.trainers import BpeTrainer
8
  from transformers import PreTrainedTokenizerFast
9
  from transformers.processing_utils import ProcessorMixin
10
 
11
 
12
+ class WaveletActionProcessor(ProcessorMixin):
13
  attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
14
  bpe_tokenizer_class: str = "AutoTokenizer"
15
 
16
  def __init__(
17
  self,
18
  bpe_tokenizer: PreTrainedTokenizerFast,
19
+ wavelet: str = "db1",
20
+ level: int = 2,
21
+ scale: float = 10.0,
22
  min_token: int = 0,
23
  *,
24
+ action_dim: Optional[int] = None,
25
+ time_horizon: Optional[int] = None,
26
  ):
27
+ self.wavelet = wavelet
28
+ self.level = level
29
  self.scale = scale
30
+ self.min_token = int(min_token)
31
+
32
+ # Used for decode (same logic as FAST)
 
 
 
 
 
33
  self.time_horizon = time_horizon
34
  self.action_dim = action_dim
35
  self.called_time_horizon = time_horizon
36
  self.called_action_dim = action_dim
37
 
38
+ # Cache wavelet coefficient layout needed for decoding
39
+ # We keep one slice-structure per dimension (they are typically identical for fixed T/wavelet/level)
40
+ self._coeff_slices_per_dim = None # list of slice dicts
41
+ self._n_coeff = None # number of wavelet coeffs per dim after coeffs_to_array
42
+
43
  super().__init__(bpe_tokenizer)
44
 
45
+ def _ensure_coeff_layout(self, T: int, D: int):
46
+ """Cache coeff slices and coeff vector length for given (T, wavelet, level)."""
47
+ if (
48
+ self._coeff_slices_per_dim is not None
49
+ and self._n_coeff is not None
50
+ and self.called_time_horizon == T
51
+ and self.called_action_dim == D
52
+ ):
53
+ return
54
+
55
+ dummy = np.zeros(T, dtype=np.float32)
56
+
57
+ slices_per_dim = []
58
+ n_coeff = None
59
+ for _ in range(D):
60
+ coeffs = pywt.wavedec(dummy, self.wavelet, level=self.level)
61
+ arr, slc = pywt.coeffs_to_array(coeffs)
62
+ slices_per_dim.append(slc)
63
+ if n_coeff is None:
64
+ n_coeff = int(arr.shape[0])
65
+
66
+ self._coeff_slices_per_dim = slices_per_dim
67
+ self._n_coeff = n_coeff
68
+
69
+ def __call__(self, action_chunk: np.ndarray) -> List[List[int]]:
70
+ """
71
+ Encode actions to BPE tokens.
72
+
73
+ action_chunk: (T,D) or (B,T,D)
74
+ returns: List[List[int]] (batch of token id lists)
75
+ """
76
  assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
77
  if action_chunk.ndim == 2:
78
  action_chunk = action_chunk[None, ...]
79
 
80
+ B, T, D = action_chunk.shape
 
 
81
 
82
+ # cache for decoding
83
+ self.called_time_horizon, self.called_action_dim = T, D
84
+ self._ensure_coeff_layout(T, D)
85
+
86
+ batch_tokens: List[List[int]] = []
87
+ for i in range(B):
88
+ # wavelet per dim -> flattened coeffs of length (n_coeff * D)
89
+ coeffs_by_dim = []
90
+ for d in range(D):
91
+ coeffs = pywt.wavedec(action_chunk[i, :, d], self.wavelet, level=self.level)
92
+ flat, _ = pywt.coeffs_to_array(coeffs) # shape (n_coeff,)
93
+ coeffs_by_dim.append(flat)
94
+
95
+ coeff_mat = np.stack(coeffs_by_dim, axis=1) # (n_coeff, D)
96
+ flat_all = coeff_mat.reshape(-1) # (n_coeff * D,)
97
+
98
+ quant = np.around(flat_all * self.scale).astype(int)
99
+
100
+ shifted = (quant - self.min_token).astype(int)
101
+
102
+ # Optional safety check (unicode range). Keep it simple:
103
+ if shifted.min() < 0:
104
+ # This means min_token was not low enough for these coeffs.
105
+ raise ValueError(
106
+ f"Shifted tokens became negative (min={shifted.min()}). "
107
+ f"Your min_token={self.min_token} is too high. Re-fit or lower min_token."
108
+ )
109
+ if shifted.max() > 0x10FFFF:
110
+ raise ValueError(
111
+ f"Shifted tokens exceed Unicode max (max={shifted.max()}). "
112
+ f"Reduce scale or re-fit min/max range."
113
+ )
114
+
115
+ token_str = "".join(chr(int(x)) for x in shifted)
116
+ batch_tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
117
+
118
+ return batch_tokens
119
 
120
  def decode(
121
  self,
122
+ tokens: List[List[int]],
123
  *,
124
+ time_horizon: Optional[int] = None,
125
+ action_dim: Optional[int] = None,
126
+ ) -> np.ndarray:
127
+ """
128
+ Decode BPE tokens back to actions.
129
+
130
+ tokens: List[List[int]] (batch)
131
+ returns: (B, T, D)
132
+ """
133
+ T = time_horizon or self.time_horizon or self.called_time_horizon
134
+ D = action_dim or self.action_dim or self.called_action_dim
135
 
136
+ assert T is not None and D is not None, (
137
+ "Tokenizer not initialized: call encode() once or pass time_horizon and action_dim."
138
+ )
139
 
140
+ # cache for next call + ensure layout
141
+ self.time_horizon, self.action_dim = T, D
142
+ self.called_time_horizon, self.called_action_dim = T, D
143
+ self._ensure_coeff_layout(T, D)
144
 
145
  decoded_actions = []
146
+ for tok_list in tokens:
147
+ # decode to string of chars
148
+ s = self.bpe_tokenizer.decode(tok_list, clean_up_tokenization_spaces=False)
149
+
150
+ ints = np.array([ord(c) for c in s], dtype=np.int64)
151
+
152
+ # unshift + dequantize
153
+ quant = ints + self.min_token
154
+ flat_coeffs = quant.astype(np.float32) / self.scale # (n_coeff * D,)
155
+
156
+ # reshape to (n_coeff, D)
157
+ expected = self._n_coeff * D
158
+ if flat_coeffs.shape[0] != expected:
159
+ raise ValueError(
160
+ f"Decoded coeff length mismatch: got {flat_coeffs.shape[0]}, expected {expected}. "
161
+ f"(T={T}, D={D}, n_coeff={self._n_coeff}). "
162
+ "This usually means you decoded with different T/D than encoding."
163
+ )
164
+
165
+ coeff_mat = flat_coeffs.reshape(self._n_coeff, D)
166
+
167
+ # inverse wavelet per dimension
168
+ recon = np.zeros((T, D), dtype=np.float32)
169
+ for d in range(D):
170
+ arr = coeff_mat[:, d]
171
+ coeff_list = pywt.array_to_coeffs(
172
+ arr,
173
+ self._coeff_slices_per_dim[d],
174
+ output_format="wavedec",
175
+ )
176
+ sig = pywt.waverec(coeff_list, self.wavelet)
177
+ recon[:, d] = sig[:T] # waverec can return a bit longer due to padding
178
+
179
+ decoded_actions.append(recon)
180
+
181
+ return np.stack(decoded_actions, axis=0)
182
 
183
  @classmethod
184
  def fit(
185
  cls,
186
+ action_data: List[np.ndarray], # each (T,D)
187
+ wavelet: str = "db1",
188
+ level: int = 2,
189
+ scale: float = 10.0,
190
  vocab_size: int = 1024,
191
  *,
192
+ time_horizon: Optional[int] = None,
193
+ action_dim: Optional[int] = None,
194
+ ) -> "WaveletActionProcessor":
195
+ """
196
+ Fit BPE tokenizer on wavelet-quantized coefficient streams.
197
+ """
198
+
199
+ # Compute quantized coefficient streams to estimate min/max token range
200
+ all_streams = []
201
+ for a in action_data:
202
+ assert a.ndim == 2, "Each item must be (T,D)"
203
+ T, D = a.shape
204
+ # wavelet per dim -> flatten (n_coeff * D)
205
+ coeffs_by_dim = []
206
+ for d in range(D):
207
+ coeffs = pywt.wavedec(a[:, d], wavelet, level=level)
208
+ flat, _ = pywt.coeffs_to_array(coeffs)
209
+ coeffs_by_dim.append(flat)
210
+ coeff_mat = np.stack(coeffs_by_dim, axis=1)
211
+ stream = np.around(coeff_mat.reshape(-1) * scale).astype(int)
212
+ all_streams.append(stream)
213
+
214
+ all_vals = np.concatenate(all_streams)
215
+ min_token = int(all_vals.min())
216
+ max_token = int(all_vals.max())
217
+
218
+ token_range = max_token - min_token + 1
219
+ if token_range > vocab_size:
220
+ raise ValueError(
221
+ f"Vocab size {vocab_size} too small for token range {token_range}. "
222
+ "Increase vocab_size or reduce scale."
223
+ )
224
+ if token_range + 100 > vocab_size:
225
  logging.warning(
226
+ f"Initial alphabet size {token_range} is close to vocab_size {vocab_size}. "
227
+ "Consider increasing vocab_size for better BPE merges."
228
  )
229
 
 
230
  def _token_iter():
231
+ for stream in all_streams:
232
+ shifted = (stream - min_token).astype(int)
233
+ # no clamp; must be >=0
234
+ yield "".join(chr(int(x)) for x in shifted)
 
235
 
236
+ # Train BPE
237
  bpe = ByteLevelBPETokenizer()
238
+ alphabet = [chr(i) for i in range(token_range)]
 
 
239
  trainer = BpeTrainer(
240
  vocab_size=vocab_size,
241
  min_frequency=2,
 
244
  initial_alphabet=alphabet,
245
  max_token_length=10000,
246
  )
 
 
 
247
  bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
248
 
249
+ # infer T/D defaults if not provided
250
+ if time_horizon is None:
251
+ time_horizon = int(action_data[0].shape[0])
252
+ if action_dim is None:
253
+ action_dim = int(action_data[0].shape[1])
254
+
255
  return cls(
256
  PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
257
+ wavelet=wavelet,
258
+ level=level,
259
  scale=scale,
 
260
  min_token=min_token,
261
  time_horizon=time_horizon,
262
  action_dim=action_dim,
processor_config.json CHANGED
@@ -1,11 +1,12 @@
1
  {
2
  "action_dim": 6,
3
  "auto_map": {
4
- "AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
5
  },
6
- "min_token": -32,
7
- "processor_class": "UniversalActionProcessor",
 
8
  "scale": 10.0,
9
  "time_horizon": 10,
10
- "vocab_size": 1024
11
  }
 
1
  {
2
  "action_dim": 6,
3
  "auto_map": {
4
+ "AutoProcessor": "processing_action_tokenizer.WaveletActionProcessor"
5
  },
6
+ "level": 2,
7
+ "min_token": -20,
8
+ "processor_class": "WaveletActionProcessor",
9
  "scale": 10.0,
10
  "time_horizon": 10,
11
+ "wavelet": "db1"
12
  }
tokenizer.json CHANGED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
  "added_tokens_decoder": {},
3
  "auto_map": {
4
- "AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
5
  },
6
  "clean_up_tokenization_spaces": false,
7
  "extra_special_tokens": {},
8
  "model_max_length": 1000000000000000019884624838656,
9
- "processor_class": "UniversalActionProcessor",
10
  "tokenizer_class": "PreTrainedTokenizerFast"
11
  }
 
1
  {
2
  "added_tokens_decoder": {},
3
  "auto_map": {
4
+ "AutoProcessor": "processing_action_tokenizer.WaveletActionProcessor"
5
  },
6
  "clean_up_tokenization_spaces": false,
7
  "extra_special_tokens": {},
8
  "model_max_length": 1000000000000000019884624838656,
9
+ "processor_class": "WaveletActionProcessor",
10
  "tokenizer_class": "PreTrainedTokenizerFast"
11
  }