import logging from typing import ClassVar, List, Optional import numpy as np import pywt from tokenizers import ByteLevelBPETokenizer from tokenizers.trainers import BpeTrainer from transformers import PreTrainedTokenizerFast from transformers.processing_utils import ProcessorMixin class WaveletActionProcessor(ProcessorMixin): attributes: ClassVar[list[str]] = ["bpe_tokenizer"] bpe_tokenizer_class: str = "AutoTokenizer" def __init__( self, bpe_tokenizer: PreTrainedTokenizerFast, wavelet: str = "db1", level: int = 2, scale: float = 10.0, min_token: int = 0, *, action_dim: Optional[int] = None, time_horizon: Optional[int] = None, ): self.wavelet = wavelet self.level = level self.scale = scale self.min_token = int(min_token) # Used for decode (same logic as FAST) self.time_horizon = time_horizon self.action_dim = action_dim self.called_time_horizon = time_horizon self.called_action_dim = action_dim # Cache wavelet coefficient layout needed for decoding # We keep one slice-structure per dimension (they are typically identical for fixed T/wavelet/level) self._coeff_slices_per_dim = None # list of slice dicts self._n_coeff = None # number of wavelet coeffs per dim after coeffs_to_array super().__init__(bpe_tokenizer) def _ensure_coeff_layout(self, T: int, D: int): """Cache coeff slices and coeff vector length for given (T, wavelet, level).""" if ( self._coeff_slices_per_dim is not None and self._n_coeff is not None and self.called_time_horizon == T and self.called_action_dim == D ): return dummy = np.zeros(T, dtype=np.float32) slices_per_dim = [] n_coeff = None for _ in range(D): coeffs = pywt.wavedec(dummy, self.wavelet, level=self.level) arr, slc = pywt.coeffs_to_array(coeffs) slices_per_dim.append(slc) if n_coeff is None: n_coeff = int(arr.shape[0]) self._coeff_slices_per_dim = slices_per_dim self._n_coeff = n_coeff def __call__(self, action_chunk: np.ndarray) -> List[List[int]]: """ Encode actions to BPE tokens. action_chunk: (T,D) or (B,T,D) returns: List[List[int]] (batch of token id lists) """ assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]" if action_chunk.ndim == 2: action_chunk = action_chunk[None, ...] B, T, D = action_chunk.shape # cache for decoding self.called_time_horizon, self.called_action_dim = T, D self._ensure_coeff_layout(T, D) batch_tokens: List[List[int]] = [] for i in range(B): # wavelet per dim -> flattened coeffs of length (n_coeff * D) coeffs_by_dim = [] for d in range(D): coeffs = pywt.wavedec(action_chunk[i, :, d], self.wavelet, level=self.level) flat, _ = pywt.coeffs_to_array(coeffs) # shape (n_coeff,) coeffs_by_dim.append(flat) coeff_mat = np.stack(coeffs_by_dim, axis=1) # (n_coeff, D) flat_all = coeff_mat.reshape(-1) # (n_coeff * D,) quant = np.around(flat_all * self.scale).astype(int) shifted = (quant - self.min_token).astype(int) # Optional safety check (unicode range). Keep it simple: if shifted.min() < 0: # This means min_token was not low enough for these coeffs. raise ValueError( f"Shifted tokens became negative (min={shifted.min()}). " f"Your min_token={self.min_token} is too high. Re-fit or lower min_token." ) if shifted.max() > 0x10FFFF: raise ValueError( f"Shifted tokens exceed Unicode max (max={shifted.max()}). " f"Reduce scale or re-fit min/max range." ) token_str = "".join(chr(int(x)) for x in shifted) batch_tokens.append(self.bpe_tokenizer(token_str)["input_ids"]) return batch_tokens def decode( self, tokens: List[List[int]], *, time_horizon: Optional[int] = None, action_dim: Optional[int] = None, ) -> np.ndarray: """ Decode BPE tokens back to actions. tokens: List[List[int]] (batch) returns: (B, T, D) """ T = time_horizon or self.time_horizon or self.called_time_horizon D = action_dim or self.action_dim or self.called_action_dim assert T is not None and D is not None, ( "Tokenizer not initialized: call encode() once or pass time_horizon and action_dim." ) # cache for next call + ensure layout self.time_horizon, self.action_dim = T, D self.called_time_horizon, self.called_action_dim = T, D self._ensure_coeff_layout(T, D) decoded_actions = [] for tok_list in tokens: # decode to string of chars s = self.bpe_tokenizer.decode(tok_list, clean_up_tokenization_spaces=False) ints = np.array([ord(c) for c in s], dtype=np.int64) # unshift + dequantize quant = ints + self.min_token flat_coeffs = quant.astype(np.float32) / self.scale # (n_coeff * D,) # reshape to (n_coeff, D) expected = self._n_coeff * D if flat_coeffs.shape[0] != expected: raise ValueError( f"Decoded coeff length mismatch: got {flat_coeffs.shape[0]}, expected {expected}. " f"(T={T}, D={D}, n_coeff={self._n_coeff}). " "This usually means you decoded with different T/D than encoding." ) coeff_mat = flat_coeffs.reshape(self._n_coeff, D) # inverse wavelet per dimension recon = np.zeros((T, D), dtype=np.float32) for d in range(D): arr = coeff_mat[:, d] coeff_list = pywt.array_to_coeffs( arr, self._coeff_slices_per_dim[d], output_format="wavedec", ) sig = pywt.waverec(coeff_list, self.wavelet) recon[:, d] = sig[:T] # waverec can return a bit longer due to padding decoded_actions.append(recon) return np.stack(decoded_actions, axis=0) @classmethod def fit( cls, action_data: List[np.ndarray], # each (T,D) wavelet: str = "db1", level: int = 2, scale: float = 10.0, vocab_size: int = 1024, *, time_horizon: Optional[int] = None, action_dim: Optional[int] = None, ) -> "WaveletActionProcessor": """ Fit BPE tokenizer on wavelet-quantized coefficient streams. """ # Compute quantized coefficient streams to estimate min/max token range all_streams = [] for a in action_data: assert a.ndim == 2, "Each item must be (T,D)" T, D = a.shape # wavelet per dim -> flatten (n_coeff * D) coeffs_by_dim = [] for d in range(D): coeffs = pywt.wavedec(a[:, d], wavelet, level=level) flat, _ = pywt.coeffs_to_array(coeffs) coeffs_by_dim.append(flat) coeff_mat = np.stack(coeffs_by_dim, axis=1) stream = np.around(coeff_mat.reshape(-1) * scale).astype(int) all_streams.append(stream) all_vals = np.concatenate(all_streams) min_token = int(all_vals.min()) max_token = int(all_vals.max()) token_range = max_token - min_token + 1 if token_range > vocab_size: raise ValueError( f"Vocab size {vocab_size} too small for token range {token_range}. " "Increase vocab_size or reduce scale." ) if token_range + 100 > vocab_size: logging.warning( f"Initial alphabet size {token_range} is close to vocab_size {vocab_size}. " "Consider increasing vocab_size for better BPE merges." ) def _token_iter(): for stream in all_streams: shifted = (stream - min_token).astype(int) # no clamp; must be >=0 yield "".join(chr(int(x)) for x in shifted) # Train BPE bpe = ByteLevelBPETokenizer() alphabet = [chr(i) for i in range(token_range)] trainer = BpeTrainer( vocab_size=vocab_size, min_frequency=2, show_progress=True, special_tokens=[], initial_alphabet=alphabet, max_token_length=10000, ) bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) # infer T/D defaults if not provided if time_horizon is None: time_horizon = int(action_data[0].shape[0]) if action_dim is None: action_dim = int(action_data[0].shape[1]) return cls( PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False), wavelet=wavelet, level=level, scale=scale, min_token=min_token, time_horizon=time_horizon, action_dim=action_dim, )