dleemiller commited on
Commit
fac0dcc
·
verified ·
1 Parent(s): 137450c

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -19,6 +19,7 @@
19
  "predict_char": true,
20
  "predict_length": true,
21
  "predict_path": true,
 
22
  "sep_token_id": 2,
23
  "transformers_version": "4.57.3",
24
  "unk_token_id": 4,
 
19
  "predict_char": true,
20
  "predict_length": true,
21
  "predict_path": true,
22
+ "predict_path_uncertainty": true,
23
  "sep_token_id": 2,
24
  "transformers_version": "4.57.3",
25
  "unk_token_id": 4,
configuration_swipe.py CHANGED
@@ -22,6 +22,8 @@ class SwipeTransformerConfig(PretrainedConfig):
22
  max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
23
  path_input_dim (int, optional): Path feature dimension. Defaults to 6 for (x, y, dx, dy, ds, log_dt).
24
  predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
 
 
25
  pad_token_id (int, optional): Padding token ID. Defaults to 0.
26
  cls_token_id (int, optional): CLS token ID. Defaults to 1.
27
  sep_token_id (int, optional): SEP token ID. Defaults to 2.
@@ -45,6 +47,7 @@ class SwipeTransformerConfig(PretrainedConfig):
45
  path_input_dim: int = 6,
46
  predict_char: bool = True,
47
  predict_path: bool = True,
 
48
  predict_length: bool = True,
49
  pad_token_id: int = 0,
50
  cls_token_id: int = 1,
@@ -72,6 +75,7 @@ class SwipeTransformerConfig(PretrainedConfig):
72
  # Model capabilities
73
  self.predict_char = predict_char
74
  self.predict_path = predict_path
 
75
  self.predict_length = predict_length
76
 
77
  # Special tokens
 
22
  max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
23
  path_input_dim (int, optional): Path feature dimension. Defaults to 6 for (x, y, dx, dy, ds, log_dt).
24
  predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
25
+ predict_path_uncertainty (bool, optional): Whether to predict log sigma for path coords.
26
+ Defaults to False.
27
  pad_token_id (int, optional): Padding token ID. Defaults to 0.
28
  cls_token_id (int, optional): CLS token ID. Defaults to 1.
29
  sep_token_id (int, optional): SEP token ID. Defaults to 2.
 
47
  path_input_dim: int = 6,
48
  predict_char: bool = True,
49
  predict_path: bool = True,
50
+ predict_path_uncertainty: bool = False,
51
  predict_length: bool = True,
52
  pad_token_id: int = 0,
53
  cls_token_id: int = 1,
 
75
  # Model capabilities
76
  self.predict_char = predict_char
77
  self.predict_path = predict_path
78
+ self.predict_path_uncertainty = predict_path_uncertainty
79
  self.predict_length = predict_length
80
 
81
  # Special tokens
heads.py CHANGED
@@ -58,7 +58,10 @@ class PathPredictionHead(nn.Module):
58
  # - ds is non-negative
59
  # - log_dt is non-negative
60
  if features.shape[-1] == 6:
61
- x_y = torch.sigmoid(features[..., 0:2])
 
 
 
62
  dx_dy = torch.tanh(features[..., 2:4])
63
  ds = torch.nn.functional.softplus(features[..., 4:5])
64
  log_dt = torch.nn.functional.softplus(features[..., 5:6])
@@ -68,6 +71,30 @@ class PathPredictionHead(nn.Module):
68
  return features
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  class LengthPredictionHead(nn.Module):
72
  """Regress sequence length (e.g., swipable character count) from CLS embedding."""
73
 
 
58
  # - ds is non-negative
59
  # - log_dt is non-negative
60
  if features.shape[-1] == 6:
61
+ # Use sigmoid(2x) to avoid center bias that sigmoid(x) has
62
+ # Mathematical identity: sigmoid(2x) = 0.5(tanh(x)+1)
63
+ # The 2x scaling provides steeper gradients, helping escape center attraction
64
+ x_y = torch.sigmoid(2.0 * features[..., 0:2])
65
  dx_dy = torch.tanh(features[..., 2:4])
66
  ds = torch.nn.functional.softplus(features[..., 4:5])
67
  log_dt = torch.nn.functional.softplus(features[..., 5:6])
 
71
  return features
72
 
73
 
74
+ class PathUncertaintyHead(nn.Module):
75
+ """Prediction head for log sigma of path coordinates."""
76
+
77
+ def __init__(self, d_model: int, output_dim: int = 6):
78
+ super().__init__()
79
+ self.dense = nn.Linear(d_model, d_model)
80
+ self.layer_norm = nn.LayerNorm(d_model)
81
+ self.decoder = nn.Linear(d_model, output_dim)
82
+ self.activation = nn.GELU()
83
+
84
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
85
+ """
86
+ Args:
87
+ hidden_states: [batch, seq_len, d_model]
88
+
89
+ Returns:
90
+ [batch, seq_len, output_dim] log sigma values.
91
+ """
92
+ x = self.dense(hidden_states)
93
+ x = self.activation(x)
94
+ x = self.layer_norm(x)
95
+ return self.decoder(x)
96
+
97
+
98
  class LengthPredictionHead(nn.Module):
99
  """Regress sequence length (e.g., swipable character count) from CLS embedding."""
100
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:452880eddd71bf06c0c61750ae3f5ba65670c15aa97da8233df856ac74bd3e56
3
- size 348207344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffae9d3acc0266e3856bc25d739afd98000f3b09aa5c1bec6e0c4e9503b6ddbf
3
+ size 350594936
modeling_swipe.py CHANGED
@@ -22,6 +22,8 @@ class SwipeTransformerOutput(ModelOutput):
22
  Prediction scores of the character prediction head (text segment only).
23
  path_logits (`torch.FloatTensor` of shape `(batch_size, path_length, path_input_dim)`, *optional*):
24
  Prediction scores of the path prediction head (path segment only, if enabled).
 
 
25
  length_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
26
  Predicted length from the length head (if enabled).
27
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
@@ -39,6 +41,7 @@ class SwipeTransformerOutput(ModelOutput):
39
  loss: torch.FloatTensor | None = None
40
  char_logits: torch.FloatTensor | None = None
41
  path_logits: torch.FloatTensor | None = None
 
42
  length_logits: torch.FloatTensor | None = None
43
  last_hidden_state: torch.FloatTensor | None = None
44
  pooler_output: torch.FloatTensor | None = None
@@ -88,7 +91,12 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
88
 
89
  # Import existing components
90
  from .embeddings import MixedEmbedding
91
- from .heads import CharacterPredictionHead, LengthPredictionHead, PathPredictionHead
 
 
 
 
 
92
 
93
  # Embeddings
