resfast-tokenizer / processing_action_tokenizer.py
jadechoghari's picture
jadechoghari HF Staff
Upload folder using huggingface_hub
f3b0016 verified
# processing_action_tokenizer_residual.py
import logging
from typing import ClassVar, Iterable
import numpy as np
from scipy.fft import dct, idct
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
class ResidualFASTActionProcessor(ProcessorMixin):
"""
Residual FAST: intent + residual tokenization built on top of FAST's DCT+BPE scheme.
Encodes an action chunk (B, T, D) into tokens by:
1) DCT over time axis
2) Split coeffs:
intent = coeff[:k_intent, :]
residual= coeff[k_intent:, :]
3) Quantize to ints
4) Convert to a string of characters via chr(int - min_token)
5) Wrap with special markers: <INTENT> ... <RESIDUAL> ...
6) BPE-tokenize the resulting string
Decoding reverses the above.
Notes:
- Assumes input actions are already normalized to roughly [-1, 1] (same as FAST).
- Uses a single BPE tokenizer to keep the interface identical to FAST.
- Markers are special tokens so decode can reliably split streams.
"""
attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
bpe_tokenizer_class: str = "AutoTokenizer"
INTENT_MARKER = "<INTENT>"
RESIDUAL_MARKER = "<RESIDUAL>"
def __init__(
self,
bpe_tokenizer: PreTrainedTokenizerFast,
*,
k_intent: int = 5,
scale: float = 10.0,
vocab_size: int = 1024,
min_token: int = 0,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.k_intent = int(k_intent)
self.scale = float(scale)
self.vocab_size = int(vocab_size)
self.min_token = int(min_token)
# Needed for decoding
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
# Ensure markers exist as special tokens in the tokenizer (robust)
self._ensure_special_tokens(bpe_tokenizer)
super().__init__(bpe_tokenizer)
@staticmethod
def _ensure_special_tokens(tok: PreTrainedTokenizerFast) -> None:
special = set(tok.all_special_tokens)
to_add = []
if ResidualFASTActionProcessor.INTENT_MARKER not in special:
to_add.append(ResidualFASTActionProcessor.INTENT_MARKER)
if ResidualFASTActionProcessor.RESIDUAL_MARKER not in special:
to_add.append(ResidualFASTActionProcessor.RESIDUAL_MARKER)
if to_add:
tok.add_special_tokens({"additional_special_tokens": to_add})
def __call__(self, action_chunk: np.ndarray) -> list[list[int]]:
"""
action_chunk: np.ndarray with shape (T, D) or (B, T, D)
returns: list of token-id lists, length B
"""
assert action_chunk.ndim <= 3, "Only up to 3 dims supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
B, T, D = action_chunk.shape
if self.k_intent < 0 or self.k_intent > T:
raise ValueError(f"k_intent must be in [0, T]. Got k_intent={self.k_intent}, T={T}")
# Cache dimensions for decode
self.called_time_horizon = T
self.called_action_dim = D
# DCT over time axis (axis=1)
coeff = dct(action_chunk, axis=1, norm="ortho") # (B, T, D)
# Split frequencies
intent_coeff = coeff[:, : self.k_intent, :] # (B, K, D)
residual_coeff = coeff[:, self.k_intent :, :] # (B, T-K, D)
# Quantize
intent_q = np.around(intent_coeff * self.scale).astype(int)
residual_q = np.around(residual_coeff * self.scale).astype(int)
tokens: list[list[int]] = []
for b in range(B):
# Convert quantized ints to chars (shifted by min_token)
intent_chars = "".join(
map(chr, np.maximum(intent_q[b].flatten() - self.min_token, 0).astype(int))
)
residual_chars = "".join(
map(chr, np.maximum(residual_q[b].flatten() - self.min_token, 0).astype(int))
)
# Insert markers; remove any whitespace in tokenizer decode later, so no need to add separators
token_str = f"{self.INTENT_MARKER}{intent_chars}{self.RESIDUAL_MARKER}{residual_chars}"
# IMPORTANT: add_special_tokens=False so we don't inject BOS/EOS etc.
ids = self.bpe_tokenizer(token_str, add_special_tokens=False)["input_ids"]
tokens.append(ids)
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
k_intent: int | None = None,
) -> np.ndarray:
"""
tokens: list of token-id lists (batch)
returns: np.ndarray (B, T, D)
"""
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
K = int(k_intent) if k_intent is not None else self.k_intent
# Cache for next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert self.time_horizon is not None and self.action_dim is not None, (
"Tokenizer not initialized. Call encode() once or pass time_horizon and action_dim."
)
T = int(self.time_horizon)
D = int(self.action_dim)
if K < 0 or K > T:
raise ValueError(f"k_intent must be in [0, T]. Got k_intent={K}, T={T}")
decoded_actions = []
for token_ids in tokens:
try:
# Decode back to the original string
decoded = self.bpe_tokenizer.decode(token_ids, skip_special_tokens=False)
# FAST-style safety: the encoded stream has no spaces; remove whitespace defensively
decoded = "".join(decoded.split())
# Find markers and split
i0 = decoded.find(self.INTENT_MARKER)
i1 = decoded.find(self.RESIDUAL_MARKER)
if i0 == -1 or i1 == -1 or i1 < i0:
raise ValueError("Missing or misordered <INTENT>/<RESIDUAL> markers in decoded string.")
intent_str = decoded[i0 + len(self.INTENT_MARKER) : i1]
residual_str = decoded[i1 + len(self.RESIDUAL_MARKER) :]
# Convert chars back to quantized ints
intent_vals = np.array(list(map(ord, intent_str)), dtype=int) + self.min_token
residual_vals = np.array(list(map(ord, residual_str)), dtype=int) + self.min_token
# Reshape to (K, D) and (T-K, D)
if intent_vals.size != K * D:
raise ValueError(f"Intent size mismatch: got {intent_vals.size}, expected {K*D}")
if residual_vals.size != (T - K) * D:
raise ValueError(f"Residual size mismatch: got {residual_vals.size}, expected {(T-K)*D}")
intent_q = intent_vals.reshape(K, D)
residual_q = residual_vals.reshape(T - K, D)
# Reconstruct full DCT coefficient matrix (T, D)
coeff_q = np.zeros((T, D), dtype=float)
coeff_q[:K, :] = intent_q
coeff_q[K:, :] = residual_q
# Inverse DCT (time axis is axis=0 now because coeff_q is (T, D))
action = idct(coeff_q / self.scale, axis=0, norm="ortho")
except Exception as e:
print(f"[ResidualFAST] Error decoding tokens: {e}")
print(f"[ResidualFAST] Tokens: {token_ids}")
action = np.zeros((T, D), dtype=float)
decoded_actions.append(action)
return np.stack(decoded_actions, axis=0)
@classmethod
def fit(
cls,
action_data: list[np.ndarray] | np.ndarray,
*,
k_intent: int = 5,
scale: float = 10.0,
vocab_size: int = 1024,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "ResidualFASTActionProcessor":
"""
Train the internal BPE tokenizer on Residual FAST strings.
action_data can be:
- list of arrays, each (T, D)
- or a single array (N, T, D)
NOTE:
- We keep the FAST alphabet trick: all possible quantized values are present in initial_alphabet.
- We reserve room in vocab_size for the special marker tokens.
"""
if isinstance(action_data, np.ndarray):
assert action_data.ndim == 3, "If passing np.ndarray, expected shape (N, T, D)."
chunks = [action_data[i] for i in range(action_data.shape[0])]
else:
chunks = action_data
if len(chunks) == 0:
raise ValueError("Empty action_data passed to fit().")
# Validate shapes (allow varying T, but D should be consistent for easiest decoding)
Ds = [c.shape[1] for c in chunks]
if len(set(Ds)) != 1 and action_dim is None:
raise ValueError("Varying action_dim in fit() data. Pass action_dim=... or standardize D.")
D = action_dim if action_dim is not None else Ds[0]
# Build training corpus strings + track min/max quantized coefficients
all_q_vals = []
strings = []
for a in chunks:
assert a.ndim == 2, "Each chunk must be (T, D)."
T, d = a.shape
if d != D:
raise ValueError(f"Chunk action_dim={d} != expected D={D}.")
if k_intent < 0 or k_intent > T:
raise ValueError(f"k_intent must be in [0, T]. Got k_intent={k_intent}, T={T}")
coeff = dct(a, axis=0, norm="ortho") # (T, D)
intent = coeff[:k_intent, :]
residual = coeff[k_intent:, :]
# Quantize
intent_q = np.around(intent * scale).astype(int)
residual_q = np.around(residual * scale).astype(int)
all_q_vals.append(intent_q.flatten())
all_q_vals.append(residual_q.flatten())
all_q = np.concatenate(all_q_vals, axis=0)
max_token = int(all_q.max())
min_token = int(all_q.min())
# FAST constraint: alphabet size must be <= vocab_size minus special tokens
min_vocab_size = max_token - min_token # inclusive range => size = +1
n_special = 2 # <INTENT>, <RESIDUAL>
required_vocab = (max_token - min_token + 1) + n_special
if required_vocab > vocab_size:
raise AssertionError(
f"vocab_size={vocab_size} too small. Need >= (range+special) = {required_vocab} "
f"(range={max_token-min_token+1}, special={n_special})."
)
if (max_token - min_token + 1) + 100 > vocab_size:
logging.warning(
"Initial alphabet size is close to vocab_size. Consider increasing vocab_size "
"for better BPE compression."
)
# Iterator producing Residual FAST strings
def _token_iter() -> Iterable[str]:
for a in chunks:
T, d = a.shape
coeff = dct(a, axis=0, norm="ortho")
intent = coeff[:k_intent, :]
residual = coeff[k_intent:, :]
intent_q = (np.around(intent * scale) - min_token).astype(int)
residual_q = (np.around(residual * scale) - min_token).astype(int)
intent_str = "".join(map(chr, intent_q.flatten()))
residual_str = "".join(map(chr, residual_q.flatten()))
yield f"{cls.INTENT_MARKER}{intent_str}{cls.RESIDUAL_MARKER}{residual_str}"
# Train BPE tokenizer (byte-level)
bpe = ByteLevelBPETokenizer()
# Alphabet for the quantized chars
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[cls.INTENT_MARKER, cls.RESIDUAL_MARKER],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train inner tokenizer (same trick as FAST)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
hf_tok = PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False)
# Ensure special tokens registered (defensive)
cls._ensure_special_tokens(hf_tok)
return cls(
hf_tok,
k_intent=k_intent,
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)