# processing_action_tokenizer_residual.py import logging from typing import ClassVar, Iterable import numpy as np from scipy.fft import dct, idct from tokenizers import ByteLevelBPETokenizer from tokenizers.trainers import BpeTrainer from transformers import PreTrainedTokenizerFast from transformers.processing_utils import ProcessorMixin class ResidualFASTActionProcessor(ProcessorMixin): """ Residual FAST: intent + residual tokenization built on top of FAST's DCT+BPE scheme. Encodes an action chunk (B, T, D) into tokens by: 1) DCT over time axis 2) Split coeffs: intent = coeff[:k_intent, :] residual= coeff[k_intent:, :] 3) Quantize to ints 4) Convert to a string of characters via chr(int - min_token) 5) Wrap with special markers: ... ... 6) BPE-tokenize the resulting string Decoding reverses the above. Notes: - Assumes input actions are already normalized to roughly [-1, 1] (same as FAST). - Uses a single BPE tokenizer to keep the interface identical to FAST. - Markers are special tokens so decode can reliably split streams. """ attributes: ClassVar[list[str]] = ["bpe_tokenizer"] bpe_tokenizer_class: str = "AutoTokenizer" INTENT_MARKER = "" RESIDUAL_MARKER = "" def __init__( self, bpe_tokenizer: PreTrainedTokenizerFast, *, k_intent: int = 5, scale: float = 10.0, vocab_size: int = 1024, min_token: int = 0, action_dim: int | None = None, time_horizon: int | None = None, ): self.k_intent = int(k_intent) self.scale = float(scale) self.vocab_size = int(vocab_size) self.min_token = int(min_token) # Needed for decoding self.time_horizon = time_horizon self.action_dim = action_dim self.called_time_horizon = time_horizon self.called_action_dim = action_dim # Ensure markers exist as special tokens in the tokenizer (robust) self._ensure_special_tokens(bpe_tokenizer) super().__init__(bpe_tokenizer) @staticmethod def _ensure_special_tokens(tok: PreTrainedTokenizerFast) -> None: special = set(tok.all_special_tokens) to_add = [] if ResidualFASTActionProcessor.INTENT_MARKER not in special: to_add.append(ResidualFASTActionProcessor.INTENT_MARKER) if ResidualFASTActionProcessor.RESIDUAL_MARKER not in special: to_add.append(ResidualFASTActionProcessor.RESIDUAL_MARKER) if to_add: tok.add_special_tokens({"additional_special_tokens": to_add}) def __call__(self, action_chunk: np.ndarray) -> list[list[int]]: """ action_chunk: np.ndarray with shape (T, D) or (B, T, D) returns: list of token-id lists, length B """ assert action_chunk.ndim <= 3, "Only up to 3 dims supported: [batch, timesteps, action_dim]" if action_chunk.ndim == 2: action_chunk = action_chunk[None, ...] B, T, D = action_chunk.shape if self.k_intent < 0 or self.k_intent > T: raise ValueError(f"k_intent must be in [0, T]. Got k_intent={self.k_intent}, T={T}") # Cache dimensions for decode self.called_time_horizon = T self.called_action_dim = D # DCT over time axis (axis=1) coeff = dct(action_chunk, axis=1, norm="ortho") # (B, T, D) # Split frequencies intent_coeff = coeff[:, : self.k_intent, :] # (B, K, D) residual_coeff = coeff[:, self.k_intent :, :] # (B, T-K, D) # Quantize intent_q = np.around(intent_coeff * self.scale).astype(int) residual_q = np.around(residual_coeff * self.scale).astype(int) tokens: list[list[int]] = [] for b in range(B): # Convert quantized ints to chars (shifted by min_token) intent_chars = "".join( map(chr, np.maximum(intent_q[b].flatten() - self.min_token, 0).astype(int)) ) residual_chars = "".join( map(chr, np.maximum(residual_q[b].flatten() - self.min_token, 0).astype(int)) ) # Insert markers; remove any whitespace in tokenizer decode later, so no need to add separators token_str = f"{self.INTENT_MARKER}{intent_chars}{self.RESIDUAL_MARKER}{residual_chars}" # IMPORTANT: add_special_tokens=False so we don't inject BOS/EOS etc. ids = self.bpe_tokenizer(token_str, add_special_tokens=False)["input_ids"] tokens.append(ids) return tokens def decode( self, tokens: list[list[int]], *, time_horizon: int | None = None, action_dim: int | None = None, k_intent: int | None = None, ) -> np.ndarray: """ tokens: list of token-id lists (batch) returns: np.ndarray (B, T, D) """ self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon self.action_dim = action_dim or self.action_dim or self.called_action_dim K = int(k_intent) if k_intent is not None else self.k_intent # Cache for next call self.called_time_horizon = self.time_horizon self.called_action_dim = self.action_dim assert self.time_horizon is not None and self.action_dim is not None, ( "Tokenizer not initialized. Call encode() once or pass time_horizon and action_dim." ) T = int(self.time_horizon) D = int(self.action_dim) if K < 0 or K > T: raise ValueError(f"k_intent must be in [0, T]. Got k_intent={K}, T={T}") decoded_actions = [] for token_ids in tokens: try: # Decode back to the original string decoded = self.bpe_tokenizer.decode(token_ids, skip_special_tokens=False) # FAST-style safety: the encoded stream has no spaces; remove whitespace defensively decoded = "".join(decoded.split()) # Find markers and split i0 = decoded.find(self.INTENT_MARKER) i1 = decoded.find(self.RESIDUAL_MARKER) if i0 == -1 or i1 == -1 or i1 < i0: raise ValueError("Missing or misordered / markers in decoded string.") intent_str = decoded[i0 + len(self.INTENT_MARKER) : i1] residual_str = decoded[i1 + len(self.RESIDUAL_MARKER) :] # Convert chars back to quantized ints intent_vals = np.array(list(map(ord, intent_str)), dtype=int) + self.min_token residual_vals = np.array(list(map(ord, residual_str)), dtype=int) + self.min_token # Reshape to (K, D) and (T-K, D) if intent_vals.size != K * D: raise ValueError(f"Intent size mismatch: got {intent_vals.size}, expected {K*D}") if residual_vals.size != (T - K) * D: raise ValueError(f"Residual size mismatch: got {residual_vals.size}, expected {(T-K)*D}") intent_q = intent_vals.reshape(K, D) residual_q = residual_vals.reshape(T - K, D) # Reconstruct full DCT coefficient matrix (T, D) coeff_q = np.zeros((T, D), dtype=float) coeff_q[:K, :] = intent_q coeff_q[K:, :] = residual_q # Inverse DCT (time axis is axis=0 now because coeff_q is (T, D)) action = idct(coeff_q / self.scale, axis=0, norm="ortho") except Exception as e: print(f"[ResidualFAST] Error decoding tokens: {e}") print(f"[ResidualFAST] Tokens: {token_ids}") action = np.zeros((T, D), dtype=float) decoded_actions.append(action) return np.stack(decoded_actions, axis=0) @classmethod def fit( cls, action_data: list[np.ndarray] | np.ndarray, *, k_intent: int = 5, scale: float = 10.0, vocab_size: int = 1024, time_horizon: int | None = None, action_dim: int | None = None, ) -> "ResidualFASTActionProcessor": """ Train the internal BPE tokenizer on Residual FAST strings. action_data can be: - list of arrays, each (T, D) - or a single array (N, T, D) NOTE: - We keep the FAST alphabet trick: all possible quantized values are present in initial_alphabet. - We reserve room in vocab_size for the special marker tokens. """ if isinstance(action_data, np.ndarray): assert action_data.ndim == 3, "If passing np.ndarray, expected shape (N, T, D)." chunks = [action_data[i] for i in range(action_data.shape[0])] else: chunks = action_data if len(chunks) == 0: raise ValueError("Empty action_data passed to fit().") # Validate shapes (allow varying T, but D should be consistent for easiest decoding) Ds = [c.shape[1] for c in chunks] if len(set(Ds)) != 1 and action_dim is None: raise ValueError("Varying action_dim in fit() data. Pass action_dim=... or standardize D.") D = action_dim if action_dim is not None else Ds[0] # Build training corpus strings + track min/max quantized coefficients all_q_vals = [] strings = [] for a in chunks: assert a.ndim == 2, "Each chunk must be (T, D)." T, d = a.shape if d != D: raise ValueError(f"Chunk action_dim={d} != expected D={D}.") if k_intent < 0 or k_intent > T: raise ValueError(f"k_intent must be in [0, T]. Got k_intent={k_intent}, T={T}") coeff = dct(a, axis=0, norm="ortho") # (T, D) intent = coeff[:k_intent, :] residual = coeff[k_intent:, :] # Quantize intent_q = np.around(intent * scale).astype(int) residual_q = np.around(residual * scale).astype(int) all_q_vals.append(intent_q.flatten()) all_q_vals.append(residual_q.flatten()) all_q = np.concatenate(all_q_vals, axis=0) max_token = int(all_q.max()) min_token = int(all_q.min()) # FAST constraint: alphabet size must be <= vocab_size minus special tokens min_vocab_size = max_token - min_token # inclusive range => size = +1 n_special = 2 # , required_vocab = (max_token - min_token + 1) + n_special if required_vocab > vocab_size: raise AssertionError( f"vocab_size={vocab_size} too small. Need >= (range+special) = {required_vocab} " f"(range={max_token-min_token+1}, special={n_special})." ) if (max_token - min_token + 1) + 100 > vocab_size: logging.warning( "Initial alphabet size is close to vocab_size. Consider increasing vocab_size " "for better BPE compression." ) # Iterator producing Residual FAST strings def _token_iter() -> Iterable[str]: for a in chunks: T, d = a.shape coeff = dct(a, axis=0, norm="ortho") intent = coeff[:k_intent, :] residual = coeff[k_intent:, :] intent_q = (np.around(intent * scale) - min_token).astype(int) residual_q = (np.around(residual * scale) - min_token).astype(int) intent_str = "".join(map(chr, intent_q.flatten())) residual_str = "".join(map(chr, residual_q.flatten())) yield f"{cls.INTENT_MARKER}{intent_str}{cls.RESIDUAL_MARKER}{residual_str}" # Train BPE tokenizer (byte-level) bpe = ByteLevelBPETokenizer() # Alphabet for the quantized chars alphabet = [chr(i) for i in range(max_token - min_token + 1)] trainer = BpeTrainer( vocab_size=vocab_size, min_frequency=2, show_progress=True, special_tokens=[cls.INTENT_MARKER, cls.RESIDUAL_MARKER], initial_alphabet=alphabet, max_token_length=10000, ) # Train inner tokenizer (same trick as FAST) bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) hf_tok = PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False) # Ensure special tokens registered (defensive) cls._ensure_special_tokens(hf_tok) return cls( hf_tok, k_intent=k_intent, scale=scale, vocab_size=vocab_size, min_token=min_token, time_horizon=time_horizon, action_dim=action_dim, )