|
|
|
|
|
|
|
|
import logging |
|
|
from typing import ClassVar, Iterable |
|
|
|
|
|
import numpy as np |
|
|
from scipy.fft import dct, idct |
|
|
from tokenizers import ByteLevelBPETokenizer |
|
|
from tokenizers.trainers import BpeTrainer |
|
|
from transformers import PreTrainedTokenizerFast |
|
|
from transformers.processing_utils import ProcessorMixin |
|
|
|
|
|
|
|
|
class ResidualFASTActionProcessor(ProcessorMixin): |
|
|
""" |
|
|
Residual FAST: intent + residual tokenization built on top of FAST's DCT+BPE scheme. |
|
|
|
|
|
Encodes an action chunk (B, T, D) into tokens by: |
|
|
1) DCT over time axis |
|
|
2) Split coeffs: |
|
|
intent = coeff[:k_intent, :] |
|
|
residual= coeff[k_intent:, :] |
|
|
3) Quantize to ints |
|
|
4) Convert to a string of characters via chr(int - min_token) |
|
|
5) Wrap with special markers: <INTENT> ... <RESIDUAL> ... |
|
|
6) BPE-tokenize the resulting string |
|
|
|
|
|
Decoding reverses the above. |
|
|
|
|
|
Notes: |
|
|
- Assumes input actions are already normalized to roughly [-1, 1] (same as FAST). |
|
|
- Uses a single BPE tokenizer to keep the interface identical to FAST. |
|
|
- Markers are special tokens so decode can reliably split streams. |
|
|
""" |
|
|
|
|
|
attributes: ClassVar[list[str]] = ["bpe_tokenizer"] |
|
|
bpe_tokenizer_class: str = "AutoTokenizer" |
|
|
|
|
|
INTENT_MARKER = "<INTENT>" |
|
|
RESIDUAL_MARKER = "<RESIDUAL>" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
bpe_tokenizer: PreTrainedTokenizerFast, |
|
|
*, |
|
|
k_intent: int = 5, |
|
|
scale: float = 10.0, |
|
|
vocab_size: int = 1024, |
|
|
min_token: int = 0, |
|
|
action_dim: int | None = None, |
|
|
time_horizon: int | None = None, |
|
|
): |
|
|
self.k_intent = int(k_intent) |
|
|
self.scale = float(scale) |
|
|
self.vocab_size = int(vocab_size) |
|
|
self.min_token = int(min_token) |
|
|
|
|
|
|
|
|
self.time_horizon = time_horizon |
|
|
self.action_dim = action_dim |
|
|
self.called_time_horizon = time_horizon |
|
|
self.called_action_dim = action_dim |
|
|
|
|
|
|
|
|
self._ensure_special_tokens(bpe_tokenizer) |
|
|
|
|
|
super().__init__(bpe_tokenizer) |
|
|
|
|
|
@staticmethod |
|
|
def _ensure_special_tokens(tok: PreTrainedTokenizerFast) -> None: |
|
|
special = set(tok.all_special_tokens) |
|
|
to_add = [] |
|
|
if ResidualFASTActionProcessor.INTENT_MARKER not in special: |
|
|
to_add.append(ResidualFASTActionProcessor.INTENT_MARKER) |
|
|
if ResidualFASTActionProcessor.RESIDUAL_MARKER not in special: |
|
|
to_add.append(ResidualFASTActionProcessor.RESIDUAL_MARKER) |
|
|
if to_add: |
|
|
tok.add_special_tokens({"additional_special_tokens": to_add}) |
|
|
|
|
|
def __call__(self, action_chunk: np.ndarray) -> list[list[int]]: |
|
|
""" |
|
|
action_chunk: np.ndarray with shape (T, D) or (B, T, D) |
|
|
returns: list of token-id lists, length B |
|
|
""" |
|
|
assert action_chunk.ndim <= 3, "Only up to 3 dims supported: [batch, timesteps, action_dim]" |
|
|
if action_chunk.ndim == 2: |
|
|
action_chunk = action_chunk[None, ...] |
|
|
|
|
|
B, T, D = action_chunk.shape |
|
|
if self.k_intent < 0 or self.k_intent > T: |
|
|
raise ValueError(f"k_intent must be in [0, T]. Got k_intent={self.k_intent}, T={T}") |
|
|
|
|
|
|
|
|
self.called_time_horizon = T |
|
|
self.called_action_dim = D |
|
|
|
|
|
|
|
|
coeff = dct(action_chunk, axis=1, norm="ortho") |
|
|
|
|
|
|
|
|
intent_coeff = coeff[:, : self.k_intent, :] |
|
|
residual_coeff = coeff[:, self.k_intent :, :] |
|
|
|
|
|
|
|
|
intent_q = np.around(intent_coeff * self.scale).astype(int) |
|
|
residual_q = np.around(residual_coeff * self.scale).astype(int) |
|
|
|
|
|
tokens: list[list[int]] = [] |
|
|
for b in range(B): |
|
|
|
|
|
intent_chars = "".join( |
|
|
map(chr, np.maximum(intent_q[b].flatten() - self.min_token, 0).astype(int)) |
|
|
) |
|
|
residual_chars = "".join( |
|
|
map(chr, np.maximum(residual_q[b].flatten() - self.min_token, 0).astype(int)) |
|
|
) |
|
|
|
|
|
|
|
|
token_str = f"{self.INTENT_MARKER}{intent_chars}{self.RESIDUAL_MARKER}{residual_chars}" |
|
|
|
|
|
|
|
|
ids = self.bpe_tokenizer(token_str, add_special_tokens=False)["input_ids"] |
|
|
tokens.append(ids) |
|
|
|
|
|
return tokens |
|
|
|
|
|
def decode( |
|
|
self, |
|
|
tokens: list[list[int]], |
|
|
*, |
|
|
time_horizon: int | None = None, |
|
|
action_dim: int | None = None, |
|
|
k_intent: int | None = None, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
tokens: list of token-id lists (batch) |
|
|
returns: np.ndarray (B, T, D) |
|
|
""" |
|
|
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon |
|
|
self.action_dim = action_dim or self.action_dim or self.called_action_dim |
|
|
K = int(k_intent) if k_intent is not None else self.k_intent |
|
|
|
|
|
|
|
|
self.called_time_horizon = self.time_horizon |
|
|
self.called_action_dim = self.action_dim |
|
|
|
|
|
assert self.time_horizon is not None and self.action_dim is not None, ( |
|
|
"Tokenizer not initialized. Call encode() once or pass time_horizon and action_dim." |
|
|
) |
|
|
|
|
|
T = int(self.time_horizon) |
|
|
D = int(self.action_dim) |
|
|
if K < 0 or K > T: |
|
|
raise ValueError(f"k_intent must be in [0, T]. Got k_intent={K}, T={T}") |
|
|
|
|
|
decoded_actions = [] |
|
|
for token_ids in tokens: |
|
|
try: |
|
|
|
|
|
decoded = self.bpe_tokenizer.decode(token_ids, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
decoded = "".join(decoded.split()) |
|
|
|
|
|
|
|
|
i0 = decoded.find(self.INTENT_MARKER) |
|
|
i1 = decoded.find(self.RESIDUAL_MARKER) |
|
|
if i0 == -1 or i1 == -1 or i1 < i0: |
|
|
raise ValueError("Missing or misordered <INTENT>/<RESIDUAL> markers in decoded string.") |
|
|
|
|
|
intent_str = decoded[i0 + len(self.INTENT_MARKER) : i1] |
|
|
residual_str = decoded[i1 + len(self.RESIDUAL_MARKER) :] |
|
|
|
|
|
|
|
|
intent_vals = np.array(list(map(ord, intent_str)), dtype=int) + self.min_token |
|
|
residual_vals = np.array(list(map(ord, residual_str)), dtype=int) + self.min_token |
|
|
|
|
|
|
|
|
if intent_vals.size != K * D: |
|
|
raise ValueError(f"Intent size mismatch: got {intent_vals.size}, expected {K*D}") |
|
|
if residual_vals.size != (T - K) * D: |
|
|
raise ValueError(f"Residual size mismatch: got {residual_vals.size}, expected {(T-K)*D}") |
|
|
|
|
|
intent_q = intent_vals.reshape(K, D) |
|
|
residual_q = residual_vals.reshape(T - K, D) |
|
|
|
|
|
|
|
|
coeff_q = np.zeros((T, D), dtype=float) |
|
|
coeff_q[:K, :] = intent_q |
|
|
coeff_q[K:, :] = residual_q |
|
|
|
|
|
|
|
|
action = idct(coeff_q / self.scale, axis=0, norm="ortho") |
|
|
except Exception as e: |
|
|
print(f"[ResidualFAST] Error decoding tokens: {e}") |
|
|
print(f"[ResidualFAST] Tokens: {token_ids}") |
|
|
action = np.zeros((T, D), dtype=float) |
|
|
|
|
|
decoded_actions.append(action) |
|
|
|
|
|
return np.stack(decoded_actions, axis=0) |
|
|
|
|
|
@classmethod |
|
|
def fit( |
|
|
cls, |
|
|
action_data: list[np.ndarray] | np.ndarray, |
|
|
*, |
|
|
k_intent: int = 5, |
|
|
scale: float = 10.0, |
|
|
vocab_size: int = 1024, |
|
|
time_horizon: int | None = None, |
|
|
action_dim: int | None = None, |
|
|
) -> "ResidualFASTActionProcessor": |
|
|
""" |
|
|
Train the internal BPE tokenizer on Residual FAST strings. |
|
|
|
|
|
action_data can be: |
|
|
- list of arrays, each (T, D) |
|
|
- or a single array (N, T, D) |
|
|
|
|
|
NOTE: |
|
|
- We keep the FAST alphabet trick: all possible quantized values are present in initial_alphabet. |
|
|
- We reserve room in vocab_size for the special marker tokens. |
|
|
""" |
|
|
if isinstance(action_data, np.ndarray): |
|
|
assert action_data.ndim == 3, "If passing np.ndarray, expected shape (N, T, D)." |
|
|
chunks = [action_data[i] for i in range(action_data.shape[0])] |
|
|
else: |
|
|
chunks = action_data |
|
|
|
|
|
if len(chunks) == 0: |
|
|
raise ValueError("Empty action_data passed to fit().") |
|
|
|
|
|
|
|
|
Ds = [c.shape[1] for c in chunks] |
|
|
if len(set(Ds)) != 1 and action_dim is None: |
|
|
raise ValueError("Varying action_dim in fit() data. Pass action_dim=... or standardize D.") |
|
|
D = action_dim if action_dim is not None else Ds[0] |
|
|
|
|
|
|
|
|
all_q_vals = [] |
|
|
strings = [] |
|
|
|
|
|
for a in chunks: |
|
|
assert a.ndim == 2, "Each chunk must be (T, D)." |
|
|
T, d = a.shape |
|
|
if d != D: |
|
|
raise ValueError(f"Chunk action_dim={d} != expected D={D}.") |
|
|
if k_intent < 0 or k_intent > T: |
|
|
raise ValueError(f"k_intent must be in [0, T]. Got k_intent={k_intent}, T={T}") |
|
|
|
|
|
coeff = dct(a, axis=0, norm="ortho") |
|
|
intent = coeff[:k_intent, :] |
|
|
residual = coeff[k_intent:, :] |
|
|
|
|
|
|
|
|
intent_q = np.around(intent * scale).astype(int) |
|
|
residual_q = np.around(residual * scale).astype(int) |
|
|
|
|
|
all_q_vals.append(intent_q.flatten()) |
|
|
all_q_vals.append(residual_q.flatten()) |
|
|
|
|
|
all_q = np.concatenate(all_q_vals, axis=0) |
|
|
max_token = int(all_q.max()) |
|
|
min_token = int(all_q.min()) |
|
|
|
|
|
|
|
|
min_vocab_size = max_token - min_token |
|
|
n_special = 2 |
|
|
required_vocab = (max_token - min_token + 1) + n_special |
|
|
if required_vocab > vocab_size: |
|
|
raise AssertionError( |
|
|
f"vocab_size={vocab_size} too small. Need >= (range+special) = {required_vocab} " |
|
|
f"(range={max_token-min_token+1}, special={n_special})." |
|
|
) |
|
|
|
|
|
if (max_token - min_token + 1) + 100 > vocab_size: |
|
|
logging.warning( |
|
|
"Initial alphabet size is close to vocab_size. Consider increasing vocab_size " |
|
|
"for better BPE compression." |
|
|
) |
|
|
|
|
|
|
|
|
def _token_iter() -> Iterable[str]: |
|
|
for a in chunks: |
|
|
T, d = a.shape |
|
|
coeff = dct(a, axis=0, norm="ortho") |
|
|
|
|
|
intent = coeff[:k_intent, :] |
|
|
residual = coeff[k_intent:, :] |
|
|
|
|
|
intent_q = (np.around(intent * scale) - min_token).astype(int) |
|
|
residual_q = (np.around(residual * scale) - min_token).astype(int) |
|
|
|
|
|
intent_str = "".join(map(chr, intent_q.flatten())) |
|
|
residual_str = "".join(map(chr, residual_q.flatten())) |
|
|
|
|
|
yield f"{cls.INTENT_MARKER}{intent_str}{cls.RESIDUAL_MARKER}{residual_str}" |
|
|
|
|
|
|
|
|
bpe = ByteLevelBPETokenizer() |
|
|
|
|
|
|
|
|
alphabet = [chr(i) for i in range(max_token - min_token + 1)] |
|
|
|
|
|
trainer = BpeTrainer( |
|
|
vocab_size=vocab_size, |
|
|
min_frequency=2, |
|
|
show_progress=True, |
|
|
special_tokens=[cls.INTENT_MARKER, cls.RESIDUAL_MARKER], |
|
|
initial_alphabet=alphabet, |
|
|
max_token_length=10000, |
|
|
) |
|
|
|
|
|
|
|
|
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer) |
|
|
|
|
|
hf_tok = PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False) |
|
|
|
|
|
cls._ensure_special_tokens(hf_tok) |
|
|
|
|
|
return cls( |
|
|
hf_tok, |
|
|
k_intent=k_intent, |
|
|
scale=scale, |
|
|
vocab_size=vocab_size, |
|
|
min_token=min_token, |
|
|
time_horizon=time_horizon, |
|
|
action_dim=action_dim, |
|
|
) |
|
|
|