File size: 12,937 Bytes
f3b0016
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
# processing_action_tokenizer_residual.py

import logging
from typing import ClassVar, Iterable

import numpy as np
from scipy.fft import dct, idct
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin


class ResidualFASTActionProcessor(ProcessorMixin):
    """
    Residual FAST: intent + residual tokenization built on top of FAST's DCT+BPE scheme.

    Encodes an action chunk (B, T, D) into tokens by:
      1) DCT over time axis
      2) Split coeffs:
           intent  = coeff[:k_intent, :]
           residual= coeff[k_intent:, :]
      3) Quantize to ints
      4) Convert to a string of characters via chr(int - min_token)
      5) Wrap with special markers: <INTENT> ... <RESIDUAL> ...
      6) BPE-tokenize the resulting string

    Decoding reverses the above.

    Notes:
      - Assumes input actions are already normalized to roughly [-1, 1] (same as FAST).
      - Uses a single BPE tokenizer to keep the interface identical to FAST.
      - Markers are special tokens so decode can reliably split streams.
    """

    attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
    bpe_tokenizer_class: str = "AutoTokenizer"

    INTENT_MARKER = "<INTENT>"
    RESIDUAL_MARKER = "<RESIDUAL>"

    def __init__(
        self,
        bpe_tokenizer: PreTrainedTokenizerFast,
        *,
        k_intent: int = 5,
        scale: float = 10.0,
        vocab_size: int = 1024,
        min_token: int = 0,
        action_dim: int | None = None,
        time_horizon: int | None = None,
    ):
        self.k_intent = int(k_intent)
        self.scale = float(scale)
        self.vocab_size = int(vocab_size)
        self.min_token = int(min_token)

        # Needed for decoding
        self.time_horizon = time_horizon
        self.action_dim = action_dim
        self.called_time_horizon = time_horizon
        self.called_action_dim = action_dim

        # Ensure markers exist as special tokens in the tokenizer (robust)
        self._ensure_special_tokens(bpe_tokenizer)

        super().__init__(bpe_tokenizer)

    @staticmethod
    def _ensure_special_tokens(tok: PreTrainedTokenizerFast) -> None:
        special = set(tok.all_special_tokens)
        to_add = []
        if ResidualFASTActionProcessor.INTENT_MARKER not in special:
            to_add.append(ResidualFASTActionProcessor.INTENT_MARKER)
        if ResidualFASTActionProcessor.RESIDUAL_MARKER not in special:
            to_add.append(ResidualFASTActionProcessor.RESIDUAL_MARKER)
        if to_add:
            tok.add_special_tokens({"additional_special_tokens": to_add})

    def __call__(self, action_chunk: np.ndarray) -> list[list[int]]:
        """
        action_chunk: np.ndarray with shape (T, D) or (B, T, D)
        returns: list of token-id lists, length B
        """
        assert action_chunk.ndim <= 3, "Only up to 3 dims supported: [batch, timesteps, action_dim]"
        if action_chunk.ndim == 2:
            action_chunk = action_chunk[None, ...]

        B, T, D = action_chunk.shape
        if self.k_intent < 0 or self.k_intent > T:
            raise ValueError(f"k_intent must be in [0, T]. Got k_intent={self.k_intent}, T={T}")

        # Cache dimensions for decode
        self.called_time_horizon = T
        self.called_action_dim = D

        # DCT over time axis (axis=1)
        coeff = dct(action_chunk, axis=1, norm="ortho")  # (B, T, D)

        # Split frequencies
        intent_coeff = coeff[:, : self.k_intent, :]       # (B, K, D)
        residual_coeff = coeff[:, self.k_intent :, :]     # (B, T-K, D)

        # Quantize
        intent_q = np.around(intent_coeff * self.scale).astype(int)
        residual_q = np.around(residual_coeff * self.scale).astype(int)

        tokens: list[list[int]] = []
        for b in range(B):
            # Convert quantized ints to chars (shifted by min_token)
            intent_chars = "".join(
                map(chr, np.maximum(intent_q[b].flatten() - self.min_token, 0).astype(int))
            )
            residual_chars = "".join(
                map(chr, np.maximum(residual_q[b].flatten() - self.min_token, 0).astype(int))
            )

            # Insert markers; remove any whitespace in tokenizer decode later, so no need to add separators
            token_str = f"{self.INTENT_MARKER}{intent_chars}{self.RESIDUAL_MARKER}{residual_chars}"

            # IMPORTANT: add_special_tokens=False so we don't inject BOS/EOS etc.
            ids = self.bpe_tokenizer(token_str, add_special_tokens=False)["input_ids"]
            tokens.append(ids)

        return tokens

    def decode(
        self,
        tokens: list[list[int]],
        *,
        time_horizon: int | None = None,
        action_dim: int | None = None,
        k_intent: int | None = None,
    ) -> np.ndarray:
        """
        tokens: list of token-id lists (batch)
        returns: np.ndarray (B, T, D)
        """
        self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
        self.action_dim = action_dim or self.action_dim or self.called_action_dim
        K = int(k_intent) if k_intent is not None else self.k_intent

        # Cache for next call
        self.called_time_horizon = self.time_horizon
        self.called_action_dim = self.action_dim

        assert self.time_horizon is not None and self.action_dim is not None, (
            "Tokenizer not initialized. Call encode() once or pass time_horizon and action_dim."
        )

        T = int(self.time_horizon)
        D = int(self.action_dim)
        if K < 0 or K > T:
            raise ValueError(f"k_intent must be in [0, T]. Got k_intent={K}, T={T}")

        decoded_actions = []
        for token_ids in tokens:
            try:
                # Decode back to the original string
                decoded = self.bpe_tokenizer.decode(token_ids, skip_special_tokens=False)

                # FAST-style safety: the encoded stream has no spaces; remove whitespace defensively
                decoded = "".join(decoded.split())

                # Find markers and split
                i0 = decoded.find(self.INTENT_MARKER)
                i1 = decoded.find(self.RESIDUAL_MARKER)
                if i0 == -1 or i1 == -1 or i1 < i0:
                    raise ValueError("Missing or misordered <INTENT>/<RESIDUAL> markers in decoded string.")

                intent_str = decoded[i0 + len(self.INTENT_MARKER) : i1]
                residual_str = decoded[i1 + len(self.RESIDUAL_MARKER) :]

                # Convert chars back to quantized ints
                intent_vals = np.array(list(map(ord, intent_str)), dtype=int) + self.min_token
                residual_vals = np.array(list(map(ord, residual_str)), dtype=int) + self.min_token

                # Reshape to (K, D) and (T-K, D)
                if intent_vals.size != K * D:
                    raise ValueError(f"Intent size mismatch: got {intent_vals.size}, expected {K*D}")
                if residual_vals.size != (T - K) * D:
                    raise ValueError(f"Residual size mismatch: got {residual_vals.size}, expected {(T-K)*D}")

                intent_q = intent_vals.reshape(K, D)
                residual_q = residual_vals.reshape(T - K, D)

                # Reconstruct full DCT coefficient matrix (T, D)
                coeff_q = np.zeros((T, D), dtype=float)
                coeff_q[:K, :] = intent_q
                coeff_q[K:, :] = residual_q

                # Inverse DCT (time axis is axis=0 now because coeff_q is (T, D))
                action = idct(coeff_q / self.scale, axis=0, norm="ortho")
            except Exception as e:
                print(f"[ResidualFAST] Error decoding tokens: {e}")
                print(f"[ResidualFAST] Tokens: {token_ids}")
                action = np.zeros((T, D), dtype=float)

            decoded_actions.append(action)

        return np.stack(decoded_actions, axis=0)

    @classmethod
    def fit(
        cls,
        action_data: list[np.ndarray] | np.ndarray,
        *,
        k_intent: int = 5,
        scale: float = 10.0,
        vocab_size: int = 1024,
        time_horizon: int | None = None,
        action_dim: int | None = None,
    ) -> "ResidualFASTActionProcessor":
        """
        Train the internal BPE tokenizer on Residual FAST strings.

        action_data can be:
          - list of arrays, each (T, D)
          - or a single array (N, T, D)

        NOTE:
          - We keep the FAST alphabet trick: all possible quantized values are present in initial_alphabet.
          - We reserve room in vocab_size for the special marker tokens.
        """
        if isinstance(action_data, np.ndarray):
            assert action_data.ndim == 3, "If passing np.ndarray, expected shape (N, T, D)."
            chunks = [action_data[i] for i in range(action_data.shape[0])]
        else:
            chunks = action_data

        if len(chunks) == 0:
            raise ValueError("Empty action_data passed to fit().")

        # Validate shapes (allow varying T, but D should be consistent for easiest decoding)
        Ds = [c.shape[1] for c in chunks]
        if len(set(Ds)) != 1 and action_dim is None:
            raise ValueError("Varying action_dim in fit() data. Pass action_dim=... or standardize D.")
        D = action_dim if action_dim is not None else Ds[0]

        # Build training corpus strings + track min/max quantized coefficients
        all_q_vals = []
        strings = []

        for a in chunks:
            assert a.ndim == 2, "Each chunk must be (T, D)."
            T, d = a.shape
            if d != D:
                raise ValueError(f"Chunk action_dim={d} != expected D={D}.")
            if k_intent < 0 or k_intent > T:
                raise ValueError(f"k_intent must be in [0, T]. Got k_intent={k_intent}, T={T}")

            coeff = dct(a, axis=0, norm="ortho")  # (T, D)
            intent = coeff[:k_intent, :]
            residual = coeff[k_intent:, :]

            # Quantize
            intent_q = np.around(intent * scale).astype(int)
            residual_q = np.around(residual * scale).astype(int)

            all_q_vals.append(intent_q.flatten())
            all_q_vals.append(residual_q.flatten())

        all_q = np.concatenate(all_q_vals, axis=0)
        max_token = int(all_q.max())
        min_token = int(all_q.min())

        # FAST constraint: alphabet size must be <= vocab_size minus special tokens
        min_vocab_size = max_token - min_token  # inclusive range => size = +1
        n_special = 2  # <INTENT>, <RESIDUAL>
        required_vocab = (max_token - min_token + 1) + n_special
        if required_vocab > vocab_size:
            raise AssertionError(
                f"vocab_size={vocab_size} too small. Need >= (range+special) = {required_vocab} "
                f"(range={max_token-min_token+1}, special={n_special})."
            )

        if (max_token - min_token + 1) + 100 > vocab_size:
            logging.warning(
                "Initial alphabet size is close to vocab_size. Consider increasing vocab_size "
                "for better BPE compression."
            )

        # Iterator producing Residual FAST strings
        def _token_iter() -> Iterable[str]:
            for a in chunks:
                T, d = a.shape
                coeff = dct(a, axis=0, norm="ortho")

                intent = coeff[:k_intent, :]
                residual = coeff[k_intent:, :]

                intent_q = (np.around(intent * scale) - min_token).astype(int)
                residual_q = (np.around(residual * scale) - min_token).astype(int)

                intent_str = "".join(map(chr, intent_q.flatten()))
                residual_str = "".join(map(chr, residual_q.flatten()))

                yield f"{cls.INTENT_MARKER}{intent_str}{cls.RESIDUAL_MARKER}{residual_str}"

        # Train BPE tokenizer (byte-level)
        bpe = ByteLevelBPETokenizer()

        # Alphabet for the quantized chars
        alphabet = [chr(i) for i in range(max_token - min_token + 1)]

        trainer = BpeTrainer(
            vocab_size=vocab_size,
            min_frequency=2,
            show_progress=True,
            special_tokens=[cls.INTENT_MARKER, cls.RESIDUAL_MARKER],
            initial_alphabet=alphabet,
            max_token_length=10000,
        )

        # Train inner tokenizer (same trick as FAST)
        bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)

        hf_tok = PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False)
        # Ensure special tokens registered (defensive)
        cls._ensure_special_tokens(hf_tok)

        return cls(
            hf_tok,
            k_intent=k_intent,
            scale=scale,
            vocab_size=vocab_size,
            min_token=min_token,
            time_horizon=time_horizon,
            action_dim=action_dim,
        )