94
  self.embeddings = MixedEmbedding(
@@ -130,8 +138,14 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
130
  self.path_head = PathPredictionHead(
131
  d_model=config.d_model, output_dim=config.path_input_dim
132
  )
 
 
 
 
 
133
  else:
134
  self.path_head = None
 
135
 
136
  # Length prediction head (predicts word length from path)
137
  # Max length is max_char_len (including EOS)
@@ -142,87 +156,37 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
142
  # Initialize weights
143
  self.post_init()
144
 
145
- def forward(
146
- self,
147
- input_ids: torch.Tensor,
148
- path_coords: torch.Tensor,
149
- attention_mask: torch.Tensor | None = None,
150
- labels: torch.Tensor | dict | None = None,
151
- return_dict: bool | None = None,
152
- output_hidden_states: bool | None = None,
153
- output_attentions: bool | None = None,
154
- **kwargs,
155
- ):
156
- """
157
- Forward pass of the model.
158
-
159
- Args:
160
- input_ids (torch.Tensor): Character token IDs [batch, char_len]
161
- path_coords (torch.Tensor): Path features [batch, path_len, path_input_dim]
162
- Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
163
- attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
164
- labels (torch.Tensor or dict, optional): Labels for loss calculation
165
- Can be tensor [batch, char_len] or dict with keys like char_labels, path_labels
166
- return_dict (bool, optional): Whether to return ModelOutput object
167
- output_hidden_states (bool, optional): Whether to output hidden states
168
- output_attentions (bool, optional): Whether to output attention weights
169
- **kwargs: Additional arguments (for compatibility)
170
-
171
- Returns:
172
- SwipeTransformerOutput or tuple: Model outputs with:
173
- - loss: Optional loss value
174
- - char_logits: Character prediction logits [batch, char_len, vocab_size] (if enabled)
175
- - path_logits: Path prediction logits [batch, path_len, path_input_dim] (if enabled)
176
- - length_logits: Length regression output [batch] (if enabled)
177
- - last_hidden_state: Hidden states [batch, seq_len, d_model]
178
- - pooler_output: SEP token embedding [batch, d_model] for similarity/embedding tasks
179
- - hidden_states: Tuple of per-layer hidden states (if output_hidden_states=True)
180
- - attentions: Tuple of per-layer attention weights (if output_attentions=True)
181
- """
182
- # Validate required inputs
183
- if input_ids is None or path_coords is None:
184
- raise ValueError("Both input_ids and path_coords are required")
185
-
186
- # Extract labels if dict (used by custom trainers)
187
- if isinstance(labels, dict):
188
- char_labels = labels.get("char_labels")
189
- # Can handle other label types in the future (path_labels, etc.)
190
- else:
191
- char_labels = labels
192
-
193
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
194
- output_hidden_states = (
195
- output_hidden_states
196
- if output_hidden_states is not None
197
- else self.config.output_hidden_states
198
- )
199
- output_attentions = (
200
- output_attentions if output_attentions is not None else self.config.output_attentions
201
- )
202
-
203
- batch_size = path_coords.shape[0]
204
- device = path_coords.device
205
-
206
- # Create [CLS] and [SEP] tokens
207
  cls_token = torch.full(
208
- (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
 
 
 
209
  )
210
  sep_token = torch.full(
211
- (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
 
 
 
212
  )
 
213
 
214
- # Get embeddings
215
- embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
216
-
217
- # Prepare attention mask for encoder
218
- if attention_mask is not None:
219
- # Convert attention mask: 1 = attend, 0 = ignore
220
- # PyTorch expects: False = attend, True = ignore
221
- src_key_padding_mask = attention_mask == 0
222
- else:
223
- src_key_padding_mask = None
224
 
225
- # Encode while optionally capturing attentions and per-layer hidden states.
 
 
 
 
 
 
 
 
226
  attentions: tuple[torch.Tensor, ...] | None = None
227
  hidden_states_by_layer: list[torch.Tensor] | None = [] if output_hidden_states else None
228
 
@@ -296,69 +260,165 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
296
  if idx in original_forwards:
297
  layer.self_attn.forward = original_forwards[idx]
298
 
299
- path_len = path_coords.shape[1]
300
- char_len = input_ids.shape[1]
 
 
 
301
 
302
- # Character prediction (text segment only)
 
 
 
 
 
 
 
 
 
 
 
 
303
  char_logits = None
304
  if self.char_head is not None:
305
- # Sequence is: [CLS] + path + [SEP] + chars
306
  char_start = 1 + path_len + 1
307
  char_hidden = hidden_states[:, char_start : char_start + char_len, :]
308
  char_logits = self.char_head(char_hidden)
309
 
310
- # Path prediction (path segment only, if enabled)
311
  path_logits = None
 
312
  if self.path_head is not None:
313
  path_hidden = hidden_states[:, 1 : 1 + path_len, :]
314
  path_logits = self.path_head(path_hidden)
 
 
315
 
316
- # Length prediction from CLS token
317
- cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
318
  length_logits = self.length_head(cls_hidden) if self.length_head is not None else None
319
 
320
- # Extract SEP token embedding for pooler output (embeddings/similarity tasks)
321
- # SEP is at position 1 + path_len
322
  sep_position = 1 + path_len
323
- pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
324
-
325
- # Compute loss if labels provided (masked-only; -100 = ignore)
326
- loss = None
327
- if char_labels is not None and self.char_head is not None:
328
- # Predict only the text segment
329
- char_pred = char_logits # [B, char_len, V]
330
- labels_flat = char_labels.reshape(-1)
331
- mask = labels_flat != -100
332
- if mask.any():
333
- logits_flat = char_pred.reshape(-1, self.config.vocab_size)[mask]
334
- labels_flat = labels_flat[mask]
335
- loss = nn.functional.cross_entropy(logits_flat, labels_flat, reduction="mean")
336
- else:
337
- loss = torch.tensor(0.0, device=hidden_states.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  if not return_dict:
340
- hidden_tuple = None
341
- if hidden_states_by_layer is not None:
342
- hidden_tuple = (embeddings,) + tuple(hidden_states_by_layer)
343
  output = (
344
  char_logits,
345
  path_logits,
346
  length_logits,
347
  hidden_states,
348
  pooler_output,
349
- hidden_tuple,
350
  attentions,
 
351
  )
352
  return (loss,) + output if loss is not None else output
353
 
354
- all_hidden_states = None
355
- if hidden_states_by_layer is not None:
356
- all_hidden_states = (embeddings,) + tuple(hidden_states_by_layer)
357
-
358
  return SwipeTransformerOutput(
359
  loss=loss,
360
  char_logits=char_logits,
361
  path_logits=path_logits,
 
362
  length_logits=length_logits,
363
  last_hidden_state=hidden_states,
364
  pooler_output=pooler_output,
 
22
  Prediction scores of the character prediction head (text segment only).
23
  path_logits (`torch.FloatTensor` of shape `(batch_size, path_length, path_input_dim)`, *optional*):
24
  Prediction scores of the path prediction head (path segment only, if enabled).
25
+ path_log_sigma (`torch.FloatTensor` of shape `(batch_size, path_length, path_input_dim)`, *optional*):
26
+ Predicted log sigma for path coordinates (path segment only, if enabled).
27
  length_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
28
  Predicted length from the length head (if enabled).
29
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
 
41
  loss: torch.FloatTensor | None = None
42
  char_logits: torch.FloatTensor | None = None
43
  path_logits: torch.FloatTensor | None = None
44
+ path_log_sigma: torch.FloatTensor | None = None
45
  length_logits: torch.FloatTensor | None = None
46
  last_hidden_state: torch.FloatTensor | None = None
47
  pooler_output: torch.FloatTensor | None = None
 
91
 
92
  # Import existing components
93
  from .embeddings import MixedEmbedding
94
+ from .heads import (
95
+ CharacterPredictionHead,
96
+ LengthPredictionHead,
97
+ PathPredictionHead,
98
+ PathUncertaintyHead,
99
+ )
100
 
101
  # Embeddings
102
  self.embeddings = MixedEmbedding(
 
138
  self.path_head = PathPredictionHead(
139
  d_model=config.d_model, output_dim=config.path_input_dim
140
  )
141
+ self.path_log_sigma_head = (
142
+ PathUncertaintyHead(d_model=config.d_model, output_dim=config.path_input_dim)
143
+ if config.predict_path_uncertainty
144
+ else None
145
+ )
146
  else:
147
  self.path_head = None
148
+ self.path_log_sigma_head = None
149
 
150
  # Length prediction head (predicts word length from path)
151
  # Max length is max_char_len (including EOS)
 
156
  # Initialize weights
157
  self.post_init()
158
 
159
+ def _make_special_tokens(
160
+ self, batch_size: int, *, device: torch.device
161
+ ) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  cls_token = torch.full(
163
+ (batch_size, 1),
164
+ fill_value=self.config.cls_token_id,
165
+ dtype=torch.long,
166
+ device=device,
167
  )
168
  sep_token = torch.full(
169
+ (batch_size, 1),
170
+ fill_value=self.config.sep_token_id,
171
+ dtype=torch.long,
172
+ device=device,
173
  )
174
+ return cls_token, sep_token
175
 
176
+ def _src_key_padding_mask(self, attention_mask: torch.Tensor | None) -> torch.Tensor | None:
177
+ if attention_mask is None:
178
+ return None
179
+ return attention_mask == 0
 
 
 
 
 
 
180
 
181
+ def _encode(
182
+ self,
183
+ embeddings: torch.Tensor,
184
+ *,
185
+ src_key_padding_mask: torch.Tensor | None,
186
+ output_hidden_states: bool,
187
+ output_attentions: bool,
188
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, ...] | None, tuple[torch.Tensor, ...] | None]:
189
+ """Run encoder with optional per-layer hidden-state + attention capture."""
190
  attentions: tuple[torch.Tensor, ...] | None = None
191
  hidden_states_by_layer: list[torch.Tensor] | None = [] if output_hidden_states else None
192
 
 
260
  if idx in original_forwards:
261
  layer.self_attn.forward = original_forwards[idx]
262
 
263
+ all_hidden_states = None
264
+ if hidden_states_by_layer is not None:
265
+ all_hidden_states = (embeddings,) + tuple(hidden_states_by_layer)
266
+
267
+ return hidden_states, all_hidden_states, attentions
268
 
269
+ def _heads(
270
+ self,
271
+ hidden_states: torch.Tensor,
272
+ *,
273
+ path_len: int,
274
+ char_len: int,
275
+ ) -> tuple[
276
+ torch.Tensor | None,
277
+ torch.Tensor | None,
278
+ torch.Tensor | None,
279
+ torch.Tensor | None,
280
+ torch.Tensor,
281
+ ]:
282
  char_logits = None
283
  if self.char_head is not None:
 
284
  char_start = 1 + path_len + 1
285
  char_hidden = hidden_states[:, char_start : char_start + char_len, :]
286
  char_logits = self.char_head(char_hidden)
287
 
 
288
  path_logits = None
289
+ path_log_sigma = None
290
  if self.path_head is not None:
291
  path_hidden = hidden_states[:, 1 : 1 + path_len, :]
292
  path_logits = self.path_head(path_hidden)
293
+ if self.path_log_sigma_head is not None:
294
+ path_log_sigma = self.path_log_sigma_head(path_hidden)
295
 
296
+ cls_hidden = hidden_states[:, 0, :]
 
297
  length_logits = self.length_head(cls_hidden) if self.length_head is not None else None
298
 
 
 
299
  sep_position = 1 + path_len
300
+ pooler_output = hidden_states[:, sep_position, :]
301
+ return char_logits, path_logits, path_log_sigma, length_logits, pooler_output
302
+
303
+ def _char_loss(
304
+ self,
305
+ *,
306
+ char_logits: torch.Tensor | None,
307
+ char_labels: torch.Tensor | None,
308
+ device: torch.device,
309
+ ) -> torch.Tensor | None:
310
+ if char_labels is None or self.char_head is None or char_logits is None:
311
+ return None
312
+
313
+ labels_flat = char_labels.reshape(-1)
314
+ mask = labels_flat != -100
315
+ if mask.any():
316
+ logits_flat = char_logits.reshape(-1, self.config.vocab_size)[mask]
317
+ labels_flat = labels_flat[mask]
318
+ return nn.functional.cross_entropy(logits_flat, labels_flat, reduction="mean")
319
+ return torch.tensor(0.0, device=device)
320
+
321
+ def forward(
322
+ self,
323
+ input_ids: torch.Tensor,
324
+ path_coords: torch.Tensor,
325
+ attention_mask: torch.Tensor | None = None,
326
+ labels: torch.Tensor | dict | None = None,
327
+ return_dict: bool | None = None,
328
+ output_hidden_states: bool | None = None,
329
+ output_attentions: bool | None = None,
330
+ **kwargs,
331
+ ):
332
+ """
333
+ Forward pass of the model.
334
+
335
+ Args:
336
+ input_ids (torch.Tensor): Character token IDs [batch, char_len]
337
+ path_coords (torch.Tensor): Path features [batch, path_len, path_input_dim]
338
+ Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
339
+ attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
340
+ labels (torch.Tensor or dict, optional): Labels for loss calculation
341
+ Can be tensor [batch, char_len] or dict with keys like char_labels, path_labels
342
+ return_dict (bool, optional): Whether to return ModelOutput object
343
+ output_hidden_states (bool, optional): Whether to output hidden states
344
+ output_attentions (bool, optional): Whether to output attention weights
345
+ **kwargs: Additional arguments (for compatibility)
346
+
347
+ Returns:
348
+ SwipeTransformerOutput or tuple: Model outputs with:
349
+ - loss: Optional loss value
350
+ - char_logits: Character prediction logits [batch, char_len, vocab_size] (if enabled)
351
+ - path_logits: Path prediction logits [batch, path_len, path_input_dim] (if enabled)
352
+ - path_log_sigma: Path log sigma [batch, path_len, path_input_dim] (if enabled)
353
+ - length_logits: Length regression output [batch] (if enabled)
354
+ - last_hidden_state: Hidden states [batch, seq_len, d_model]
355
+ - pooler_output: SEP token embedding [batch, d_model] for similarity/embedding tasks
356
+ - hidden_states: Tuple of per-layer hidden states (if output_hidden_states=True)
357
+ - attentions: Tuple of per-layer attention weights (if output_attentions=True)
358
+ """
359
+ if input_ids is None or path_coords is None:
360
+ raise ValueError("Both input_ids and path_coords are required")
361
+
362
+ if isinstance(labels, dict):
363
+ char_labels = labels.get("char_labels")
364
+ else:
365
+ char_labels = labels
366
+
367
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
368
+ output_hidden_states = (
369
+ output_hidden_states
370
+ if output_hidden_states is not None
371
+ else self.config.output_hidden_states
372
+ )
373
+ output_attentions = (
374
+ output_attentions if output_attentions is not None else self.config.output_attentions
375
+ )
376
+
377
+ batch_size = int(path_coords.shape[0])
378
+ device = path_coords.device
379
+ cls_token, sep_token = self._make_special_tokens(batch_size, device=device)
380
+ embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
381
+
382
+ src_key_padding_mask = self._src_key_padding_mask(attention_mask)
383
+ hidden_states, all_hidden_states, attentions = self._encode(
384
+ embeddings,
385
+ src_key_padding_mask=src_key_padding_mask,
386
+ output_hidden_states=bool(output_hidden_states),
387
+ output_attentions=bool(output_attentions),
388
+ )
389
+
390
+ path_len = int(path_coords.shape[1])
391
+ char_len = int(input_ids.shape[1])
392
+ char_logits, path_logits, path_log_sigma, length_logits, pooler_output = self._heads(
393
+ hidden_states,
394
+ path_len=path_len,
395
+ char_len=char_len,
396
+ )
397
+
398
+ loss = self._char_loss(
399
+ char_logits=char_logits,
400
+ char_labels=char_labels,
401
+ device=hidden_states.device,
402
+ )
403
 
404
  if not return_dict:
 
 
 
405
  output = (
406
  char_logits,
407
  path_logits,
408
  length_logits,
409
  hidden_states,
410
  pooler_output,
411
+ all_hidden_states,
412
  attentions,
413
+ path_log_sigma,
414
  )
415
  return (loss,) + output if loss is not None else output
416
 
 
 
 
 
417
  return SwipeTransformerOutput(
418
  loss=loss,
419
  char_logits=char_logits,
420
  path_logits=path_logits,
421
+ path_log_sigma=path_log_sigma,
422
  length_logits=length_logits,
423
  last_hidden_state=hidden_states,
424
  pooler_output=pooler_output,
processing_swipe.py CHANGED
@@ -40,11 +40,7 @@ class SwipeProcessor(ProcessorMixin):
40
  self.max_char_len = max_char_len
41
  self.path_input_dim = path_input_dim
42
  self.path_resample_mode = path_resample_mode
43
- # Attributes expected by newer transformers (not used for swipe models)
44
  self.chat_template = None
45
- self.audio_tokenizer = None
46
- self.feature_extractor = None
47
- self.image_processor = None
48
 
49
  def __call__(
50
  self,
@@ -94,47 +90,87 @@ class SwipeProcessor(ProcessorMixin):
94
  if path_coords is None and text is None:
95
  raise ValueError("Must provide either path_coords or text (or both)")
96
 
97
- # Determine batch size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if path_coords is not None:
99
- # Handle path coordinates
100
  if isinstance(path_coords, (list, tuple)):
101
  if len(path_coords) == 0:
102
  batch_size = 1
103
  else:
104
  first = path_coords[0]
105
- # Raw single path: [{"x","y","t"}, ...]
106
  if isinstance(first, dict):
107
  batch_size = 1
108
- # Raw batch of paths: [[{"x","y","t"}, ...], ...]
109
  elif (
110
  isinstance(first, (list, tuple))
111
  and len(first) > 0
112
  and isinstance(first[0], dict)
113
  ):
114
  batch_size = len(path_coords)
115
- # Numeric batch: [[[...], ...], ...] where points are lists/tuples
116
  elif (
117
  isinstance(first, (list, tuple))
118
  and len(first) > 0
119
  and isinstance(first[0], (list, tuple))
120
  ):
121
  path_coords = torch.tensor(path_coords, dtype=torch.float32)
122
- batch_size = path_coords.shape[0]
123
  else:
124
- # Numeric single path: [[...], [...], ...]
125
  path_coords = torch.tensor([path_coords], dtype=torch.float32)
126
- batch_size = path_coords.shape[0]
127
  elif isinstance(path_coords, np.ndarray):
128
  path_coords = torch.from_numpy(path_coords).float()
129
  if path_coords.dim() == 2:
130
- # Single path, add batch dimension
131
  path_coords = path_coords.unsqueeze(0)
132
- batch_size = path_coords.shape[0]
133
  elif isinstance(path_coords, torch.Tensor):
134
  if path_coords.dim() == 2:
135
- # Single path, add batch dimension
136
  path_coords = path_coords.unsqueeze(0)
137
- batch_size = path_coords.shape[0]
 
 
138
  elif text is not None:
139
  if isinstance(text, str):
140
  batch_size = 1
@@ -144,238 +180,252 @@ class SwipeProcessor(ProcessorMixin):
144
  else:
145
  batch_size = 1
146
 
147
- result = {}
148
 
149
- # Process path coordinates
150
- if path_coords is not None:
151
- # Check if path_coords is raw data (list of dicts) or already a tensor
152
- if isinstance(path_coords, (list, tuple)) and len(path_coords) > 0:
153
- first_elem = path_coords[0]
154
-
155
- # Raw single path: [{"x","y","t"}, ...]
156
- if isinstance(first_elem, dict) and "x" in first_elem:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  path_feats, mask = preprocess_raw_path_to_features(
158
- path_coords,
159
  self.max_path_len,
160
  resample_mode=self.path_resample_mode,
161
  )
162
- if return_tensors == "pt":
163
- path_coords = torch.from_numpy(path_feats).float().unsqueeze(0)
164
- _path_mask = torch.from_numpy(mask).long().unsqueeze(0)
165
- else:
166
- path_coords = np.expand_dims(path_feats, axis=0)
167
- _path_mask = np.expand_dims(mask, axis=0)
168
-
169
- # Raw batch of paths: [[{"x","y","t"}, ...], ...]
170
- elif (
171
- isinstance(first_elem, (list, tuple))
172
- and len(first_elem) > 0
173
- and isinstance(first_elem[0], dict)
174
- and "x" in first_elem[0]
175
- ):
176
- processed_paths = []
177
- path_masks = []
178
- for path in path_coords:
179
- path_feats, mask = preprocess_raw_path_to_features(
180
- path,
181
- self.max_path_len,
182
- resample_mode=self.path_resample_mode,
183
- )
184
- processed_paths.append(path_feats)
185
- path_masks.append(mask)
186
-
187
- path_coords = np.stack(processed_paths) # [batch, max_path_len, 6]
188
- _path_mask = np.stack(path_masks) # [batch, max_path_len]
189
-
190
- if return_tensors == "pt":
191
- path_coords = torch.from_numpy(path_coords).float()
192
- _path_mask = torch.from_numpy(_path_mask).long()
193
-
194
- else:
195
- # Numeric list input; process as before
196
- path_coords = torch.tensor(path_coords, dtype=torch.float32)
197
- if path_coords.dim() == 2:
198
- path_coords = path_coords.unsqueeze(0)
199
-
200
- current_path_len = path_coords.shape[1]
201
- if truncation and current_path_len > self.max_path_len:
202
- path_coords = path_coords[:, : self.max_path_len, :]
203
- if padding and current_path_len < self.max_path_len:
204
- pad_len = self.max_path_len - current_path_len
205
- pad_shape = (batch_size, pad_len, self.path_input_dim)
206
- path_coords = torch.cat([path_coords, torch.zeros(pad_shape)], dim=1)
207
-
208
- _path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
209
- is_padding = (path_coords == 0).all(dim=-1)
210
- _path_mask[is_padding] = 0
211
- elif isinstance(path_coords, np.ndarray):
212
- path_coords = torch.from_numpy(path_coords).float()
213
- if path_coords.dim() == 2:
214
- path_coords = path_coords.unsqueeze(0)
215
- # If user provided raw (x,y,t) triples but model expects engineered features,
216
- # convert to motion features and resample.
217
- if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
218
- processed_paths = []
219
- path_masks = []
220
- for path in path_coords.cpu().numpy():
221
- raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
222
- path_feats, mask = preprocess_raw_path_to_features(
223
- raw,
224
- self.max_path_len,
225
- resample_mode=self.path_resample_mode,
226
- )
227
- processed_paths.append(path_feats)
228
- path_masks.append(mask)
229
-
230
- path_coords = torch.from_numpy(np.stack(processed_paths)).float()
231
- _path_mask = torch.from_numpy(np.stack(path_masks)).long()
232
- else:
233
- _path_mask = torch.ones(
234
- path_coords.shape[0], self.max_path_len, dtype=torch.long
235
- )
236
- elif isinstance(path_coords, torch.Tensor):
237
- if path_coords.dim() == 2:
238
- path_coords = path_coords.unsqueeze(0)
239
- # If user provided raw (x,y,t) triples but model expects engineered features,
240
- # convert to motion features and resample.
241
- if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
242
- processed_paths = []
243
- path_masks = []
244
- for path in path_coords.detach().cpu().numpy():
245
- raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
246
- path_feats, mask = preprocess_raw_path_to_features(
247
- raw,
248
- self.max_path_len,
249
- resample_mode=self.path_resample_mode,
250
- )
251
- processed_paths.append(path_feats)
252
- path_masks.append(mask)
253
-
254
- path_coords = torch.from_numpy(np.stack(processed_paths)).float()
255
- _path_mask = torch.from_numpy(np.stack(path_masks)).long()
256
- else:
257
- _path_mask = torch.ones(
258
- path_coords.shape[0], self.max_path_len, dtype=torch.long
259
  )
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- result["path_coords"] = path_coords
262
- else:
263
- # No path coords provided, create empty/zero tensors
264
- path_coords = torch.zeros(batch_size, self.max_path_len, self.path_input_dim)
265
- _path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
266
- result["path_coords"] = path_coords
267
-
268
- # Process text
269
- if text is not None:
270
- # Ensure text is a list
271
- if isinstance(text, str):
272
- text = [text]
273
-
274
- # Tokenize text
275
- text_max_length = max_length if max_length is not None else self.max_char_len
276
-
277
- # First tokenize without padding/truncation to add EOS
278
- encoded_raw = self.tokenizer(
279
- text,
280
- padding=False,
281
- truncation=False,
282
- return_tensors=None, # Get lists first
283
- **kwargs,
284
  )
 
 
 
285
 
286
- # Add EOS token after each word (matching training dataset behavior)
287
- eos_id = self.tokenizer.eos_token_id
288
- for i in range(len(encoded_raw["input_ids"])):
289
- # Add EOS if not already present
290
- if encoded_raw["input_ids"][i][-1] != eos_id:
291
- encoded_raw["input_ids"][i].append(eos_id)
292
-
293
- # Now apply padding and truncation
294
- max_len_needed = max(len(ids) for ids in encoded_raw["input_ids"])
295
- if truncation and max_len_needed > text_max_length:
296
- # Truncate but preserve EOS at the end
297
- for i in range(len(encoded_raw["input_ids"])):
298
- if len(encoded_raw["input_ids"][i]) > text_max_length:
299
- encoded_raw["input_ids"][i] = encoded_raw["input_ids"][i][
300
- : text_max_length - 1
301
- ] + [eos_id]
302
-
303
- # Pad sequences
304
- if padding:
305
- pad_id = self.tokenizer.pad_token_id
306
- for i in range(len(encoded_raw["input_ids"])):
307
- seq_len = len(encoded_raw["input_ids"][i])
308
- if seq_len < text_max_length:
309
- encoded_raw["input_ids"][i].extend([pad_id] * (text_max_length - seq_len))
310
-
311
- # Create attention mask (1 for real tokens + EOS, 0 for padding)
312
- _char_mask = []
313
- for ids in encoded_raw["input_ids"]:
314
- mask = [1 if token_id != self.tokenizer.pad_token_id else 0 for token_id in ids]
315
- _char_mask.append(mask)
316
-
317
- # Convert to tensors if requested
318
- if return_tensors == "pt":
319
- result["input_ids"] = torch.tensor(encoded_raw["input_ids"], dtype=torch.long)
320
- _char_mask = torch.tensor(_char_mask, dtype=torch.long)
321
- elif return_tensors == "np":
322
- result["input_ids"] = np.array(encoded_raw["input_ids"], dtype=np.int64)
323
- _char_mask = np.array(_char_mask, dtype=np.int64)
324
- else:
325
- result["input_ids"] = encoded_raw["input_ids"]
326
- else:
327
- # No text provided, create padding tokens
328
  if return_tensors == "pt":
329
  char_tokens = torch.full(
330
- (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=torch.long
 
 
331
  )
332
- _char_mask = torch.zeros(batch_size, self.max_char_len, dtype=torch.long)
333
  elif return_tensors == "np":
334
  char_tokens = np.full(
335
- (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=np.int64
 
 
336
  )
337
- _char_mask = np.zeros((batch_size, self.max_char_len), dtype=np.int64)
338
  else:
339
  char_tokens = [
340
  [self.tokenizer.pad_token_id] * self.max_char_len for _ in range(batch_size)
341
  ]
342
- _char_mask = [[0] * self.max_char_len for _ in range(batch_size)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- result["input_ids"] = char_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
- # Create combined attention mask: [CLS] + path + [SEP] + chars
347
- # Sequence structure: [CLS:1] + _path_mask + [SEP:1] + _char_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  if return_tensors == "pt":
349
  cls_mask = torch.ones(batch_size, 1, dtype=torch.long)
350
  sep_mask = torch.ones(batch_size, 1, dtype=torch.long)
351
- attention_mask = torch.cat([cls_mask, _path_mask, sep_mask, _char_mask], dim=1)
352
- elif return_tensors == "np":
353
  cls_mask = np.ones((batch_size, 1), dtype=np.int64)
354
  sep_mask = np.ones((batch_size, 1), dtype=np.int64)
355
- attention_mask = np.concatenate([cls_mask, _path_mask, sep_mask, _char_mask], axis=1)
356
- else:
357
- cls_mask = [[1] for _ in range(batch_size)]
358
- sep_mask = [[1] for _ in range(batch_size)]
359
- attention_mask = [
360
- cls + path.tolist() + sep + char
361
- for cls, path, sep, char in zip(
362
- cls_mask, _path_mask, sep_mask, _char_mask, strict=False
363
- )
364
- ]
365
-
366
- result["attention_mask"] = attention_mask
367
-
368
- # Convert to requested format
369
  if return_tensors == "np":
370
- for key in result:
371
- if isinstance(result[key], torch.Tensor):
372
- result[key] = result[key].numpy()
373
  elif return_tensors is None:
374
- for key in result:
375
- if isinstance(result[key], torch.Tensor):
376
- result[key] = result[key].tolist()
377
-
378
- return result
379
 
380
  def batch_decode(self, token_ids, **kwargs):
381
  """
 
40
  self.max_char_len = max_char_len
41
  self.path_input_dim = path_input_dim
42
  self.path_resample_mode = path_resample_mode
 
43
  self.chat_template = None
 
 
 
44
 
45
  def __call__(
46
  self,
 
90
  if path_coords is None and text is None:
91
  raise ValueError("Must provide either path_coords or text (or both)")
92
 
93
+ batch_size, path_coords, text = self._infer_batch_size(path_coords, text)
94
+
95
+ result: dict[str, Any] = {}
96
+
97
+ path_coords_out, path_mask = self._process_path_coords(
98
+ path_coords=path_coords,
99
+ batch_size=batch_size,
100
+ truncation=truncation,
101
+ padding=padding,
102
+ return_tensors=return_tensors,
103
+ )
104
+ result["path_coords"] = path_coords_out
105
+
106
+ input_ids, char_mask = self._process_text(
107
+ text=text,
108
+ batch_size=batch_size,
109
+ padding=padding,
110
+ truncation=truncation,
111
+ max_length=max_length,
112
+ return_tensors=return_tensors,
113
+ **kwargs,
114
+ )
115
+ result["input_ids"] = input_ids
116
+
117
+ result["attention_mask"] = self._build_attention_mask(
118
+ path_mask=path_mask,
119
+ char_mask=char_mask,
120
+ batch_size=batch_size,
121
+ return_tensors=return_tensors,
122
+ )
123
+
124
+ self._convert_result_in_place(result, return_tensors=return_tensors)
125
+ return result
126
+
127
+ def _infer_batch_size(
128
+ self,
129
+ path_coords: (
130
+ list[dict[str, float]]
131
+ | list[list[dict[str, float]]]
132
+ | list[list[list[float]]]
133
+ | torch.Tensor
134
+ | np.ndarray
135
+ | None
136
+ ),
137
+ text: str | list[str] | None,
138
+ ) -> tuple[int, Any, str | list[str] | None]:
139
  if path_coords is not None:
 
140
  if isinstance(path_coords, (list, tuple)):
141
  if len(path_coords) == 0:
142
  batch_size = 1
143
  else:
144
  first = path_coords[0]
 
145
  if isinstance(first, dict):
146
  batch_size = 1
 
147
  elif (
148
  isinstance(first, (list, tuple))
149
  and len(first) > 0
150
  and isinstance(first[0], dict)
151
  ):
152
  batch_size = len(path_coords)
 
153
  elif (
154
  isinstance(first, (list, tuple))
155
  and len(first) > 0
156
  and isinstance(first[0], (list, tuple))
157
  ):
158
  path_coords = torch.tensor(path_coords, dtype=torch.float32)
159
+ batch_size = int(path_coords.shape[0])
160
  else:
 
161
  path_coords = torch.tensor([path_coords], dtype=torch.float32)
162
+ batch_size = int(path_coords.shape[0])
163
  elif isinstance(path_coords, np.ndarray):
164
  path_coords = torch.from_numpy(path_coords).float()
165
  if path_coords.dim() == 2:
 
166
  path_coords = path_coords.unsqueeze(0)
167
+ batch_size = int(path_coords.shape[0])
168
  elif isinstance(path_coords, torch.Tensor):
169
  if path_coords.dim() == 2:
 
170
  path_coords = path_coords.unsqueeze(0)
171
+ batch_size = int(path_coords.shape[0])
172
+ else:
173
+ batch_size = 1
174
  elif text is not None:
175
  if isinstance(text, str):
176
  batch_size = 1
 
180
  else:
181
  batch_size = 1
182
 
183
+ return batch_size, path_coords, text
184
 
185
+ def _process_path_coords(
186
+ self,
187
+ *,
188
+ path_coords,
189
+ batch_size: int,
190
+ truncation: bool,
191
+ padding: bool | str,
192
+ return_tensors: str | None,
193
+ ) -> tuple[Any, Any]:
194
+ if path_coords is None:
195
+ path_coords_out = torch.zeros(batch_size, self.max_path_len, self.path_input_dim)
196
+ path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
197
+ return path_coords_out, path_mask
198
+
199
+ if isinstance(path_coords, (list, tuple)) and len(path_coords) > 0:
200
+ first_elem = path_coords[0]
201
+
202
+ if isinstance(first_elem, dict) and "x" in first_elem:
203
+ path_feats, mask = preprocess_raw_path_to_features(
204
+ path_coords,
205
+ self.max_path_len,
206
+ resample_mode=self.path_resample_mode,
207
+ )
208
+ if return_tensors == "pt":
209
+ return (
210
+ torch.from_numpy(path_feats).float().unsqueeze(0),
211
+ torch.from_numpy(mask).long().unsqueeze(0),
212
+ )
213
+ return (np.expand_dims(path_feats, axis=0), np.expand_dims(mask, axis=0))
214
+
215
+ if (
216
+ isinstance(first_elem, (list, tuple))
217
+ and len(first_elem) > 0
218
+ and isinstance(first_elem[0], dict)
219
+ and "x" in first_elem[0]
220
+ ):
221
+ processed_paths = []
222
+ path_masks = []
223
+ for path in path_coords:
224
  path_feats, mask = preprocess_raw_path_to_features(
225
+ path,
226
  self.max_path_len,
227
  resample_mode=self.path_resample_mode,
228
  )
229
+ processed_paths.append(path_feats)
230
+ path_masks.append(mask)
231
+
232
+ path_coords_np = np.stack(processed_paths)
233
+ path_mask_np = np.stack(path_masks)
234
+ if return_tensors == "pt":
235
+ return torch.from_numpy(path_coords_np).float(), torch.from_numpy(
236
+ path_mask_np
237
+ ).long()
238
+ return path_coords_np, path_mask_np
239
+
240
+ # Numeric list input
241
+ path_tensor = torch.tensor(path_coords, dtype=torch.float32)
242
+ if path_tensor.dim() == 2:
243
+ path_tensor = path_tensor.unsqueeze(0)
244
+
245
+ current_path_len = int(path_tensor.shape[1])
246
+ if truncation and current_path_len > self.max_path_len:
247
+ path_tensor = path_tensor[:, : self.max_path_len, :]
248
+ if padding and current_path_len < self.max_path_len:
249
+ pad_len = self.max_path_len - current_path_len
250
+ pad_shape = (batch_size, pad_len, self.path_input_dim)
251
+ path_tensor = torch.cat([path_tensor, torch.zeros(pad_shape)], dim=1)
252
+
253
+ path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
254
+ is_padding = (path_tensor == 0).all(dim=-1)
255
+ path_mask[is_padding] = 0
256
+ return path_tensor, path_mask
257
+
258
+ if isinstance(path_coords, np.ndarray):
259
+ path_coords = torch.from_numpy(path_coords).float()
260
+
261
+ if isinstance(path_coords, torch.Tensor):
262
+ if path_coords.dim() == 2:
263
+ path_coords = path_coords.unsqueeze(0)
264
+ if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
265
+ processed_paths = []
266
+ path_masks = []
267
+ for path in path_coords.detach().cpu().numpy():
268
+ raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
269
+ path_feats, mask = preprocess_raw_path_to_features(
270
+ raw,
271
+ self.max_path_len,
272
+ resample_mode=self.path_resample_mode,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
+ processed_paths.append(path_feats)
275
+ path_masks.append(mask)
276
+ return torch.from_numpy(np.stack(processed_paths)).float(), torch.from_numpy(
277
+ np.stack(path_masks)
278
+ ).long()
279
+
280
+ if int(path_coords.shape[-1]) != int(self.path_input_dim):
281
+ raise ValueError(
282
+ f"Expected path_coords.shape[-1] == path_input_dim ({self.path_input_dim}), "
283
+ f"got {int(path_coords.shape[-1])}. If your path is (x,y,t), pass D=3."
284
+ )
285
 
286
+ path_tensor = path_coords
287
+ current_path_len = int(path_tensor.shape[1])
288
+ if truncation and current_path_len > self.max_path_len:
289
+ path_tensor = path_tensor[:, : self.max_path_len, :]
290
+ if padding and current_path_len < self.max_path_len:
291
+ pad_len = self.max_path_len - current_path_len
292
+ pad_shape = (int(path_tensor.shape[0]), pad_len, int(path_tensor.shape[-1]))
293
+ pad = torch.zeros(pad_shape, dtype=path_tensor.dtype, device=path_tensor.device)
294
+ path_tensor = torch.cat([path_tensor, pad], dim=1)
295
+
296
+ path_mask = torch.ones(
297
+ int(path_tensor.shape[0]),
298
+ int(path_tensor.shape[1]),
299
+ dtype=torch.long,
300
+ device=path_tensor.device,
 
 
 
 
 
 
 
 
301
  )
302
+ is_padding = (path_tensor == 0).all(dim=-1)
303
+ path_mask[is_padding] = 0
304
+ return path_tensor, path_mask
305
 
306
+ # Fallback: treat unknown input as empty path.
307
+ path_coords_out = torch.zeros(batch_size, self.max_path_len, self.path_input_dim)
308
+ path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
309
+ return path_coords_out, path_mask
310
+
311
+ def _process_text(
312
+ self,
313
+ *,
314
+ text: str | list[str] | None,
315
+ batch_size: int,
316
+ padding: bool | str,
317
+ truncation: bool,
318
+ max_length: int | None,
319
+ return_tensors: str | None,
320
+ **kwargs: Any,
321
+ ) -> tuple[Any, Any]:
322
+ if text is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  if return_tensors == "pt":
324
  char_tokens = torch.full(
325
+ (batch_size, self.max_char_len),
326
+ self.tokenizer.pad_token_id,
327
+ dtype=torch.long,
328
  )
329
+ char_mask = torch.zeros(batch_size, self.max_char_len, dtype=torch.long)
330
  elif return_tensors == "np":
331
  char_tokens = np.full(
332
+ (batch_size, self.max_char_len),
333
+ self.tokenizer.pad_token_id,
334
+ dtype=np.int64,
335
  )
336
+ char_mask = np.zeros((batch_size, self.max_char_len), dtype=np.int64)
337
  else:
338
  char_tokens = [
339
  [self.tokenizer.pad_token_id] * self.max_char_len for _ in range(batch_size)
340
  ]
341
+ char_mask = [[0] * self.max_char_len for _ in range(batch_size)]
342
+ return char_tokens, char_mask
343
+
344
+ if isinstance(text, str):
345
+ text = [text]
346
+
347
+ text_max_length = max_length if max_length is not None else self.max_char_len
348
+
349
+ encoded_raw = self.tokenizer(
350
+ text,
351
+ padding=False,
352
+ truncation=False,
353
+ return_tensors=None,
354
+ **kwargs,
355
+ )
356
 
357
+ eos_id = self.tokenizer.eos_token_id
358
+ for i in range(len(encoded_raw["input_ids"])):
359
+ if encoded_raw["input_ids"][i][-1] != eos_id:
360
+ encoded_raw["input_ids"][i].append(eos_id)
361
+
362
+ max_len_needed = max(len(ids) for ids in encoded_raw["input_ids"])
363
+ if truncation and max_len_needed > text_max_length:
364
+ for i in range(len(encoded_raw["input_ids"])):
365
+ if len(encoded_raw["input_ids"][i]) > text_max_length:
366
+ encoded_raw["input_ids"][i] = encoded_raw["input_ids"][i][
367
+ : text_max_length - 1
368
+ ] + [eos_id]
369
+
370
+ if padding:
371
+ pad_id = self.tokenizer.pad_token_id
372
+ for i in range(len(encoded_raw["input_ids"])):
373
+ seq_len = len(encoded_raw["input_ids"][i])
374
+ if seq_len < text_max_length:
375
+ encoded_raw["input_ids"][i].extend([pad_id] * (text_max_length - seq_len))
376
+
377
+ char_mask_list = [
378
+ [1 if token_id != self.tokenizer.pad_token_id else 0 for token_id in ids]
379
+ for ids in encoded_raw["input_ids"]
380
+ ]
381
 
382
+ if return_tensors == "pt":
383
+ return (
384
+ torch.tensor(encoded_raw["input_ids"], dtype=torch.long),
385
+ torch.tensor(char_mask_list, dtype=torch.long),
386
+ )
387
+ if return_tensors == "np":
388
+ return (
389
+ np.array(encoded_raw["input_ids"], dtype=np.int64),
390
+ np.array(char_mask_list, dtype=np.int64),
391
+ )
392
+ return encoded_raw["input_ids"], char_mask_list
393
+
394
+ def _build_attention_mask(
395
+ self,
396
+ *,
397
+ path_mask,
398
+ char_mask,
399
+ batch_size: int,
400
+ return_tensors: str | None,
401
+ ):
402
  if return_tensors == "pt":
403
  cls_mask = torch.ones(batch_size, 1, dtype=torch.long)
404
  sep_mask = torch.ones(batch_size, 1, dtype=torch.long)
405
+ return torch.cat([cls_mask, path_mask, sep_mask, char_mask], dim=1)
406
+ if return_tensors == "np":
407
  cls_mask = np.ones((batch_size, 1), dtype=np.int64)
408
  sep_mask = np.ones((batch_size, 1), dtype=np.int64)
409
+ return np.concatenate([cls_mask, path_mask, sep_mask, char_mask], axis=1)
410
+
411
+ cls_mask = [[1] for _ in range(batch_size)]
412
+ sep_mask = [[1] for _ in range(batch_size)]
413
+ return [
414
+ cls + path.tolist() + sep + char
415
+ for cls, path, sep, char in zip(cls_mask, path_mask, sep_mask, char_mask, strict=False)
416
+ ]
417
+
418
+ def _convert_result_in_place(
419
+ self, result: dict[str, Any], *, return_tensors: str | None
420
+ ) -> None:
 
 
421
  if return_tensors == "np":
422
+ for key, value in list(result.items()):
423
+ if isinstance(value, torch.Tensor):
424
+ result[key] = value.numpy()
425
  elif return_tensors is None:
426
+ for key, value in list(result.items()):
427
+ if isinstance(value, torch.Tensor):
428
+ result[key] = value.tolist()
 
 
429
 
430
  def batch_decode(self, token_ids, **kwargs):
431
  """
special_tokens_map.json CHANGED
@@ -1,44 +1,8 @@
1
  {
2
- "cls_token": {
3
- "content": "[CLS]",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "eos_token": {
10
- "content": "[EOS]",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- },
16
- "mask_token": {
17
- "content": "[MASK]",
18
- "lstrip": false,
19
- "normalized": false,
20
- "rstrip": false,
21
- "single_word": false
22
- },
23
- "pad_token": {
24
- "content": "[PAD]",
25
- "lstrip": false,
26
- "normalized": false,
27
- "rstrip": false,
28
- "single_word": false
29
- },
30
- "sep_token": {
31
- "content": "[SEP]",
32
- "lstrip": false,
33
- "normalized": false,
34
- "rstrip": false,
35
- "single_word": false
36
- },
37
- "unk_token": {
38
- "content": "[UNK]",
39
- "lstrip": false,
40
- "normalized": false,
41
- "rstrip": false,
42
- "single_word": false
43
- }
44
  }
 
1
  {
2
+ "cls_token": "[CLS]",
3
+ "eos_token": "[EOS]",
4
+ "mask_token": "[MASK]",
5
+ "pad_token": "[PAD]",
6
+ "sep_token": "[SEP]",
7
+ "unk_token": "[UNK]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
tokenizer_config.json CHANGED
@@ -49,18 +49,22 @@
49
  "special": true
50
  }
51
  },
 
52
  "auto_map": {
53
  "AutoTokenizer": [
54
  "tokenization_swipe.SwipeTokenizer",
55
  null
56
  ]
57
  },
 
58
  "clean_up_tokenization_spaces": false,
59
  "cls_token": "[CLS]",
60
  "eos_token": "[EOS]",
61
- "extra_special_tokens": {},
 
62
  "mask_token": "[MASK]",
63
  "model_max_length": 1000000000000000019884624838656,
 
64
  "pad_token": "[PAD]",
65
  "processor_class": "SwipeProcessor",
66
  "sep_token": "[SEP]",
 
49
  "special": true
50
  }
51
  },
52
+ "additional_special_tokens": null,
53
  "auto_map": {
54
  "AutoTokenizer": [
55
  "tokenization_swipe.SwipeTokenizer",
56
  null
57
  ]
58
  },
59
+ "backend": "custom",
60
  "clean_up_tokenization_spaces": false,
61
  "cls_token": "[CLS]",
62
  "eos_token": "[EOS]",
63
+ "extra_special_tokens": [],
64
+ "is_local": true,
65
  "mask_token": "[MASK]",
66
  "model_max_length": 1000000000000000019884624838656,
67
+ "model_specific_special_tokens": {},
68
  "pad_token": "[PAD]",
69
  "processor_class": "SwipeProcessor",
70
  "sep_token": "[SEP]",