File size: 11,115 Bytes
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""Processor for handling multimodal swipe inputs (path + text)."""

import numpy as np
import torch
from transformers import ProcessorMixin


class SwipeProcessor(ProcessorMixin):
    """
    Processor for handling multimodal swipe inputs (path coordinates + text).

    This processor combines path coordinate preprocessing with text tokenization,
    creating the inputs needed for SwipeTransformer models.

    Args:
        tokenizer: SwipeTokenizer instance
        max_path_len (int): Maximum path length. Defaults to 64.
        max_char_len (int): Maximum character length. Defaults to 38.
    """

    attributes = ["tokenizer"]
    tokenizer_class = "AutoTokenizer"  # Will use auto_map from tokenizer_config.json

    def __init__(self, tokenizer=None, max_path_len: int = 64, max_char_len: int = 38):
        self.tokenizer = tokenizer
        self.max_path_len = max_path_len
        self.max_char_len = max_char_len
        # Attributes expected by newer transformers (not used for swipe models)
        self.chat_template = None
        self.audio_tokenizer = None
        self.feature_extractor = None
        self.image_processor = None

    def __call__(
        self,
        path_coords: list[list[list[float]]] | torch.Tensor | np.ndarray | None = None,
        text: str | list[str] | None = None,
        padding: bool | str = True,
        truncation: bool = True,
        max_length: int | None = None,
        return_tensors: str | None = "pt",
        **kwargs,
    ):
        """
        Process path coordinates and text into model inputs.

        Args:
            path_coords: List of paths or tensor [batch, path_len, 3]
                        Each point is (x, y, time). Can be None if only processing text.
            text: String or list of strings to encode. Can be None if only processing paths.
            padding: Whether to pad sequences. Can be True/False or "max_length"
            truncation: Whether to truncate sequences
            max_length: Maximum sequence length for text (overrides max_char_len)
            return_tensors: "pt" for PyTorch, "np" for NumPy, None for lists
            **kwargs: Additional keyword arguments

        Returns:
            Dictionary with:
                - path_coords: [batch, max_path_len, 3] (if path_coords provided)
                - input_ids: [batch, max_char_len] (if text provided)
                - attention_mask: [batch, total_seq_len]
        """
        if path_coords is None and text is None:
            raise ValueError("Must provide either path_coords or text (or both)")

        # Determine batch size
        if path_coords is not None:
            # Handle path coordinates
            if isinstance(path_coords, (list, tuple)):
                # Check if it's a batch or single path
                if len(path_coords) > 0 and isinstance(path_coords[0][0], (list, tuple)):
                    # Batch of paths [[path1], [path2], ...]
                    path_coords = torch.tensor(path_coords, dtype=torch.float32)
                else:
                    # Single path [[x,y,t], [x,y,t], ...]
                    path_coords = torch.tensor([path_coords], dtype=torch.float32)
            elif isinstance(path_coords, np.ndarray):
                path_coords = torch.from_numpy(path_coords).float()
                if path_coords.dim() == 2:
                    # Single path, add batch dimension
                    path_coords = path_coords.unsqueeze(0)
            elif isinstance(path_coords, torch.Tensor):
                if path_coords.dim() == 2:
                    # Single path, add batch dimension
                    path_coords = path_coords.unsqueeze(0)

            batch_size = path_coords.shape[0]
        elif text is not None:
            if isinstance(text, str):
                batch_size = 1
                text = [text]
            else:
                batch_size = len(text)
        else:
            batch_size = 1

        result = {}

        # Process path coordinates
        if path_coords is not None:
            current_path_len = path_coords.shape[1]

            # Truncate if needed
            if truncation and current_path_len > self.max_path_len:
                path_coords = path_coords[:, : self.max_path_len, :]
                current_path_len = self.max_path_len

            # Pad if needed
            if padding and current_path_len < self.max_path_len:
                pad_len = self.max_path_len - current_path_len
                path_coords = torch.cat([path_coords, torch.zeros(batch_size, pad_len, 3)], dim=1)

            # Create path mask (1 = real data, 0 = padding)
            path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
            if padding and current_path_len < self.max_path_len:
                path_mask[:, current_path_len:] = 0

            result["path_coords"] = path_coords
            # Store path_mask internally for attention_mask construction
            _path_mask = path_mask
        else:
            # No path coords provided, create empty/zero tensors
            path_coords = torch.zeros(batch_size, self.max_path_len, 3)
            _path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
            result["path_coords"] = path_coords

        # Process text
        if text is not None:
            # Ensure text is a list
            if isinstance(text, str):
                text = [text]

            # Tokenize text
            text_max_length = max_length if max_length is not None else self.max_char_len

            # First tokenize without padding/truncation to add EOS
            encoded_raw = self.tokenizer(
                text,
                padding=False,
                truncation=False,
                return_tensors=None,  # Get lists first
                **kwargs,
            )

            # Add EOS token after each word (matching training dataset behavior)
            eos_id = self.tokenizer.eos_token_id
            for i in range(len(encoded_raw["input_ids"])):
                # Add EOS if not already present
                if encoded_raw["input_ids"][i][-1] != eos_id:
                    encoded_raw["input_ids"][i].append(eos_id)

            # Now apply padding and truncation
            max_len_needed = max(len(ids) for ids in encoded_raw["input_ids"])
            if truncation and max_len_needed > text_max_length:
                # Truncate but preserve EOS at the end
                for i in range(len(encoded_raw["input_ids"])):
                    if len(encoded_raw["input_ids"][i]) > text_max_length:
                        encoded_raw["input_ids"][i] = (
                            encoded_raw["input_ids"][i][: text_max_length - 1] + [eos_id]
                        )

            # Pad sequences
            if padding:
                pad_id = self.tokenizer.pad_token_id
                for i in range(len(encoded_raw["input_ids"])):
                    seq_len = len(encoded_raw["input_ids"][i])
                    if seq_len < text_max_length:
                        encoded_raw["input_ids"][i].extend([pad_id] * (text_max_length - seq_len))

            # Create attention mask (1 for real tokens + EOS, 0 for padding)
            _char_mask = []
            for ids in encoded_raw["input_ids"]:
                mask = [1 if token_id != self.tokenizer.pad_token_id else 0 for token_id in ids]
                _char_mask.append(mask)

            # Convert to tensors if requested
            if return_tensors == "pt":
                result["input_ids"] = torch.tensor(encoded_raw["input_ids"], dtype=torch.long)
                _char_mask = torch.tensor(_char_mask, dtype=torch.long)
            elif return_tensors == "np":
                result["input_ids"] = np.array(encoded_raw["input_ids"], dtype=np.int64)
                _char_mask = np.array(_char_mask, dtype=np.int64)
            else:
                result["input_ids"] = encoded_raw["input_ids"]
        else:
            # No text provided, create padding tokens
            if return_tensors == "pt":
                char_tokens = torch.full(
                    (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=torch.long
                )
                _char_mask = torch.zeros(batch_size, self.max_char_len, dtype=torch.long)
            elif return_tensors == "np":
                char_tokens = np.full(
                    (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=np.int64
                )
                _char_mask = np.zeros((batch_size, self.max_char_len), dtype=np.int64)
            else:
                char_tokens = [
                    [self.tokenizer.pad_token_id] * self.max_char_len for _ in range(batch_size)
                ]
                _char_mask = [[0] * self.max_char_len for _ in range(batch_size)]

            result["input_ids"] = char_tokens

        # Create combined attention mask: [CLS] + path + [SEP] + chars
        # Sequence structure: [CLS:1] + _path_mask + [SEP:1] + _char_mask
        if return_tensors == "pt":
            cls_mask = torch.ones(batch_size, 1, dtype=torch.long)
            sep_mask = torch.ones(batch_size, 1, dtype=torch.long)
            attention_mask = torch.cat([cls_mask, _path_mask, sep_mask, _char_mask], dim=1)
        elif return_tensors == "np":
            cls_mask = np.ones((batch_size, 1), dtype=np.int64)
            sep_mask = np.ones((batch_size, 1), dtype=np.int64)
            attention_mask = np.concatenate([cls_mask, _path_mask, sep_mask, _char_mask], axis=1)
        else:
            cls_mask = [[1] for _ in range(batch_size)]
            sep_mask = [[1] for _ in range(batch_size)]
            attention_mask = [
                cls + path.tolist() + sep + char
                for cls, path, sep, char in zip(
                    cls_mask, _path_mask, sep_mask, _char_mask, strict=False
                )
            ]

        result["attention_mask"] = attention_mask

        # Convert to requested format
        if return_tensors == "np":
            for key in result:
                if isinstance(result[key], torch.Tensor):
                    result[key] = result[key].numpy()
        elif return_tensors is None:
            for key in result:
                if isinstance(result[key], torch.Tensor):
                    result[key] = result[key].tolist()

        return result

    def batch_decode(self, token_ids, **kwargs):
        """
        Decode token IDs to strings.

        Args:
            token_ids: Token IDs to decode
            **kwargs: Additional arguments passed to tokenizer

        Returns:
            List of decoded strings
        """
        return self.tokenizer.batch_decode(token_ids, **kwargs)

    def decode(self, token_ids, **kwargs):
        """
        Decode single sequence of token IDs to string.

        Args:
            token_ids: Token IDs to decode
            **kwargs: Additional arguments passed to tokenizer

        Returns:
            Decoded string
        """
        return self.tokenizer.decode(token_ids, **kwargs)