File size: 19,846 Bytes
bf31071
 
b121266
 
 
 
bf31071
 
 
 
b121266
 
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b121266
 
 
 
 
 
 
 
bf31071
 
 
b121266
 
bf31071
 
 
 
 
 
 
 
b121266
 
 
 
 
 
 
 
bf31071
 
 
 
 
b121266
bf31071
 
 
 
 
b121266
 
 
 
 
 
 
 
 
 
 
bf31071
 
 
 
 
 
 
 
b121266
 
bf31071
b121266
bf31071
 
 
 
 
 
 
 
b121266
 
bf31071
b121266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
 
 
 
 
b121266
bf31071
 
 
 
b121266
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
b121266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
b121266
 
bf31071
b121266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
b121266
 
 
 
 
 
bf31071
 
 
 
b121266
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b121266
 
 
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc08380
b121266
 
 
bc08380
b121266
 
 
bc08380
b121266
 
bc08380
b121266
 
 
 
 
 
bc08380
b121266
bc08380
b121266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc08380
b121266
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
"""Processor for handling multimodal swipe inputs (path + text)."""

from __future__ import annotations

from typing import Any

import numpy as np
import torch
from transformers import ProcessorMixin

from .preprocessing import preprocess_raw_path_to_features


class SwipeProcessor(ProcessorMixin):
    """
    Processor for handling multimodal swipe inputs (path coordinates + text).

    This processor combines path coordinate preprocessing with text tokenization,
    creating the inputs needed for SwipeTransformer models.

    Args:
        tokenizer: SwipeTokenizer instance
        max_path_len (int): Maximum path length. Defaults to 64.
        max_char_len (int): Maximum character length. Defaults to 38.
    """

    attributes = ["tokenizer"]
    tokenizer_class = "AutoTokenizer"  # Will use auto_map from tokenizer_config.json

    def __init__(
        self,
        tokenizer=None,
        max_path_len: int = 64,
        max_char_len: int = 38,
        path_input_dim: int = 6,
        path_resample_mode: str = "time",
    ):
        self.tokenizer = tokenizer
        self.max_path_len = max_path_len
        self.max_char_len = max_char_len
        self.path_input_dim = path_input_dim
        self.path_resample_mode = path_resample_mode
        # Attributes expected by newer transformers (not used for swipe models)
        self.chat_template = None
        self.audio_tokenizer = None
        self.feature_extractor = None
        self.image_processor = None

    def __call__(
        self,
        path_coords: (
            list[dict[str, float]]
            | list[list[dict[str, float]]]
            | list[list[list[float]]]
            | torch.Tensor
            | np.ndarray
            | None
        ) = None,
        text: str | list[str] | None = None,
        padding: bool | str = True,
        truncation: bool = True,
        max_length: int | None = None,
        return_tensors: str | None = "pt",
        **kwargs: Any,
    ):
        """
        Process path coordinates and text into model inputs.

        Args:
            path_coords:
                Swipe paths in one of the supported formats:
                - Raw path (single example): list of dicts like `{"x": ..., "y": ..., "t": ...}`
                - Raw batch: list of raw paths
                - Numeric arrays/tensors: `[batch, path_len, D]` or `[path_len, D]`
                If `D==3` and `path_input_dim==6`, raw `(x,y,t)` triples are converted to engineered
                `(x, y, dx, dy, ds, log_dt)` features and resampled to `max_path_len`.
                If omitted, the processor emits a zero path with a zero path attention mask.
            text:
                String or list of strings to encode.
                If omitted, the processor emits padded text tokens with a zero text attention mask.
            padding: Whether to pad sequences. Can be True/False or "max_length"
            truncation: Whether to truncate sequences
            max_length: Maximum sequence length for text (overrides max_char_len)
            return_tensors: "pt" for PyTorch, "np" for NumPy, None for lists
            **kwargs: Additional keyword arguments

        Returns:
            Dictionary with:
                - path_coords: [batch, max_path_len, path_input_dim] (if path_coords provided)
                  Default: [batch, max_path_len, 6] for (x, y, dx, dy, ds, log_dt)
                - input_ids: [batch, max_char_len] (if text provided)
                - attention_mask: [batch, total_seq_len] (covers `[CLS] + path + [SEP] + text`)
        """
        if path_coords is None and text is None:
            raise ValueError("Must provide either path_coords or text (or both)")

        # Determine batch size
        if path_coords is not None:
            # Handle path coordinates
            if isinstance(path_coords, (list, tuple)):
                if len(path_coords) == 0:
                    batch_size = 1
                else:
                    first = path_coords[0]
                    # Raw single path: [{"x","y","t"}, ...]
                    if isinstance(first, dict):
                        batch_size = 1
                    # Raw batch of paths: [[{"x","y","t"}, ...], ...]
                    elif (
                        isinstance(first, (list, tuple))
                        and len(first) > 0
                        and isinstance(first[0], dict)
                    ):
                        batch_size = len(path_coords)
                    # Numeric batch: [[[...], ...], ...] where points are lists/tuples
                    elif (
                        isinstance(first, (list, tuple))
                        and len(first) > 0
                        and isinstance(first[0], (list, tuple))
                    ):
                        path_coords = torch.tensor(path_coords, dtype=torch.float32)
                        batch_size = path_coords.shape[0]
                    else:
                        # Numeric single path: [[...], [...], ...]
                        path_coords = torch.tensor([path_coords], dtype=torch.float32)
                        batch_size = path_coords.shape[0]
            elif isinstance(path_coords, np.ndarray):
                path_coords = torch.from_numpy(path_coords).float()
                if path_coords.dim() == 2:
                    # Single path, add batch dimension
                    path_coords = path_coords.unsqueeze(0)
                batch_size = path_coords.shape[0]
            elif isinstance(path_coords, torch.Tensor):
                if path_coords.dim() == 2:
                    # Single path, add batch dimension
                    path_coords = path_coords.unsqueeze(0)
                batch_size = path_coords.shape[0]
        elif text is not None:
            if isinstance(text, str):
                batch_size = 1
                text = [text]
            else:
                batch_size = len(text)
        else:
            batch_size = 1

        result = {}

        # Process path coordinates
        if path_coords is not None:
            # Check if path_coords is raw data (list of dicts) or already a tensor
            if isinstance(path_coords, (list, tuple)) and len(path_coords) > 0:
                first_elem = path_coords[0]

                # Raw single path: [{"x","y","t"}, ...]
                if isinstance(first_elem, dict) and "x" in first_elem:
                    path_feats, mask = preprocess_raw_path_to_features(
                        path_coords,
                        self.max_path_len,
                        resample_mode=self.path_resample_mode,
                    )
                    if return_tensors == "pt":
                        path_coords = torch.from_numpy(path_feats).float().unsqueeze(0)
                        _path_mask = torch.from_numpy(mask).long().unsqueeze(0)
                    else:
                        path_coords = np.expand_dims(path_feats, axis=0)
                        _path_mask = np.expand_dims(mask, axis=0)

                # Raw batch of paths: [[{"x","y","t"}, ...], ...]
                elif (
                    isinstance(first_elem, (list, tuple))
                    and len(first_elem) > 0
                    and isinstance(first_elem[0], dict)
                    and "x" in first_elem[0]
                ):
                    processed_paths = []
                    path_masks = []
                    for path in path_coords:
                        path_feats, mask = preprocess_raw_path_to_features(
                            path,
                            self.max_path_len,
                            resample_mode=self.path_resample_mode,
                        )
                        processed_paths.append(path_feats)
                        path_masks.append(mask)

                    path_coords = np.stack(processed_paths)  # [batch, max_path_len, 6]
                    _path_mask = np.stack(path_masks)  # [batch, max_path_len]

                    if return_tensors == "pt":
                        path_coords = torch.from_numpy(path_coords).float()
                        _path_mask = torch.from_numpy(_path_mask).long()

                else:
                    # Numeric list input; process as before
                    path_coords = torch.tensor(path_coords, dtype=torch.float32)
                    if path_coords.dim() == 2:
                        path_coords = path_coords.unsqueeze(0)

                    current_path_len = path_coords.shape[1]
                    if truncation and current_path_len > self.max_path_len:
                        path_coords = path_coords[:, : self.max_path_len, :]
                    if padding and current_path_len < self.max_path_len:
                        pad_len = self.max_path_len - current_path_len
                        pad_shape = (batch_size, pad_len, self.path_input_dim)
                        path_coords = torch.cat([path_coords, torch.zeros(pad_shape)], dim=1)

                    _path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
                    is_padding = (path_coords == 0).all(dim=-1)
                    _path_mask[is_padding] = 0
            elif isinstance(path_coords, np.ndarray):
                path_coords = torch.from_numpy(path_coords).float()
                if path_coords.dim() == 2:
                    path_coords = path_coords.unsqueeze(0)
                # If user provided raw (x,y,t) triples but model expects engineered features,
                # convert to motion features and resample.
                if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
                    processed_paths = []
                    path_masks = []
                    for path in path_coords.cpu().numpy():
                        raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
                        path_feats, mask = preprocess_raw_path_to_features(
                            raw,
                            self.max_path_len,
                            resample_mode=self.path_resample_mode,
                        )
                        processed_paths.append(path_feats)
                        path_masks.append(mask)

                    path_coords = torch.from_numpy(np.stack(processed_paths)).float()
                    _path_mask = torch.from_numpy(np.stack(path_masks)).long()
                else:
                    _path_mask = torch.ones(
                        path_coords.shape[0], self.max_path_len, dtype=torch.long
                    )
            elif isinstance(path_coords, torch.Tensor):
                if path_coords.dim() == 2:
                    path_coords = path_coords.unsqueeze(0)
                # If user provided raw (x,y,t) triples but model expects engineered features,
                # convert to motion features and resample.
                if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
                    processed_paths = []
                    path_masks = []
                    for path in path_coords.detach().cpu().numpy():
                        raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
                        path_feats, mask = preprocess_raw_path_to_features(
                            raw,
                            self.max_path_len,
                            resample_mode=self.path_resample_mode,
                        )
                        processed_paths.append(path_feats)
                        path_masks.append(mask)

                    path_coords = torch.from_numpy(np.stack(processed_paths)).float()
                    _path_mask = torch.from_numpy(np.stack(path_masks)).long()
                else:
                    _path_mask = torch.ones(
                        path_coords.shape[0], self.max_path_len, dtype=torch.long
                    )

            result["path_coords"] = path_coords
        else:
            # No path coords provided, create empty/zero tensors
            path_coords = torch.zeros(batch_size, self.max_path_len, self.path_input_dim)
            _path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
            result["path_coords"] = path_coords

        # Process text
        if text is not None:
            # Ensure text is a list
            if isinstance(text, str):
                text = [text]

            # Tokenize text
            text_max_length = max_length if max_length is not None else self.max_char_len

            # First tokenize without padding/truncation to add EOS
            encoded_raw = self.tokenizer(
                text,
                padding=False,
                truncation=False,
                return_tensors=None,  # Get lists first
                **kwargs,
            )

            # Add EOS token after each word (matching training dataset behavior)
            eos_id = self.tokenizer.eos_token_id
            for i in range(len(encoded_raw["input_ids"])):
                # Add EOS if not already present
                if encoded_raw["input_ids"][i][-1] != eos_id:
                    encoded_raw["input_ids"][i].append(eos_id)

            # Now apply padding and truncation
            max_len_needed = max(len(ids) for ids in encoded_raw["input_ids"])
            if truncation and max_len_needed > text_max_length:
                # Truncate but preserve EOS at the end
                for i in range(len(encoded_raw["input_ids"])):
                    if len(encoded_raw["input_ids"][i]) > text_max_length:
                        encoded_raw["input_ids"][i] = encoded_raw["input_ids"][i][
                            : text_max_length - 1
                        ] + [eos_id]

            # Pad sequences
            if padding:
                pad_id = self.tokenizer.pad_token_id
                for i in range(len(encoded_raw["input_ids"])):
                    seq_len = len(encoded_raw["input_ids"][i])
                    if seq_len < text_max_length:
                        encoded_raw["input_ids"][i].extend([pad_id] * (text_max_length - seq_len))

            # Create attention mask (1 for real tokens + EOS, 0 for padding)
            _char_mask = []
            for ids in encoded_raw["input_ids"]:
                mask = [1 if token_id != self.tokenizer.pad_token_id else 0 for token_id in ids]
                _char_mask.append(mask)

            # Convert to tensors if requested
            if return_tensors == "pt":
                result["input_ids"] = torch.tensor(encoded_raw["input_ids"], dtype=torch.long)
                _char_mask = torch.tensor(_char_mask, dtype=torch.long)
            elif return_tensors == "np":
                result["input_ids"] = np.array(encoded_raw["input_ids"], dtype=np.int64)
                _char_mask = np.array(_char_mask, dtype=np.int64)
            else:
                result["input_ids"] = encoded_raw["input_ids"]
        else:
            # No text provided, create padding tokens
            if return_tensors == "pt":
                char_tokens = torch.full(
                    (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=torch.long
                )
                _char_mask = torch.zeros(batch_size, self.max_char_len, dtype=torch.long)
            elif return_tensors == "np":
                char_tokens = np.full(
                    (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=np.int64
                )
                _char_mask = np.zeros((batch_size, self.max_char_len), dtype=np.int64)
            else:
                char_tokens = [
                    [self.tokenizer.pad_token_id] * self.max_char_len for _ in range(batch_size)
                ]
                _char_mask = [[0] * self.max_char_len for _ in range(batch_size)]

            result["input_ids"] = char_tokens

        # Create combined attention mask: [CLS] + path + [SEP] + chars
        # Sequence structure: [CLS:1] + _path_mask + [SEP:1] + _char_mask
        if return_tensors == "pt":
            cls_mask = torch.ones(batch_size, 1, dtype=torch.long)
            sep_mask = torch.ones(batch_size, 1, dtype=torch.long)
            attention_mask = torch.cat([cls_mask, _path_mask, sep_mask, _char_mask], dim=1)
        elif return_tensors == "np":
            cls_mask = np.ones((batch_size, 1), dtype=np.int64)
            sep_mask = np.ones((batch_size, 1), dtype=np.int64)
            attention_mask = np.concatenate([cls_mask, _path_mask, sep_mask, _char_mask], axis=1)
        else:
            cls_mask = [[1] for _ in range(batch_size)]
            sep_mask = [[1] for _ in range(batch_size)]
            attention_mask = [
                cls + path.tolist() + sep + char
                for cls, path, sep, char in zip(
                    cls_mask, _path_mask, sep_mask, _char_mask, strict=False
                )
            ]

        result["attention_mask"] = attention_mask

        # Convert to requested format
        if return_tensors == "np":
            for key in result:
                if isinstance(result[key], torch.Tensor):
                    result[key] = result[key].numpy()
        elif return_tensors is None:
            for key in result:
                if isinstance(result[key], torch.Tensor):
                    result[key] = result[key].tolist()

        return result

    def batch_decode(self, token_ids, **kwargs):
        """
        Decode token IDs to strings.

        Args:
            token_ids: Token IDs to decode
            **kwargs: Additional arguments passed to tokenizer

        Returns:
            List of decoded strings
        """
        return self.tokenizer.batch_decode(token_ids, **kwargs)

    def decode(self, token_ids, **kwargs):
        """
        Decode single sequence of token IDs to string.

        Args:
            token_ids: Token IDs to decode
            **kwargs: Additional arguments passed to tokenizer

        Returns:
            Decoded string
        """
        return self.tokenizer.decode(token_ids, **kwargs)

    def encode_path(self, path_coords, *, return_tensors: str | None = "pt", **kwargs: Any):
        """Create model inputs from a swipe path only (no text)."""
        return self(path_coords=path_coords, text=None, return_tensors=return_tensors, **kwargs)

    def encode_text(self, text, *, return_tensors: str | None = "pt", **kwargs: Any):
        """Create model inputs from text only (no path)."""
        return self(path_coords=None, text=text, return_tensors=return_tensors, **kwargs)

    # Preprocessing methods are now imported from shared preprocessing module
    # See src/swipealot/data/preprocessing.py for the implementation

    def save_pretrained(
        self,
        save_directory,
        push_to_hub=False,
        **kwargs,
    ):
        """
        Save the processor to a directory, ensuring auto_map is included.
        """
        # Call parent save_pretrained
        result = super().save_pretrained(
            save_directory,
            push_to_hub=push_to_hub,
            **kwargs,
        )

        # Add auto_map to processor_config.json for AutoProcessor compatibility
        import json
        from pathlib import Path

        # Try both possible config file names
        for config_name in ["preprocessor_config.json", "processor_config.json"]:
            processor_config_path = Path(save_directory) / config_name
            if processor_config_path.exists():
                with open(processor_config_path) as f:
                    config = json.load(f)

                config["auto_map"] = {"AutoProcessor": "processing_swipe.SwipeProcessor"}

                with open(processor_config_path, "w") as f:
                    json.dump(config, f, indent=2)
                break

        return result