| import logging |
| from typing import ClassVar, List, Optional |
|
|
| import numpy as np |
| import pywt |
| from tokenizers import ByteLevelBPETokenizer |
| from tokenizers.trainers import BpeTrainer |
| from transformers import PreTrainedTokenizerFast |
| from transformers.processing_utils import ProcessorMixin |
|
|
|
|
| class WaveletActionProcessor(ProcessorMixin): |
| attributes: ClassVar[list[str]] = ["bpe_tokenizer"] |
| bpe_tokenizer_class: str = "AutoTokenizer" |
|
|
| def __init__( |
| self, |
| bpe_tokenizer: PreTrainedTokenizerFast, |
| wavelet: str = "db1", |
| level: int = 2, |
| scale: float = 10.0, |
| min_token: int = 0, |
| *, |
| action_dim: Optional[int] = None, |
| time_horizon: Optional[int] = None, |
| ): |
| self.wavelet = wavelet |
| self.level = level |
| self.scale = scale |
| self.min_token = int(min_token) |
|
|
| |
| self.time_horizon = time_horizon |
| self.action_dim = action_dim |
| self.called_time_horizon = time_horizon |
| self.called_action_dim = action_dim |
|
|
| |
| |
| self._coeff_slices_per_dim = None |
| self._n_coeff = None |
|
|
| super().__init__(bpe_tokenizer) |
|
|
| def _ensure_coeff_layout(self, T: int, D: int): |
| """Cache coeff slices and coeff vector length for given (T, wavelet, level).""" |
| if ( |
| self._coeff_slices_per_dim is not None |
| and self._n_coeff is not None |
| and self.called_time_horizon == T |
| and self.called_action_dim == D |
| ): |
| return |
|
|
| dummy = np.zeros(T, dtype=np.float32) |
|
|
| slices_per_dim = [] |
| n_coeff = None |
| for _ in range(D): |
| coeffs = pywt.wavedec(dummy, self.wavelet, level=self.level) |
| arr, slc = pywt.coeffs_to_array(coeffs) |
| slices_per_dim.append(slc) |
| if n_coeff is None: |
| n_coeff = int(arr.shape[0]) |
|
|
| self._coeff_slices_per_dim = slices_per_dim |
| self._n_coeff = n_coeff |
|
|
| def __call__(self, action_chunk: np.ndarray) -> List[List[int]]: |
| """ |
| Encode actions to BPE tokens. |
| |
| action_chunk: (T,D) or (B,T,D) |
| returns: List[List[int]] (batch of token id lists) |
| """ |
| assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]" |
| if action_chunk.ndim == 2: |
| action_chunk = action_chunk[None, ...] |
|
|
| B, T, D = action_chunk.shape |
|
|
| |
| self.called_time_horizon, self.called_action_dim = T, D |
| self._ensure_coeff_layout(T, D) |
|
|
| batch_tokens: List[List[int]] = [] |
| for i in range(B): |
| |
| coeffs_by_dim = [] |
| for d in range(D): |
| coeffs = pywt.wavedec(action_chunk[i, :, d], self.wavelet, level=self.level) |
| flat, _ = pywt.coeffs_to_array(coeffs) |
| coeffs_by_dim.append(flat) |
|
|
| coeff_mat = np.stack(coeffs_by_dim, axis=1) |
| flat_all = coeff_mat.reshape(-1) |
|
|
| quant = np.around(flat_all * self.scale).astype(int) |
|
|
| shifted = (quant - self.min_token).astype(int) |
|
|
| |
| if shifted.min() < 0: |
| |
| raise ValueError( |
| f"Shifted tokens became negative (min={shifted.min()}). " |
| f"Your min_token={self.min_token} is too high. Re-fit or lower min_token." |
| ) |
| if shifted.max() > 0x10FFFF: |
| raise ValueError( |
| f"Shifted tokens exceed Unicode max (max={shifted.max()}). " |
| f"Reduce scale or re-fit min/max range." |
| ) |
|
|
| token_str = "".join(chr(int(x)) for x in shifted) |
| batch_tokens.append(self.bpe_tokenizer(token_str)["input_ids"]) |
|
|
| return batch_tokens |
|
|
| def decode( |
| self, |
| tokens: List[List[int]], |
| *, |
| time_horizon: Optional[int] = None, |
| action_dim: Optional[int] = None, |
| ) -> np.ndarray: |
| """ |
| Decode BPE tokens back to actions. |
| |
| tokens: List[List[int]] (batch) |
| returns: (B, T, D) |
| """ |
| T = time_horizon or self.time_horizon or self.called_time_horizon |
| D = action_dim or self.action_dim or self.called_action_dim |
|
|
| assert T is not None and D is not None, ( |
| "Tokenizer not initialized: call encode() once or pass time_horizon and action_dim." |
| ) |
|
|
| |
| self.time_horizon, self.action_dim = T, D |
| self.called_time_horizon, self.called_action_dim = T, D |
| self._ensure_coeff_layout(T, D) |
|
|
| decoded_actions = [] |
| for tok_list in tokens: |
| |
| s = self.bpe_tokenizer.decode(tok_list, clean_up_tokenization_spaces=False) |
|
|
| ints = np.array([ord(c) for c in s], dtype=np.int64) |
|
|
| |
| quant = ints + self.min_token |
| flat_coeffs = quant.astype(np.float32) / self.scale |
|
|
| |
| expected = self._n_coeff * D |
| if flat_coeffs.shape[0] != expected: |
| raise ValueError( |
| f"Decoded coeff length mismatch: got {flat_coeffs.shape[0]}, expected {expected}. " |
| f"(T={T}, D={D}, n_coeff={self._n_coeff}). " |
| "This usually means you decoded with different T/D than encoding." |
| ) |
|
|
| coeff_mat = flat_coeffs.reshape(self._n_coeff, D) |
|
|
| |
| recon = np.zeros((T, D), dtype=np.float32) |
| for d in range(D): |
| arr = coeff_mat[:, d] |
| coeff_list = pywt.array_to_coeffs( |
| arr, |
| self._coeff_slices_per_dim[d], |
| output_format="wavedec", |
| ) |
| sig = pywt.waverec(coeff_list, self.wavelet) |
| recon[:, d] = sig[:T] |
|
|
| decoded_actions.append(recon) |
|
|
| return np.stack(decoded_actions, axis=0) |
|
|
| @classmethod |
| def fit( |
| cls, |
| action_data: List[np.ndarray], |
| wavelet: str = "db1", |
| level: int = 2, |
| scale: float = 10.0, |
| vocab_size: int = 1024, |
| *, |
| time_horizon: Optional[int] = None, |
| action_dim: Optional[int] = None, |
| ) -> "WaveletActionProcessor": |
| """ |
| Fit BPE tokenizer on wavelet-quantized coefficient streams. |
| """ |
|
|
| |
| all_streams = [] |
| for a in action_data: |
| assert a.ndim == 2, "Each item must be (T,D)" |
| T, D = a.shape |
| |
| coeffs_by_dim = [] |
| for d in range(D): |
| coeffs = pywt.wavedec(a[:, d], wavelet, level=level) |
| flat, _ = pywt.coeffs_to_array(coeffs) |
| coeffs_by_dim.append(flat) |
| coeff_mat = np.stack(coeffs_by_dim, axis=1) |
| stream = np.around(coeff_mat.reshape(-1) * scale).astype(int) |
| all_streams.append(stream) |
|
|
| all_vals = np.concatenate(all_streams) |
| min_token = int(all_vals.min()) |
| max_token = int(all_vals.max()) |
|
|
| token_range = max_token - min_token + 1 |
| if token_range > vocab_size: |
| raise ValueError( |
| f"Vocab size {vocab_size} too small for token range {token_range}. " |
| "Increase vocab_size or reduce scale." |
| ) |
| if token_range + 100 > vocab_size: |
| logging.warning( |
| f"Initial alphabet size {token_range} is close to vocab_size {vocab_size}. " |
| "Consider increasing vocab_size for better BPE merges." |
| ) |
|
|
| def _token_iter(): |
| for stream in all_streams: |
| shifted = (stream - min_token).astype(int) |
| |
| yield "".join(chr(int(x)) for x in shifted) |
|
|
| |
| bpe = ByteLevelBPETokenizer() |
| alphabet = [chr(i) for i in range(token_range)] |
| trainer = BpeTrainer( |
| vocab_size=vocab_size, |
| min_frequency=2, |
| show_progress=True, |
| special_tokens=[], |
| initial_alphabet=alphabet, |
| max_token_length=10000, |
| ) |
| bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) |
|
|
| |
| if time_horizon is None: |
| time_horizon = int(action_data[0].shape[0]) |
| if action_dim is None: |
| action_dim = int(action_data[0].shape[1]) |
|
|
| return cls( |
| PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False), |
| wavelet=wavelet, |
| level=level, |
| scale=scale, |
| min_token=min_token, |
| time_horizon=time_horizon, |
| action_dim=action_dim, |
| ) |
|
|