dleemiller commited on
Commit
b121266
·
verified ·
1 Parent(s): 04c98bd

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -15,6 +15,9 @@
15
  "n_heads": 12,
16
  "n_layers": 12,
17
  "pad_token_id": 0,
 
 
 
18
  "predict_path": true,
19
  "sep_token_id": 2,
20
  "transformers_version": "4.57.3",
@@ -22,6 +25,7 @@
22
  "vocab_size": 43,
23
  "auto_map": {
24
  "AutoConfig": "configuration_swipe.SwipeTransformerConfig",
25
- "AutoModel": "modeling_swipe.SwipeTransformerModel"
 
26
  }
27
  }
 
15
  "n_heads": 12,
16
  "n_layers": 12,
17
  "pad_token_id": 0,
18
+ "path_input_dim": 6,
19
+ "predict_char": true,
20
+ "predict_length": true,
21
  "predict_path": true,
22
  "sep_token_id": 2,
23
  "transformers_version": "4.57.3",
 
25
  "vocab_size": 43,
26
  "auto_map": {
27
  "AutoConfig": "configuration_swipe.SwipeTransformerConfig",
28
+ "AutoModel": "modeling_swipe.SwipeTransformerModel",
29
+ "AutoModelForCausalLM": "modeling_swipe.SwipeTransformerModel"
30
  }
31
  }
configuration_swipe.py CHANGED
@@ -20,6 +20,7 @@ class SwipeTransformerConfig(PretrainedConfig):
20
  vocab_size (int, optional): Size of vocabulary. Defaults to 100.
21
  max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
22
  max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
 
23
  predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
24
  pad_token_id (int, optional): Padding token ID. Defaults to 0.
25
  cls_token_id (int, optional): CLS token ID. Defaults to 1.
@@ -41,7 +42,10 @@ class SwipeTransformerConfig(PretrainedConfig):
41
  vocab_size: int = 100,
42
  max_path_len: int = 64,
43
  max_char_len: int = 38,
 
 
44
  predict_path: bool = True,
 
45
  pad_token_id: int = 0,
46
  cls_token_id: int = 1,
47
  sep_token_id: int = 2,
@@ -63,83 +67,12 @@ class SwipeTransformerConfig(PretrainedConfig):
63
  self.vocab_size = vocab_size
64
  self.max_path_len = max_path_len
65
  self.max_char_len = max_char_len
 
66
 
67
  # Model capabilities
 
68
  self.predict_path = predict_path
69
-
70
- # Special tokens
71
- self.cls_token_id = cls_token_id
72
- self.sep_token_id = sep_token_id
73
- self.mask_token_id = mask_token_id
74
- self.unk_token_id = unk_token_id
75
-
76
-
77
- class SwipeCrossEncoderConfig(PretrainedConfig):
78
- """
79
- Configuration class for SwipeCrossEncoderForSequenceClassification.
80
-
81
- This configuration extends the base SwipeTransformer config for use in
82
- cross-encoder tasks (e.g., path-word similarity scoring).
83
-
84
- Args:
85
- d_model (int, optional): Hidden dimension size. Defaults to 256.
86
- n_layers (int, optional): Number of transformer layers. Defaults to 4.
87
- n_heads (int, optional): Number of attention heads. Defaults to 4.
88
- d_ff (int, optional): Feedforward dimension. Defaults to 1024.
89
- dropout (float, optional): Dropout rate. Defaults to 0.1.
90
- vocab_size (int, optional): Size of vocabulary. Defaults to 100.
91
- max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
92
- max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
93
- num_labels (int, optional): Number of classification labels. Defaults to 1.
94
- problem_type (str, optional): Problem type ('regression' or 'single_label_classification'). Defaults to "regression".
95
- pad_token_id (int, optional): Padding token ID. Defaults to 0.
96
- cls_token_id (int, optional): CLS token ID. Defaults to 1.
97
- sep_token_id (int, optional): SEP token ID. Defaults to 2.
98
- mask_token_id (int, optional): MASK token ID. Defaults to 3.
99
- unk_token_id (int, optional): Unknown token ID. Defaults to 4.
100
- eos_token_id (int, optional): End-of-sequence token ID. Defaults to 5.
101
- """
102
-
103
- model_type = "swipe_cross_encoder"
104
-
105
- def __init__(
106
- self,
107
- d_model: int = 256,
108
- n_layers: int = 4,
109
- n_heads: int = 4,
110
- d_ff: int = 1024,
111
- dropout: float = 0.1,
112
- vocab_size: int = 100,
113
- max_path_len: int = 64,
114
- max_char_len: int = 38,
115
- num_labels: int = 1,
116
- problem_type: str = "regression",
117
- pad_token_id: int = 0,
118
- cls_token_id: int = 1,
119
- sep_token_id: int = 2,
120
- mask_token_id: int = 3,
121
- unk_token_id: int = 4,
122
- eos_token_id: int = 5,
123
- **kwargs,
124
- ):
125
- super().__init__(
126
- pad_token_id=pad_token_id, num_labels=num_labels, eos_token_id=eos_token_id, **kwargs
127
- )
128
-
129
- # Model architecture parameters
130
- self.d_model = d_model
131
- self.n_layers = n_layers
132
- self.n_heads = n_heads
133
- self.d_ff = d_ff
134
- self.dropout = dropout
135
-
136
- # Vocabulary and sequence length
137
- self.vocab_size = vocab_size
138
- self.max_path_len = max_path_len
139
- self.max_char_len = max_char_len
140
-
141
- # Classification parameters
142
- self.problem_type = problem_type
143
 
144
  # Special tokens
145
  self.cls_token_id = cls_token_id
 
20
  vocab_size (int, optional): Size of vocabulary. Defaults to 100.
21
  max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
22
  max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
23
+ path_input_dim (int, optional): Path feature dimension. Defaults to 6 for (x, y, dx, dy, ds, log_dt).
24
  predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
25
  pad_token_id (int, optional): Padding token ID. Defaults to 0.
26
  cls_token_id (int, optional): CLS token ID. Defaults to 1.
 
42
  vocab_size: int = 100,
43
  max_path_len: int = 64,
44
  max_char_len: int = 38,
45
+ path_input_dim: int = 6,
46
+ predict_char: bool = True,
47
  predict_path: bool = True,
48
+ predict_length: bool = True,
49
  pad_token_id: int = 0,
50
  cls_token_id: int = 1,
51
  sep_token_id: int = 2,
 
67
  self.vocab_size = vocab_size
68
  self.max_path_len = max_path_len
69
  self.max_char_len = max_char_len
70
+ self.path_input_dim = path_input_dim
71
 
72
  # Model capabilities
73
+ self.predict_char = predict_char
74
  self.predict_path = predict_path
75
+ self.predict_length = predict_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Special tokens
78
  self.cls_token_id = cls_token_id
embeddings.py CHANGED
@@ -5,16 +5,26 @@ import torch.nn as nn
5
 
6
 
7
  class PathEmbedding(nn.Module):
8
- """Embeds path coordinates (x, y, t) to d_model dimension."""
9
 
10
- def __init__(self, d_model: int = 256):
 
 
 
 
 
 
 
11
  super().__init__()
12
- self.projection = nn.Linear(3, d_model)
13
 
14
  def forward(self, path_coords: torch.Tensor) -> torch.Tensor:
15
  """
 
 
16
  Args:
17
- path_coords: [batch, seq_len, 3] - (x, y, t) coordinates
 
18
 
19
  Returns:
20
  [batch, seq_len, d_model] embeddings
@@ -90,12 +100,13 @@ class MixedEmbedding(nn.Module):
90
  max_char_len: int,
91
  d_model: int = 256,
92
  dropout: float = 0.1,
 
93
  ):
94
  super().__init__()
95
  self.d_model = d_model
96
 
97
  # Content embeddings
98
- self.path_embedding = PathEmbedding(d_model)
99
  self.char_embedding = CharacterEmbedding(vocab_size, d_model, padding_idx=0)
100
 
101
  # Positional embeddings
@@ -120,7 +131,8 @@ class MixedEmbedding(nn.Module):
120
  Create mixed sequence with embeddings.
121
 
122
  Args:
123
- path_coords: [batch, path_len, 3] path coordinates
 
124
  char_tokens: [batch, char_len] character token IDs
125
  cls_token: [batch, 1] CLS token IDs
126
  sep_token: [batch, 1] SEP token IDs
 
5
 
6
 
7
  class PathEmbedding(nn.Module):
8
+ """Embeds path features (x, y, dx, dy, ds, log_dt) to d_model dimension."""
9
 
10
+ def __init__(self, d_model: int = 256, input_dim: int = 6):
11
+ """
12
+ Initialize path embedding layer.
13
+
14
+ Args:
15
+ d_model: Output dimension
16
+ input_dim: Input feature dimension (default: 6 for x, y, dx, dy, ds, log_dt)
17
+ """
18
  super().__init__()
19
+ self.projection = nn.Linear(input_dim, d_model)
20
 
21
  def forward(self, path_coords: torch.Tensor) -> torch.Tensor:
22
  """
23
+ Project path features to d_model dimension.
24
+
25
  Args:
26
+ path_coords: [batch, seq_len, input_dim] - path features
27
+ Default: (x, y, dx, dy, ds, log_dt) with input_dim=6
28
 
29
  Returns:
30
  [batch, seq_len, d_model] embeddings
 
100
  max_char_len: int,
101
  d_model: int = 256,
102
  dropout: float = 0.1,
103
+ path_input_dim: int = 6,
104
  ):
105
  super().__init__()
106
  self.d_model = d_model
107
 
108
  # Content embeddings
109
+ self.path_embedding = PathEmbedding(d_model, input_dim=path_input_dim)
110
  self.char_embedding = CharacterEmbedding(vocab_size, d_model, padding_idx=0)
111
 
112
  # Positional embeddings
 
131
  Create mixed sequence with embeddings.
132
 
133
  Args:
134
+ path_coords: [batch, path_len, path_input_dim] path features
135
+ Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
136
  char_tokens: [batch, char_len] character token IDs
137
  cls_token: [batch, 1] CLS token IDs
138
  sep_token: [batch, 1] SEP token IDs
heads.py CHANGED
@@ -32,11 +32,11 @@ class CharacterPredictionHead(nn.Module):
32
  class PathPredictionHead(nn.Module):
33
  """Prediction head for masked path coordinates."""
34
 
35
- def __init__(self, d_model: int):
36
  super().__init__()
37
  self.dense = nn.Linear(d_model, d_model)
38
  self.layer_norm = nn.LayerNorm(d_model)
39
- self.decoder = nn.Linear(d_model, 3) # Predict (x, y, t)
40
  self.activation = nn.GELU()
41
 
42
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -45,55 +45,38 @@ class PathPredictionHead(nn.Module):
45
  hidden_states: [batch, seq_len, d_model]
46
 
47
  Returns:
48
- [batch, seq_len, 3] coordinates in [0, 1] range
49
  """
50
  x = self.dense(hidden_states)
51
  x = self.activation(x)
52
  x = self.layer_norm(x)
53
- coords = self.decoder(x)
54
- coords = torch.sigmoid(coords) # Ensure [0, 1] range
55
- return coords
56
-
57
 
58
- class ClassificationHead(nn.Module):
59
- """
60
- Classification head for cross-encoder.
 
 
 
 
 
 
 
 
61
 
62
- Follows SBERT architecture: Dense GELU LayerNorm → Linear(→1)
63
- Outputs a single similarity score per input.
64
- """
65
-
66
- def __init__(self, d_model: int, num_labels: int = 1):
67
- super().__init__()
68
- self.dense = nn.Linear(d_model, d_model)
69
- self.activation = nn.GELU()
70
- self.norm = nn.LayerNorm(d_model)
71
- self.classifier = nn.Linear(d_model, num_labels)
72
-
73
- def forward(self, features: torch.Tensor) -> torch.Tensor:
74
- """
75
- Args:
76
- features: [batch, d_model] - typically SEP token embeddings
77
-
78
- Returns:
79
- [batch, num_labels] similarity scores
80
- """
81
- x = self.dense(features)
82
- x = self.activation(x) # GELU
83
- x = self.norm(x) # LayerNorm
84
- logits = self.classifier(x) # [batch, 1] or [batch, num_labels]
85
- return logits
86
 
87
 
88
  class LengthPredictionHead(nn.Module):
89
- """Predict sequence length (e.g., swipable character count) from CLS embedding."""
90
 
91
- def __init__(self, d_model: int, max_length: int):
92
  super().__init__()
93
  self.dense = nn.Linear(d_model, d_model)
94
  self.activation = nn.GELU()
95
  self.norm = nn.LayerNorm(d_model)
96
- self.classifier = nn.Linear(d_model, max_length + 1) # classes: 0..max_length
97
 
98
  def forward(self, cls_features: torch.Tensor) -> torch.Tensor:
99
  """
@@ -101,9 +84,9 @@ class LengthPredictionHead(nn.Module):
101
  cls_features: [batch, d_model] CLS embeddings
102
 
103
  Returns:
104
- [batch, max_length+1] logits over lengths
105
  """
106
  x = self.dense(cls_features)
107
  x = self.activation(x)
108
  x = self.norm(x)
109
- return self.classifier(x)
 
32
  class PathPredictionHead(nn.Module):
33
  """Prediction head for masked path coordinates."""
34
 
35
+ def __init__(self, d_model: int, output_dim: int = 6):
36
  super().__init__()
37
  self.dense = nn.Linear(d_model, d_model)
38
  self.layer_norm = nn.LayerNorm(d_model)
39
+ self.decoder = nn.Linear(d_model, output_dim)
40
  self.activation = nn.GELU()
41
 
42
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
45
  hidden_states: [batch, seq_len, d_model]
46
 
47
  Returns:
48
+ [batch, seq_len, output_dim] path features.
49
  """
50
  x = self.dense(hidden_states)
51
  x = self.activation(x)
52
  x = self.layer_norm(x)
53
+ features = self.decoder(x)
 
 
 
54
 
55
+ # Per-feature constraints:
56
+ # - x, y are normalized to [0,1]
57
+ # - dx, dy are signed deltas (roughly [-1,1])
58
+ # - ds is non-negative
59
+ # - log_dt is non-negative
60
+ if features.shape[-1] == 6:
61
+ x_y = torch.sigmoid(features[..., 0:2])
62
+ dx_dy = torch.tanh(features[..., 2:4])
63
+ ds = torch.nn.functional.softplus(features[..., 4:5])
64
+ log_dt = torch.nn.functional.softplus(features[..., 5:6])
65
+ return torch.cat([x_y, dx_dy, ds, log_dt], dim=-1)
66
 
67
+ # Fallback: unconstrained regression for other output dims.
68
+ return features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  class LengthPredictionHead(nn.Module):
72
+ """Regress sequence length (e.g., swipable character count) from CLS embedding."""
73
 
74
+ def __init__(self, d_model: int):
75
  super().__init__()
76
  self.dense = nn.Linear(d_model, d_model)
77
  self.activation = nn.GELU()
78
  self.norm = nn.LayerNorm(d_model)
79
+ self.regressor = nn.Linear(d_model, 1) # predict expected length directly
80
 
81
  def forward(self, cls_features: torch.Tensor) -> torch.Tensor:
82
  """
 
84
  cls_features: [batch, d_model] CLS embeddings
85
 
86
  Returns:
87
+ [batch, 1] predicted length
88
  """
89
  x = self.dense(cls_features)
90
  x = self.activation(x)
91
  x = self.norm(x)
92
+ return self.regressor(x).squeeze(-1)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b00a7a35dc99485501db127c4433afee95bd4f0c34d6c6db8ed6c69eec43404b
3
- size 348336548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:452880eddd71bf06c0c61750ae3f5ba65670c15aa97da8233df856ac74bd3e56
3
+ size 348207344
modeling_swipe.py CHANGED
@@ -1,19 +1,13 @@
1
  """HuggingFace-compatible model classes for SwipeTransformer."""
2
 
3
  from dataclasses import dataclass
4
- from typing import Optional, Tuple
5
 
6
  import torch
7
  import torch.nn as nn
8
  from transformers import PreTrainedModel
9
- from transformers.modeling_outputs import (
10
- BaseModelOutput,
11
- BaseModelOutputWithPooling,
12
- ModelOutput,
13
- SequenceClassifierOutput,
14
- )
15
 
16
- from .configuration_swipe import SwipeCrossEncoderConfig, SwipeTransformerConfig
17
 
18
 
19
  @dataclass
@@ -24,27 +18,32 @@ class SwipeTransformerOutput(ModelOutput):
24
  Args:
25
  loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
26
  Language modeling loss (character prediction).
27
- char_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`):
28
- Prediction scores of the character prediction head.
29
- path_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 3)`, *optional*):
30
- Prediction scores of the path prediction head (if enabled).
31
- length_logits (`torch.FloatTensor` of shape `(batch_size, max_length+1)`, *optional*):
32
- Prediction scores of the length prediction head (if enabled).
33
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
34
  Sequence of hidden-states at the output of the last layer of the model.
35
  pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
36
  SEP token embeddings for similarity/embedding tasks.
37
  hidden_states (`tuple(torch.FloatTensor)`, *optional*):
38
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 
 
 
 
39
  """
40
 
41
- loss: Optional[torch.FloatTensor] = None
42
- char_logits: torch.FloatTensor = None
43
- path_logits: Optional[torch.FloatTensor] = None
44
- length_logits: Optional[torch.FloatTensor] = None
45
- last_hidden_state: torch.FloatTensor = None
46
- pooler_output: Optional[torch.FloatTensor] = None
47
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
 
48
 
49
 
50
  class SwipeTransformerPreTrainedModel(PreTrainedModel):
@@ -98,6 +97,7 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
98
  max_char_len=config.max_char_len,
99
  d_model=config.d_model,
100
  dropout=config.dropout,
 
101
  )
102
 
103
  # Transformer encoder
@@ -117,21 +117,26 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
117
  )
118
 
119
  # Prediction heads
120
- self.char_head = CharacterPredictionHead(
121
- d_model=config.d_model,
122
- vocab_size=config.vocab_size,
 
 
 
 
123
  )
124
 
125
  if config.predict_path:
126
- self.path_head = PathPredictionHead(d_model=config.d_model)
 
 
127
  else:
128
  self.path_head = None
129
 
130
  # Length prediction head (predicts word length from path)
131
  # Max length is max_char_len (including EOS)
132
- self.length_head = LengthPredictionHead(
133
- d_model=config.d_model,
134
- max_length=config.max_char_len,
135
  )
136
 
137
  # Initialize weights
@@ -139,35 +144,61 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
139
 
140
  def forward(
141
  self,
142
- path_coords: torch.Tensor,
143
  input_ids: torch.Tensor,
 
144
  attention_mask: torch.Tensor | None = None,
145
- labels: torch.Tensor | None = None,
146
  return_dict: bool | None = None,
147
  output_hidden_states: bool | None = None,
 
 
148
  ):
149
  """
150
  Forward pass of the model.
151
 
152
  Args:
153
- path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
154
  input_ids (torch.Tensor): Character token IDs [batch, char_len]
 
 
155
  attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
156
- labels (torch.Tensor, optional): Labels for loss calculation [batch, char_len]
 
157
  return_dict (bool, optional): Whether to return ModelOutput object
158
  output_hidden_states (bool, optional): Whether to output hidden states
 
 
159
 
160
  Returns:
161
  SwipeTransformerOutput or tuple: Model outputs with:
162
  - loss: Optional loss value
163
- - char_logits: Character prediction logits [batch, seq_len, vocab_size]
164
- - path_logits: Path prediction logits [batch, seq_len, 3] (if predict_path=True)
165
- - length_logits: Length prediction logits [batch, max_length]
166
  - last_hidden_state: Hidden states [batch, seq_len, d_model]
167
- - pooler_output: SEP token embeddings [batch, d_model] for similarity/embedding tasks
168
- - hidden_states: Tuple of hidden states (if output_hidden_states=True)
 
169
  """
 
 
 
 
 
 
 
 
 
 
 
170
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
 
 
 
 
 
171
 
172
  batch_size = path_coords.shape[0]
173
  device = path_coords.device
@@ -191,43 +222,138 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
191
  else:
192
  src_key_padding_mask = None
193
 
194
- # Encode (batch_first=True is set in TransformerEncoderLayer)
195
- hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Character prediction
198
- char_logits = self.char_head(hidden_states)
 
 
 
 
 
199
 
200
- # Path prediction (if enabled)
201
  path_logits = None
202
  if self.path_head is not None:
203
- path_logits = self.path_head(hidden_states)
 
204
 
205
  # Length prediction from CLS token
206
  cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
207
- length_logits = self.length_head(cls_hidden) # [batch, max_length]
208
 
209
  # Extract SEP token embedding for pooler output (embeddings/similarity tasks)
210
  # SEP is at position 1 + path_len
211
- path_len = path_coords.shape[1]
212
  sep_position = 1 + path_len
213
  pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
214
 
215
- # Compute loss if labels provided
216
  loss = None
217
- if labels is not None:
218
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
219
- # Extract character positions from hidden states
220
- # Sequence is: [CLS] + path + [SEP] + chars
221
- char_start = 1 + path_len + 1 # After [CLS], path, and [SEP]
222
- char_hidden = hidden_states[:, char_start : char_start + labels.shape[1], :]
223
- char_pred = self.char_head(char_hidden)
224
- loss = loss_fct(char_pred.reshape(-1, self.config.vocab_size), labels.reshape(-1))
 
 
 
225
 
226
  if not return_dict:
227
- output = (hidden_states, char_logits, length_logits, pooler_output)
228
- if path_logits is not None:
229
- output = output + (path_logits,)
230
- return ((loss,) + output) if loss is not None else output
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  return SwipeTransformerOutput(
233
  loss=loss,
@@ -236,281 +362,12 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
236
  length_logits=length_logits,
237
  last_hidden_state=hidden_states,
238
  pooler_output=pooler_output,
239
- hidden_states=(hidden_states,) if output_hidden_states else None,
240
- )
241
-
242
-
243
- class SwipeCrossEncoderForSequenceClassification(SwipeTransformerPreTrainedModel):
244
- """
245
- HuggingFace-compatible cross-encoder for sequence classification.
246
-
247
- This model is designed for similarity scoring between swipe paths and words.
248
- It extracts the SEP token embedding and passes it through a classification head.
249
-
250
- Args:
251
- config (SwipeCrossEncoderConfig): Model configuration
252
- """
253
-
254
- config_class = SwipeCrossEncoderConfig
255
- base_model_prefix = "swipe_cross_encoder"
256
-
257
- def __init__(self, config: SwipeCrossEncoderConfig):
258
- super().__init__(config)
259
- self.config = config
260
- self.num_labels = config.num_labels
261
-
262
- # Import existing components
263
- from .embeddings import MixedEmbedding
264
- from .heads import ClassificationHead
265
-
266
- # Embeddings
267
- self.embeddings = MixedEmbedding(
268
- vocab_size=config.vocab_size,
269
- max_path_len=config.max_path_len,
270
- max_char_len=config.max_char_len,
271
- d_model=config.d_model,
272
- dropout=config.dropout,
273
  )
274
 
275
- # Transformer encoder
276
- encoder_layer = nn.TransformerEncoderLayer(
277
- d_model=config.d_model,
278
- nhead=config.n_heads,
279
- dim_feedforward=config.d_ff,
280
- dropout=config.dropout,
281
- activation="gelu",
282
- batch_first=True,
283
- norm_first=True, # Pre-LayerNorm
284
- )
285
- self.encoder = nn.TransformerEncoder(
286
- encoder_layer,
287
- num_layers=config.n_layers,
288
- enable_nested_tensor=False,
289
- )
290
-
291
- # Classification head
292
- self.classifier = ClassificationHead(
293
- d_model=config.d_model,
294
- num_labels=config.num_labels,
295
- )
296
-
297
- # Initialize weights
298
- self.post_init()
299
-
300
- def forward(
301
- self,
302
- path_coords: torch.Tensor,
303
- input_ids: torch.Tensor,
304
- attention_mask: torch.Tensor | None = None,
305
- labels: torch.Tensor | None = None,
306
- return_dict: bool | None = None,
307
- ):
308
- """
309
- Forward pass for cross-encoder.
310
-
311
- Args:
312
- path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
313
- input_ids (torch.Tensor): Character token IDs [batch, char_len]
314
- attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
315
- labels (torch.Tensor, optional): Labels for loss calculation [batch, num_labels]
316
- return_dict (bool, optional): Whether to return ModelOutput object
317
-
318
- Returns:
319
- SequenceClassifierOutput or tuple: Model outputs with logits and optional loss
320
- """
321
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
322
-
323
- batch_size = path_coords.shape[0]
324
- device = path_coords.device
325
-
326
- # Create [CLS] and [SEP] tokens
327
- cls_token = torch.full(
328
- (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
329
- )
330
- sep_token = torch.full(
331
- (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
332
- )
333
-
334
- # Get embeddings
335
- embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
336
-
337
- # Prepare attention mask
338
- if attention_mask is not None:
339
- src_key_padding_mask = attention_mask == 0
340
- else:
341
- src_key_padding_mask = None
342
-
343
- # Encode (batch_first=True is set in TransformerEncoderLayer)
344
- hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
345
-
346
- # Extract SEP token embedding
347
- # SEP is at position 1 + path_len
348
- path_len = path_coords.shape[1]
349
- sep_position = 1 + path_len
350
- sep_embedding = hidden_states[:, sep_position, :] # [batch, d_model]
351
-
352
- # Classification
353
- logits = self.classifier(sep_embedding) # [batch, num_labels]
354
-
355
- # Compute loss if labels provided
356
- loss = None
357
- if labels is not None:
358
- if self.config.problem_type is None:
359
- if self.num_labels == 1:
360
- self.config.problem_type = "regression"
361
- else:
362
- self.config.problem_type = "single_label_classification"
363
-
364
- if self.config.problem_type == "regression":
365
- loss_fct = nn.MSELoss()
366
- if self.num_labels == 1:
367
- loss = loss_fct(logits.squeeze(), labels.squeeze())
368
- else:
369
- loss = loss_fct(logits, labels)
370
- elif self.config.problem_type == "single_label_classification":
371
- loss_fct = nn.CrossEntropyLoss()
372
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
373
 
374
- if not return_dict:
375
- output = (logits,) + (hidden_states,)
376
- return ((loss,) + output) if loss is not None else output
377
-
378
- return SequenceClassifierOutput(
379
- loss=loss,
380
- logits=logits,
381
- hidden_states=(hidden_states,),
382
- )
383
-
384
-
385
- class SwipeModel(SwipeTransformerPreTrainedModel):
386
- """
387
- Base Swipe model for extracting embeddings.
388
-
389
- .. deprecated::
390
- This class is deprecated. Use SwipeTransformerModel instead, which now
391
- includes pooler_output for embeddings alongside prediction heads.
392
- SwipeTransformerModel provides both predictions AND embeddings in a single model.
393
-
394
- This model returns the SEP token embedding, which can be used for:
395
- - Vector databases
396
- - Semantic search
397
- - Similarity computation
398
-
399
- The SEP token embedding represents the joint encoding of the path and text.
400
-
401
- Usage (Deprecated - use SwipeTransformerModel instead):
402
- ```python
403
- from transformers import AutoModel
404
-
405
- model = AutoModel.from_pretrained(
406
- "your-username/swipe-model",
407
- trust_remote_code=True
408
- )
409
-
410
- # Get embeddings
411
- outputs = model(path_coords=paths, input_ids=tokens)
412
- embeddings = outputs.pooler_output # SEP token embeddings
413
- ```
414
-
415
- Args:
416
- config (SwipeTransformerConfig or SwipeCrossEncoderConfig): Model configuration
417
- """
418
-
419
- def __init__(self, config):
420
- super().__init__(config)
421
- self.config = config
422
-
423
- # Import existing components
424
- from .embeddings import MixedEmbedding
425
-
426
- # Embeddings
427
- self.embeddings = MixedEmbedding(
428
- vocab_size=config.vocab_size,
429
- max_path_len=config.max_path_len,
430
- max_char_len=config.max_char_len,
431
- d_model=config.d_model,
432
- dropout=config.dropout,
433
- )
434
-
435
- # Transformer encoder
436
- encoder_layer = nn.TransformerEncoderLayer(
437
- d_model=config.d_model,
438
- nhead=config.n_heads,
439
- dim_feedforward=config.d_ff,
440
- dropout=config.dropout,
441
- activation="gelu",
442
- batch_first=True,
443
- norm_first=True, # Pre-LayerNorm
444
- )
445
- self.encoder = nn.TransformerEncoder(
446
- encoder_layer,
447
- num_layers=config.n_layers,
448
- enable_nested_tensor=False,
449
- )
450
-
451
- # Initialize weights
452
- self.post_init()
453
-
454
- def forward(
455
- self,
456
- path_coords: torch.Tensor,
457
- input_ids: torch.Tensor,
458
- attention_mask: torch.Tensor | None = None,
459
- return_dict: bool | None = None,
460
- output_hidden_states: bool | None = None,
461
- ):
462
- """
463
- Forward pass that returns embeddings.
464
-
465
- Args:
466
- path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
467
- input_ids (torch.Tensor): Character token IDs [batch, char_len]
468
- attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
469
- return_dict (bool, optional): Whether to return ModelOutput object
470
- output_hidden_states (bool, optional): Whether to output all hidden states
471
-
472
- Returns:
473
- BaseModelOutputWithPooling with:
474
- - last_hidden_state: Full sequence hidden states [batch, seq_len, d_model]
475
- - pooler_output: SEP token embeddings [batch, d_model]
476
- - hidden_states: Tuple of hidden states (if output_hidden_states=True)
477
- """
478
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
479
-
480
- batch_size = path_coords.shape[0]
481
- device = path_coords.device
482
-
483
- # Create [CLS] and [SEP] tokens
484
- cls_token = torch.full(
485
- (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
486
- )
487
- sep_token = torch.full(
488
- (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
489
- )
490
-
491
- # Get embeddings
492
- embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
493
-
494
- # Prepare attention mask
495
- if attention_mask is not None:
496
- src_key_padding_mask = attention_mask == 0
497
- else:
498
- src_key_padding_mask = None
499
-
500
- # Encode (batch_first=True is set in TransformerEncoderLayer)
501
- hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
502
-
503
- # Extract SEP token embedding (pooler output)
504
- # SEP is at position 1 + path_len
505
- path_len = path_coords.shape[1]
506
- sep_position = 1 + path_len
507
- pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
508
-
509
- if not return_dict:
510
- return (hidden_states, pooler_output)
511
-
512
- return BaseModelOutputWithPooling(
513
- last_hidden_state=hidden_states,
514
- pooler_output=pooler_output,
515
- hidden_states=(hidden_states,) if output_hidden_states else None,
516
- )
 
1
  """HuggingFace-compatible model classes for SwipeTransformer."""
2
 
3
  from dataclasses import dataclass
 
4
 
5
  import torch
6
  import torch.nn as nn
7
  from transformers import PreTrainedModel
8
+ from transformers.modeling_outputs import ModelOutput
 
 
 
 
 
9
 
10
+ from .configuration_swipe import SwipeTransformerConfig
11
 
12
 
13
  @dataclass
 
18
  Args:
19
  loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
20
  Language modeling loss (character prediction).
21
+ char_logits (`torch.FloatTensor` of shape `(batch_size, char_length, vocab_size)`):
22
+ Prediction scores of the character prediction head (text segment only).
23
+ path_logits (`torch.FloatTensor` of shape `(batch_size, path_length, path_input_dim)`, *optional*):
24
+ Prediction scores of the path prediction head (path segment only, if enabled).
25
+ length_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
26
+ Predicted length from the length head (if enabled).
27
  last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
28
  Sequence of hidden-states at the output of the last layer of the model.
29
  pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
30
  SEP token embeddings for similarity/embedding tasks.
31
  hidden_states (`tuple(torch.FloatTensor)`, *optional*):
32
+ Tuple of `torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`.
33
+ When requested, this includes the input embeddings plus one entry per encoder layer.
34
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
35
+ Tuple of attention tensors (one for each layer) of shape
36
+ `(batch_size, num_heads, sequence_length, sequence_length)`.
37
  """
38
 
39
+ loss: torch.FloatTensor | None = None
40
+ char_logits: torch.FloatTensor | None = None
41
+ path_logits: torch.FloatTensor | None = None
42
+ length_logits: torch.FloatTensor | None = None
43
+ last_hidden_state: torch.FloatTensor | None = None
44
+ pooler_output: torch.FloatTensor | None = None
45
+ hidden_states: tuple[torch.FloatTensor] | None = None
46
+ attentions: tuple[torch.FloatTensor] | None = None
47
 
48
 
49
  class SwipeTransformerPreTrainedModel(PreTrainedModel):
 
97
  max_char_len=config.max_char_len,
98
  d_model=config.d_model,
99
  dropout=config.dropout,
100
+ path_input_dim=config.path_input_dim,
101
  )
102
 
103
  # Transformer encoder
 
117
  )
118
 
119
  # Prediction heads
120
+ self.char_head = (
121
+ CharacterPredictionHead(
122
+ d_model=config.d_model,
123
+ vocab_size=config.vocab_size,
124
+ )
125
+ if config.predict_char
126
+ else None
127
  )
128
 
129
  if config.predict_path:
130
+ self.path_head = PathPredictionHead(
131
+ d_model=config.d_model, output_dim=config.path_input_dim
132
+ )
133
  else:
134
  self.path_head = None
135
 
136
  # Length prediction head (predicts word length from path)
137
  # Max length is max_char_len (including EOS)
138
+ self.length_head = (
139
+ LengthPredictionHead(d_model=config.d_model) if config.predict_length else None
 
140
  )
141
 
142
  # Initialize weights
 
144
 
145
  def forward(
146
  self,
 
147
  input_ids: torch.Tensor,
148
+ path_coords: torch.Tensor,
149
  attention_mask: torch.Tensor | None = None,
150
+ labels: torch.Tensor | dict | None = None,
151
  return_dict: bool | None = None,
152
  output_hidden_states: bool | None = None,
153
+ output_attentions: bool | None = None,
154
+ **kwargs,
155
  ):
156
  """
157
  Forward pass of the model.
158
 
159
  Args:
 
160
  input_ids (torch.Tensor): Character token IDs [batch, char_len]
161
+ path_coords (torch.Tensor): Path features [batch, path_len, path_input_dim]
162
+ Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
163
  attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
164
+ labels (torch.Tensor or dict, optional): Labels for loss calculation
165
+ Can be tensor [batch, char_len] or dict with keys like char_labels, path_labels
166
  return_dict (bool, optional): Whether to return ModelOutput object
167
  output_hidden_states (bool, optional): Whether to output hidden states
168
+ output_attentions (bool, optional): Whether to output attention weights
169
+ **kwargs: Additional arguments (for compatibility)
170
 
171
  Returns:
172
  SwipeTransformerOutput or tuple: Model outputs with:
173
  - loss: Optional loss value
174
+ - char_logits: Character prediction logits [batch, char_len, vocab_size] (if enabled)
175
+ - path_logits: Path prediction logits [batch, path_len, path_input_dim] (if enabled)
176
+ - length_logits: Length regression output [batch] (if enabled)
177
  - last_hidden_state: Hidden states [batch, seq_len, d_model]
178
+ - pooler_output: SEP token embedding [batch, d_model] for similarity/embedding tasks
179
+ - hidden_states: Tuple of per-layer hidden states (if output_hidden_states=True)
180
+ - attentions: Tuple of per-layer attention weights (if output_attentions=True)
181
  """
182
+ # Validate required inputs
183
+ if input_ids is None or path_coords is None:
184
+ raise ValueError("Both input_ids and path_coords are required")
185
+
186
+ # Extract labels if dict (used by custom trainers)
187
+ if isinstance(labels, dict):
188
+ char_labels = labels.get("char_labels")
189
+ # Can handle other label types in the future (path_labels, etc.)
190
+ else:
191
+ char_labels = labels
192
+
193
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
194
+ output_hidden_states = (
195
+ output_hidden_states
196
+ if output_hidden_states is not None
197
+ else self.config.output_hidden_states
198
+ )
199
+ output_attentions = (
200
+ output_attentions if output_attentions is not None else self.config.output_attentions
201
+ )
202
 
203
  batch_size = path_coords.shape[0]
204
  device = path_coords.device
 
222
  else:
223
  src_key_padding_mask = None
224
 
225
+ # Encode while optionally capturing attentions and per-layer hidden states.
226
+ attentions: tuple[torch.Tensor, ...] | None = None
227
+ hidden_states_by_layer: list[torch.Tensor] | None = [] if output_hidden_states else None
228
+
229
+ hooks = []
230
+ original_forwards: dict[int, callable] = {}
231
+ attentions_buffer: list[torch.Tensor | None] | None = None
232
+
233
+ def make_patched_forward(original_forward):
234
+ def patched_forward(
235
+ query,
236
+ key,
237
+ value,
238
+ key_padding_mask=None,
239
+ need_weights=True,
240
+ attn_mask=None,
241
+ average_attn_weights=False,
242
+ is_causal=False,
243
+ ):
244
+ return original_forward(
245
+ query,
246
+ key,
247
+ value,
248
+ key_padding_mask=key_padding_mask,
249
+ need_weights=True,
250
+ attn_mask=attn_mask,
251
+ average_attn_weights=False,
252
+ is_causal=is_causal,
253
+ )
254
+
255
+ return patched_forward
256
+
257
+ def make_hook(layer_idx: int):
258
+ def hook(_module: nn.Module, _input: tuple, output: tuple):
259
+ if (
260
+ attentions_buffer is not None
261
+ and isinstance(output, tuple)
262
+ and len(output) > 1
263
+ and output[1] is not None
264
+ ):
265
+ attentions_buffer[layer_idx] = output[1]
266
+
267
+ return hook
268
+
269
+ if output_attentions:
270
+ attentions_buffer = [None] * len(self.encoder.layers)
271
+ for idx, layer in enumerate(self.encoder.layers):
272
+ attn_module = layer.self_attn
273
+ original_forwards[idx] = attn_module.forward
274
+ attn_module.forward = make_patched_forward(original_forwards[idx])
275
+ hooks.append(attn_module.register_forward_hook(make_hook(idx)))
276
+
277
+ try:
278
+ x = embeddings
279
+ for layer in self.encoder.layers:
280
+ x = layer(x, src_key_padding_mask=src_key_padding_mask)
281
+ if hidden_states_by_layer is not None:
282
+ hidden_states_by_layer.append(x)
283
+ hidden_states = x
284
+
285
+ if attentions_buffer is not None:
286
+ if any(a is None for a in attentions_buffer):
287
+ missing = [i for i, a in enumerate(attentions_buffer) if a is None]
288
+ raise RuntimeError(
289
+ f"Failed to capture attention weights for layers: {missing}."
290
+ )
291
+ attentions = tuple(attentions_buffer) # type: ignore[assignment]
292
+ finally:
293
+ for hook in hooks:
294
+ hook.remove()
295
+ for idx, layer in enumerate(self.encoder.layers):
296
+ if idx in original_forwards:
297
+ layer.self_attn.forward = original_forwards[idx]
298
+
299
+ path_len = path_coords.shape[1]
300
+ char_len = input_ids.shape[1]
301
 
302
+ # Character prediction (text segment only)
303
+ char_logits = None
304
+ if self.char_head is not None:
305
+ # Sequence is: [CLS] + path + [SEP] + chars
306
+ char_start = 1 + path_len + 1
307
+ char_hidden = hidden_states[:, char_start : char_start + char_len, :]
308
+ char_logits = self.char_head(char_hidden)
309
 
310
+ # Path prediction (path segment only, if enabled)
311
  path_logits = None
312
  if self.path_head is not None:
313
+ path_hidden = hidden_states[:, 1 : 1 + path_len, :]
314
+ path_logits = self.path_head(path_hidden)
315
 
316
  # Length prediction from CLS token
317
  cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
318
+ length_logits = self.length_head(cls_hidden) if self.length_head is not None else None
319
 
320
  # Extract SEP token embedding for pooler output (embeddings/similarity tasks)
321
  # SEP is at position 1 + path_len
 
322
  sep_position = 1 + path_len
323
  pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
324
 
325
+ # Compute loss if labels provided (masked-only; -100 = ignore)
326
  loss = None
327
+ if char_labels is not None and self.char_head is not None:
328
+ # Predict only the text segment
329
+ char_pred = char_logits # [B, char_len, V]
330
+ labels_flat = char_labels.reshape(-1)
331
+ mask = labels_flat != -100
332
+ if mask.any():
333
+ logits_flat = char_pred.reshape(-1, self.config.vocab_size)[mask]
334
+ labels_flat = labels_flat[mask]
335
+ loss = nn.functional.cross_entropy(logits_flat, labels_flat, reduction="mean")
336
+ else:
337
+ loss = torch.tensor(0.0, device=hidden_states.device)
338
 
339
  if not return_dict:
340
+ hidden_tuple = None
341
+ if hidden_states_by_layer is not None:
342
+ hidden_tuple = (embeddings,) + tuple(hidden_states_by_layer)
343
+ output = (
344
+ char_logits,
345
+ path_logits,
346
+ length_logits,
347
+ hidden_states,
348
+ pooler_output,
349
+ hidden_tuple,
350
+ attentions,
351
+ )
352
+ return (loss,) + output if loss is not None else output
353
+
354
+ all_hidden_states = None
355
+ if hidden_states_by_layer is not None:
356
+ all_hidden_states = (embeddings,) + tuple(hidden_states_by_layer)
357
 
358
  return SwipeTransformerOutput(
359
  loss=loss,
 
362
  length_logits=length_logits,
363
  last_hidden_state=hidden_states,
364
  pooler_output=pooler_output,
365
+ hidden_states=all_hidden_states,
366
+ attentions=attentions,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  )
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ #
371
+ # Legacy note:
372
+ # `SwipeModel` (embeddings-only) has been removed; use `SwipeTransformerModel` and read
373
+ # `outputs.pooler_output` for embeddings.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
preprocessing.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared preprocessing utilities for swipe path data.
2
+
3
+ This module provides a single source of truth for path preprocessing,
4
+ used by both the training dataset and the HuggingFace processor.
5
+ """
6
+
7
+ import numpy as np
8
+
9
+
10
+ def preprocess_raw_path_to_features(
11
+ data_points: list[dict],
12
+ max_len: int,
13
+ *,
14
+ resample_mode: str = "spatial",
15
+ dt_clamp_min_ms: float = 1.0,
16
+ dt_clamp_max_ms: float = 200.0,
17
+ ) -> tuple[np.ndarray, np.ndarray]:
18
+ """Convert a raw `{"x","y","t"}` path to fixed-length engineered features.
19
+
20
+ This is the fast path used by training and the HuggingFace processor. It avoids
21
+ building an intermediate list-of-dicts representation by:
22
+ 1) extracting x/y/t arrays once,
23
+ 2) resampling x/y using spatial- or time-uniform interpolation,
24
+ 3) recomputing dx/dy/ds and log_dt on the resampled trajectory.
25
+
26
+ Args:
27
+ data_points: Raw path as a list of dicts with keys: "x", "y", "t".
28
+ max_len: Target length.
29
+ resample_mode: "spatial" (arc-length) or "time" (cumulative dt).
30
+ dt_clamp_min_ms: Clamp for dt feature after resampling (first dt remains 0).
31
+ dt_clamp_max_ms: Clamp for dt feature after resampling.
32
+
33
+ Returns:
34
+ (features, mask) where:
35
+ - features: [max_len, 6] float32 array (x, y, dx, dy, ds, log_dt)
36
+ - mask: [max_len] int64 array (1 for valid; all-ones for non-empty paths)
37
+ """
38
+ num_points = len(data_points)
39
+ if num_points == 0:
40
+ return (
41
+ np.zeros((max_len, 6), dtype=np.float32),
42
+ np.zeros(max_len, dtype=np.int64),
43
+ )
44
+
45
+ x = np.fromiter((p["x"] for p in data_points), dtype=np.float64, count=num_points)
46
+ y = np.fromiter((p["y"] for p in data_points), dtype=np.float64, count=num_points)
47
+ t = np.fromiter((p["t"] for p in data_points), dtype=np.float64, count=num_points)
48
+
49
+ x = np.clip(x, 0.0, 1.0)
50
+ y = np.clip(y, 0.0, 1.0)
51
+
52
+ # Per-step deltas and axes for resampling
53
+ dx_in = np.concatenate([[0.0], np.diff(x)])
54
+ dy_in = np.concatenate([[0.0], np.diff(y)])
55
+ ds_in = np.hypot(dx_in, dy_in)
56
+ dt_raw_in = np.concatenate([[0.0], np.diff(t)])
57
+
58
+ s = np.cumsum(ds_in)
59
+ tau = np.cumsum(dt_raw_in)
60
+
61
+ if resample_mode not in {"spatial", "time"}:
62
+ raise ValueError(f"Unknown resample_mode={resample_mode!r} (use 'spatial' or 'time')")
63
+
64
+ eps = 1e-12
65
+ if resample_mode == "time" and tau[-1] > eps:
66
+ target_tau = np.linspace(0.0, float(tau[-1]), max_len, dtype=np.float64)
67
+ x_r = np.interp(target_tau, tau, x)
68
+ y_r = np.interp(target_tau, tau, y)
69
+ tau_r = target_tau
70
+ else:
71
+ # Spatial sampling (or fallback when time axis is degenerate).
72
+ if s[-1] <= eps:
73
+ original = np.arange(num_points, dtype=np.float64)
74
+ target = np.linspace(0, num_points - 1, max_len, dtype=np.float64)
75
+ x_r = np.interp(target, original, x)
76
+ y_r = np.interp(target, original, y)
77
+ tau_r = np.interp(target, original, tau)
78
+ else:
79
+ target_s = np.linspace(0.0, float(s[-1]), max_len, dtype=np.float64)
80
+ x_r = np.interp(target_s, s, x)
81
+ y_r = np.interp(target_s, s, y)
82
+ tau_r = np.interp(target_s, s, tau)
83
+
84
+ dx = np.concatenate([[0.0], np.diff(x_r)])
85
+ dy = np.concatenate([[0.0], np.diff(y_r)])
86
+ ds = np.hypot(dx, dy)
87
+ dt_raw_r = np.concatenate([[0.0], np.diff(tau_r)])
88
+ dt_feat = np.clip(dt_raw_r, dt_clamp_min_ms, dt_clamp_max_ms)
89
+ dt_feat[0] = 0.0
90
+ log_dt = np.log1p(np.maximum(0.0, dt_feat))
91
+
92
+ mask = np.ones(max_len, dtype=np.int64)
93
+ features = np.stack([x_r, y_r, dx, dy, ds, log_dt], axis=-1).astype(np.float32)
94
+ return features, mask
95
+
96
+
97
+ def normalize_and_compute_features(
98
+ data_points: list[dict],
99
+ dt_clamp_min_ms: float = 1.0,
100
+ dt_clamp_max_ms: float = 200.0,
101
+ ) -> list[dict]:
102
+ """
103
+ Normalize coordinates and compute motion features.
104
+
105
+ Computes delta features (dx, dy, dt) and log-scaled time deltas.
106
+ First point has dx=dy=dt=0 by convention.
107
+
108
+ Args:
109
+ data_points: List of {"x", "y", "t"} dicts
110
+ dt_clamp_min_ms: Minimum dt in milliseconds (inclusive).
111
+ dt_clamp_max_ms: Maximum dt in milliseconds (inclusive).
112
+
113
+ Returns:
114
+ List of dicts with keys:
115
+ - x, y: normalized coordinates in [0, 1]
116
+ - t: raw timestamp from input (passed through)
117
+ - dx, dy: deltas in x/y
118
+ - ds: sqrt(dx^2 + dy^2)
119
+ - dt_raw: raw time delta (unclamped)
120
+ - dt: clamped time delta used for feature stability
121
+ - log_dt: log1p(dt)
122
+ """
123
+ if not data_points:
124
+ return []
125
+
126
+ num_points = len(data_points)
127
+ x = np.fromiter((p["x"] for p in data_points), dtype=np.float64, count=num_points)
128
+ y = np.fromiter((p["y"] for p in data_points), dtype=np.float64, count=num_points)
129
+ t = np.fromiter((p["t"] for p in data_points), dtype=np.float64, count=num_points)
130
+
131
+ x = np.clip(x, 0.0, 1.0)
132
+ y = np.clip(y, 0.0, 1.0)
133
+
134
+ dx = np.concatenate([[0.0], np.diff(x)])
135
+ dy = np.concatenate([[0.0], np.diff(y)])
136
+ ds = np.hypot(dx, dy)
137
+ dt_raw = np.concatenate([[0.0], np.diff(t)])
138
+
139
+ dt = np.clip(dt_raw, dt_clamp_min_ms, dt_clamp_max_ms)
140
+ dt[0] = 0.0
141
+ log_dt = np.log1p(np.maximum(0.0, dt))
142
+
143
+ out: list[dict] = []
144
+ for i in range(num_points):
145
+ out.append(
146
+ {
147
+ "x": float(x[i]),
148
+ "y": float(y[i]),
149
+ "t": float(t[i]),
150
+ "dx": float(dx[i]),
151
+ "dy": float(dy[i]),
152
+ "ds": float(ds[i]),
153
+ "dt_raw": float(dt_raw[i]),
154
+ "dt": float(dt[i]),
155
+ "log_dt": float(log_dt[i]),
156
+ }
157
+ )
158
+ return out
159
+
160
+
161
+ def sample_path_points_with_features(
162
+ data_points: list[dict],
163
+ max_len: int,
164
+ *,
165
+ resample_mode: str = "spatial",
166
+ dt_clamp_min_ms: float = 1.0,
167
+ dt_clamp_max_ms: float = 200.0,
168
+ ) -> tuple[np.ndarray, np.ndarray]:
169
+ """
170
+ Sample path points with motion features to fixed length using interpolation.
171
+
172
+ Always uses interpolation (no zero-padding) to preserve feature structure.
173
+ Paths shorter than max_len are upsampled; longer paths are downsampled.
174
+
175
+ Modes:
176
+ - resample_mode="spatial": sample approximately uniformly in arc length (distance).
177
+ - resample_mode="time": sample uniformly in time (dwell regions get more samples).
178
+
179
+ Args:
180
+ data_points: List of coordinate dicts. Expected keys: x, y and either:
181
+ - dx, dy (preferred), plus optional ds, dt, log_dt; or
182
+ - ds/log_dt/dt (ds can be derived from dx/dy; dt from log_dt).
183
+ max_len: Target length
184
+ resample_mode: "spatial" or "time"
185
+ dt_clamp_min_ms: Clamp for dt feature after resampling (first dt remains 0).
186
+ dt_clamp_max_ms: Clamp for dt feature after resampling.
187
+
188
+ Returns:
189
+ Tuple of (features, mask) where:
190
+ - features: [max_len, 6] array with (x, y, dx, dy, ds, log_dt)
191
+ - mask: [max_len] binary mask (all 1s since we always interpolate)
192
+ """
193
+ num_points = len(data_points)
194
+
195
+ if num_points == 0:
196
+ # Empty path - return zeros
197
+ return (
198
+ np.zeros((max_len, 6), dtype=np.float32),
199
+ np.zeros(max_len, dtype=np.int64),
200
+ )
201
+
202
+ # Extract base signals
203
+ x = np.fromiter((p["x"] for p in data_points), dtype=np.float64, count=num_points)
204
+ y = np.fromiter((p["y"] for p in data_points), dtype=np.float64, count=num_points)
205
+
206
+ # Prefer provided dx/dy, otherwise derive from x/y
207
+ if all("dx" in p for p in data_points) and all("dy" in p for p in data_points):
208
+ dx_in = np.fromiter((p["dx"] for p in data_points), dtype=np.float64, count=num_points)
209
+ dy_in = np.fromiter((p["dy"] for p in data_points), dtype=np.float64, count=num_points)
210
+ else:
211
+ dx_in = np.concatenate([[0.0], np.diff(x)])
212
+ dy_in = np.concatenate([[0.0], np.diff(y)])
213
+
214
+ # ds can be provided or derived from dx/dy
215
+ if all("ds" in p for p in data_points):
216
+ ds_in = np.fromiter((p["ds"] for p in data_points), dtype=np.float64, count=num_points)
217
+ else:
218
+ ds_in = np.sqrt(dx_in**2 + dy_in**2)
219
+
220
+ # Time axis for resampling: prefer dt_raw (unclamped) so "dwell" gets represented.
221
+ if all("dt_raw" in p for p in data_points):
222
+ dt_axis = np.fromiter(
223
+ (p["dt_raw"] for p in data_points), dtype=np.float64, count=num_points
224
+ )
225
+ elif all("dt" in p for p in data_points):
226
+ dt_axis = np.fromiter((p["dt"] for p in data_points), dtype=np.float64, count=num_points)
227
+ elif all("log_dt" in p for p in data_points):
228
+ log_dt_in_raw = np.fromiter(
229
+ (p["log_dt"] for p in data_points), dtype=np.float64, count=num_points
230
+ )
231
+ dt_axis = np.expm1(log_dt_in_raw)
232
+ else:
233
+ dt_axis = np.zeros(num_points, dtype=np.float64)
234
+
235
+ # Cumulative arc length (s) and cumulative time (tau) for resampling
236
+ s = np.cumsum(ds_in)
237
+ tau = np.cumsum(dt_axis)
238
+
239
+ if resample_mode not in {"spatial", "time"}:
240
+ raise ValueError(f"Unknown resample_mode={resample_mode!r} (use 'spatial' or 'time')")
241
+
242
+ eps = 1e-12
243
+
244
+ if resample_mode == "time" and tau[-1] > eps:
245
+ target_tau = np.linspace(0.0, float(tau[-1]), max_len, dtype=np.float64)
246
+ x_r = np.interp(target_tau, tau, x)
247
+ y_r = np.interp(target_tau, tau, y)
248
+ tau_r = target_tau
249
+ else:
250
+ # Spatial sampling (or fallback when time axis is degenerate).
251
+ # Handle degenerate paths (zero movement): fall back to index-based interpolation
252
+ if s[-1] <= eps:
253
+ original = np.arange(num_points, dtype=np.float64)
254
+ target = np.linspace(0, num_points - 1, max_len, dtype=np.float64)
255
+ x_r = np.interp(target, original, x)
256
+ y_r = np.interp(target, original, y)
257
+ tau_r = np.interp(target, original, tau)
258
+ else:
259
+ target_s = np.linspace(0.0, float(s[-1]), max_len, dtype=np.float64)
260
+ x_r = np.interp(target_s, s, x)
261
+ y_r = np.interp(target_s, s, y)
262
+ tau_r = np.interp(target_s, s, tau)
263
+
264
+ # Recompute deltas on the resampled path for consistency
265
+ dx = np.concatenate([[0.0], np.diff(x_r)])
266
+ dy = np.concatenate([[0.0], np.diff(y_r)])
267
+ ds = np.sqrt(dx**2 + dy**2)
268
+ dt_raw_r = np.concatenate([[0.0], np.diff(tau_r)])
269
+ dt_feat = np.clip(dt_raw_r, dt_clamp_min_ms, dt_clamp_max_ms)
270
+ dt_feat[0] = 0.0
271
+ log_dt = np.log1p(np.maximum(0.0, dt_feat))
272
+
273
+ mask = np.ones(max_len, dtype=np.int64)
274
+ features = np.stack([x_r, y_r, dx, dy, ds, log_dt], axis=-1).astype(np.float32)
275
+ return features, mask
processing_swipe.py CHANGED
@@ -1,9 +1,15 @@
1
  """Processor for handling multimodal swipe inputs (path + text)."""
2
 
 
 
 
 
3
  import numpy as np
4
  import torch
5
  from transformers import ProcessorMixin
6
 
 
 
7
 
8
  class SwipeProcessor(ProcessorMixin):
9
  """
@@ -21,10 +27,19 @@ class SwipeProcessor(ProcessorMixin):
21
  attributes = ["tokenizer"]
22
  tokenizer_class = "AutoTokenizer" # Will use auto_map from tokenizer_config.json
23
 
24
- def __init__(self, tokenizer=None, max_path_len: int = 64, max_char_len: int = 38):
 
 
 
 
 
 
 
25
  self.tokenizer = tokenizer
26
  self.max_path_len = max_path_len
27
  self.max_char_len = max_char_len
 
 
28
  # Attributes expected by newer transformers (not used for swipe models)
29
  self.chat_template = None
30
  self.audio_tokenizer = None
@@ -33,21 +48,36 @@ class SwipeProcessor(ProcessorMixin):
33
 
34
  def __call__(
35
  self,
36
- path_coords: list[list[list[float]]] | torch.Tensor | np.ndarray | None = None,
 
 
 
 
 
 
 
37
  text: str | list[str] | None = None,
38
  padding: bool | str = True,
39
  truncation: bool = True,
40
  max_length: int | None = None,
41
  return_tensors: str | None = "pt",
42
- **kwargs,
43
  ):
44
  """
45
  Process path coordinates and text into model inputs.
46
 
47
  Args:
48
- path_coords: List of paths or tensor [batch, path_len, 3]
49
- Each point is (x, y, time). Can be None if only processing text.
50
- text: String or list of strings to encode. Can be None if only processing paths.
 
 
 
 
 
 
 
 
51
  padding: Whether to pad sequences. Can be True/False or "max_length"
52
  truncation: Whether to truncate sequences
53
  max_length: Maximum sequence length for text (overrides max_char_len)
@@ -56,9 +86,10 @@ class SwipeProcessor(ProcessorMixin):
56
 
57
  Returns:
58
  Dictionary with:
59
- - path_coords: [batch, max_path_len, 3] (if path_coords provided)
 
60
  - input_ids: [batch, max_char_len] (if text provided)
61
- - attention_mask: [batch, total_seq_len]
62
  """
63
  if path_coords is None and text is None:
64
  raise ValueError("Must provide either path_coords or text (or both)")
@@ -67,24 +98,43 @@ class SwipeProcessor(ProcessorMixin):
67
  if path_coords is not None:
68
  # Handle path coordinates
69
  if isinstance(path_coords, (list, tuple)):
70
- # Check if it's a batch or single path
71
- if len(path_coords) > 0 and isinstance(path_coords[0][0], (list, tuple)):
72
- # Batch of paths [[path1], [path2], ...]
73
- path_coords = torch.tensor(path_coords, dtype=torch.float32)
74
  else:
75
- # Single path [[x,y,t], [x,y,t], ...]
76
- path_coords = torch.tensor([path_coords], dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  elif isinstance(path_coords, np.ndarray):
78
  path_coords = torch.from_numpy(path_coords).float()
79
  if path_coords.dim() == 2:
80
  # Single path, add batch dimension
81
  path_coords = path_coords.unsqueeze(0)
 
82
  elif isinstance(path_coords, torch.Tensor):
83
  if path_coords.dim() == 2:
84
  # Single path, add batch dimension
85
  path_coords = path_coords.unsqueeze(0)
86
-
87
- batch_size = path_coords.shape[0]
88
  elif text is not None:
89
  if isinstance(text, str):
90
  batch_size = 1
@@ -98,31 +148,120 @@ class SwipeProcessor(ProcessorMixin):
98
 
99
  # Process path coordinates
100
  if path_coords is not None:
101
- current_path_len = path_coords.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Truncate if needed
104
- if truncation and current_path_len > self.max_path_len:
105
- path_coords = path_coords[:, : self.max_path_len, :]
106
- current_path_len = self.max_path_len
107
 
108
- # Pad if needed
109
- if padding and current_path_len < self.max_path_len:
110
- pad_len = self.max_path_len - current_path_len
111
- path_coords = torch.cat([path_coords, torch.zeros(batch_size, pad_len, 3)], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Create path mask (1 = real data, 0 = padding)
114
- # Detect padding by checking for all-zero coordinates
115
- path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
116
- # A point is padding if all its coordinates (x, y, t) are zero
117
- is_padding = (path_coords == 0).all(dim=-1) # [batch, path_len]
118
- path_mask[is_padding] = 0
119
 
120
  result["path_coords"] = path_coords
121
- # Store path_mask internally for attention_mask construction
122
- _path_mask = path_mask
123
  else:
124
  # No path coords provided, create empty/zero tensors
125
- path_coords = torch.zeros(batch_size, self.max_path_len, 3)
126
  _path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
127
  result["path_coords"] = path_coords
128
 
@@ -157,9 +296,9 @@ class SwipeProcessor(ProcessorMixin):
157
  # Truncate but preserve EOS at the end
158
  for i in range(len(encoded_raw["input_ids"])):
159
  if len(encoded_raw["input_ids"][i]) > text_max_length:
160
- encoded_raw["input_ids"][i] = (
161
- encoded_raw["input_ids"][i][: text_max_length - 1] + [eos_id]
162
- )
163
 
164
  # Pad sequences
165
  if padding:
@@ -264,104 +403,48 @@ class SwipeProcessor(ProcessorMixin):
264
  """
265
  return self.tokenizer.decode(token_ids, **kwargs)
266
 
267
- def normalize_coordinates(
268
- self, data_points: list[dict], canvas_width: float = None, canvas_height: float = None
269
- ) -> list[dict]:
270
- """
271
- Normalize swipe coordinates and timestamps.
272
-
273
- Args:
274
- data_points: List of dicts with 'x', 'y', 't' keys
275
- canvas_width: Canvas width (not used - kept for compatibility)
276
- canvas_height: Canvas height (not used - kept for compatibility)
277
-
278
- Returns:
279
- List of normalized coordinate dicts with x, y in [0,1] and t in [0,1]
280
-
281
- Note:
282
- For futo-org/swipe.futo.org dataset, x and y are already normalized to [0,1].
283
- This function clamps them to ensure they stay in bounds and normalizes timestamps.
284
- """
285
- if not data_points:
286
- return []
287
-
288
- # Extract timestamps for normalization
289
- timestamps = [p["t"] for p in data_points]
290
- t_min = min(timestamps)
291
- t_max = max(timestamps)
292
- t_range = t_max - t_min if t_max > t_min else 1.0
293
-
294
- normalized = []
295
- for point in data_points:
296
- # x and y are already normalized to [0,1] in the dataset
297
- # But sometimes they go slightly outside bounds, so clamp them
298
- x_norm = max(0.0, min(1.0, point["x"]))
299
- y_norm = max(0.0, min(1.0, point["y"]))
300
-
301
- # Normalize timestamp to [0, 1]
302
- t_norm = (point["t"] - t_min) / t_range
303
 
304
- normalized.append({"x": x_norm, "y": y_norm, "t": t_norm})
 
 
305
 
306
- return normalized
 
307
 
308
- def sample_path_points(self, data_points: list[dict], max_len: int = None) -> tuple:
 
 
 
 
 
309
  """
310
- Sample or pad path points to fixed length using linear interpolation.
311
-
312
- Args:
313
- data_points: List of coordinate dicts with 'x', 'y', 't' keys
314
- max_len: Target length (defaults to self.max_path_len if not specified)
315
-
316
- Returns:
317
- Tuple of (sampled_points, mask) where:
318
- - sampled_points: numpy array of shape [max_len, 3] with (x, y, t) coordinates
319
- - mask: numpy array of shape [max_len] indicating valid (1) vs padding (0) points
320
-
321
- Note:
322
- - If path has fewer points than max_len, it's zero-padded
323
- - If path has more points than max_len, it's downsampled using linear interpolation
324
- - If path has exactly max_len points, it's returned as-is
325
  """
326
- if max_len is None:
327
- max_len = self.max_path_len
328
-
329
- num_points = len(data_points)
330
-
331
- if num_points == max_len:
332
- points = data_points
333
- mask = [1] * max_len
334
- elif num_points < max_len:
335
- # Pad with zeros
336
- points = data_points + [{"x": 0.0, "y": 0.0, "t": 0.0}] * (max_len - num_points)
337
- mask = [1] * num_points + [0] * (max_len - num_points)
338
- else:
339
- # Downsample using linear interpolation
340
- # Extract coordinates as arrays
341
- x_coords = np.array([p["x"] for p in data_points])
342
- y_coords = np.array([p["y"] for p in data_points])
343
- t_coords = np.array([p["t"] for p in data_points])
344
-
345
- # Original indices (parameter for interpolation)
346
- original_indices = np.arange(num_points)
347
-
348
- # Target indices for interpolation (evenly spaced)
349
- target_indices = np.linspace(0, num_points - 1, max_len)
350
-
351
- # Interpolate each coordinate independently
352
- x_interp = np.interp(target_indices, original_indices, x_coords)
353
- y_interp = np.interp(target_indices, original_indices, y_coords)
354
- t_interp = np.interp(target_indices, original_indices, t_coords)
355
-
356
- # Reconstruct points
357
- points = [
358
- {"x": float(x), "y": float(y), "t": float(t)}
359
- for x, y, t in zip(x_interp, y_interp, t_interp, strict=True)
360
- ]
361
- mask = [1] * max_len
362
 
363
- # Convert to numpy arrays
364
- coords = np.array([[p["x"], p["y"], p["t"]] for p in points], dtype=np.float32)
365
- mask = np.array(mask, dtype=np.int64)
366
-
367
- return coords, mask
 
1
  """Processor for handling multimodal swipe inputs (path + text)."""
2
 
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
  import numpy as np
8
  import torch
9
  from transformers import ProcessorMixin
10
 
11
+ from .preprocessing import preprocess_raw_path_to_features
12
+
13
 
14
  class SwipeProcessor(ProcessorMixin):
15
  """
 
27
  attributes = ["tokenizer"]
28
  tokenizer_class = "AutoTokenizer" # Will use auto_map from tokenizer_config.json
29
 
30
+ def __init__(
31
+ self,
32
+ tokenizer=None,
33
+ max_path_len: int = 64,
34
+ max_char_len: int = 38,
35
+ path_input_dim: int = 6,
36
+ path_resample_mode: str = "time",
37
+ ):
38
  self.tokenizer = tokenizer
39
  self.max_path_len = max_path_len
40
  self.max_char_len = max_char_len
41
+ self.path_input_dim = path_input_dim
42
+ self.path_resample_mode = path_resample_mode
43
  # Attributes expected by newer transformers (not used for swipe models)
44
  self.chat_template = None
45
  self.audio_tokenizer = None
 
48
 
49
  def __call__(
50
  self,
51
+ path_coords: (
52
+ list[dict[str, float]]
53
+ | list[list[dict[str, float]]]
54
+ | list[list[list[float]]]
55
+ | torch.Tensor
56
+ | np.ndarray
57
+ | None
58
+ ) = None,
59
  text: str | list[str] | None = None,
60
  padding: bool | str = True,
61
  truncation: bool = True,
62
  max_length: int | None = None,
63
  return_tensors: str | None = "pt",
64
+ **kwargs: Any,
65
  ):
66
  """
67
  Process path coordinates and text into model inputs.
68
 
69
  Args:
70
+ path_coords:
71
+ Swipe paths in one of the supported formats:
72
+ - Raw path (single example): list of dicts like `{"x": ..., "y": ..., "t": ...}`
73
+ - Raw batch: list of raw paths
74
+ - Numeric arrays/tensors: `[batch, path_len, D]` or `[path_len, D]`
75
+ If `D==3` and `path_input_dim==6`, raw `(x,y,t)` triples are converted to engineered
76
+ `(x, y, dx, dy, ds, log_dt)` features and resampled to `max_path_len`.
77
+ If omitted, the processor emits a zero path with a zero path attention mask.
78
+ text:
79
+ String or list of strings to encode.
80
+ If omitted, the processor emits padded text tokens with a zero text attention mask.
81
  padding: Whether to pad sequences. Can be True/False or "max_length"
82
  truncation: Whether to truncate sequences
83
  max_length: Maximum sequence length for text (overrides max_char_len)
 
86
 
87
  Returns:
88
  Dictionary with:
89
+ - path_coords: [batch, max_path_len, path_input_dim] (if path_coords provided)
90
+ Default: [batch, max_path_len, 6] for (x, y, dx, dy, ds, log_dt)
91
  - input_ids: [batch, max_char_len] (if text provided)
92
+ - attention_mask: [batch, total_seq_len] (covers `[CLS] + path + [SEP] + text`)
93
  """
94
  if path_coords is None and text is None:
95
  raise ValueError("Must provide either path_coords or text (or both)")
 
98
  if path_coords is not None:
99
  # Handle path coordinates
100
  if isinstance(path_coords, (list, tuple)):
101
+ if len(path_coords) == 0:
102
+ batch_size = 1
 
 
103
  else:
104
+ first = path_coords[0]
105
+ # Raw single path: [{"x","y","t"}, ...]
106
+ if isinstance(first, dict):
107
+ batch_size = 1
108
+ # Raw batch of paths: [[{"x","y","t"}, ...], ...]
109
+ elif (
110
+ isinstance(first, (list, tuple))
111
+ and len(first) > 0
112
+ and isinstance(first[0], dict)
113
+ ):
114
+ batch_size = len(path_coords)
115
+ # Numeric batch: [[[...], ...], ...] where points are lists/tuples
116
+ elif (
117
+ isinstance(first, (list, tuple))
118
+ and len(first) > 0
119
+ and isinstance(first[0], (list, tuple))
120
+ ):
121
+ path_coords = torch.tensor(path_coords, dtype=torch.float32)
122
+ batch_size = path_coords.shape[0]
123
+ else:
124
+ # Numeric single path: [[...], [...], ...]
125
+ path_coords = torch.tensor([path_coords], dtype=torch.float32)
126
+ batch_size = path_coords.shape[0]
127
  elif isinstance(path_coords, np.ndarray):
128
  path_coords = torch.from_numpy(path_coords).float()
129
  if path_coords.dim() == 2:
130
  # Single path, add batch dimension
131
  path_coords = path_coords.unsqueeze(0)
132
+ batch_size = path_coords.shape[0]
133
  elif isinstance(path_coords, torch.Tensor):
134
  if path_coords.dim() == 2:
135
  # Single path, add batch dimension
136
  path_coords = path_coords.unsqueeze(0)
137
+ batch_size = path_coords.shape[0]
 
138
  elif text is not None:
139
  if isinstance(text, str):
140
  batch_size = 1
 
148
 
149
  # Process path coordinates
150
  if path_coords is not None:
151
+ # Check if path_coords is raw data (list of dicts) or already a tensor
152
+ if isinstance(path_coords, (list, tuple)) and len(path_coords) > 0:
153
+ first_elem = path_coords[0]
154
+
155
+ # Raw single path: [{"x","y","t"}, ...]
156
+ if isinstance(first_elem, dict) and "x" in first_elem:
157
+ path_feats, mask = preprocess_raw_path_to_features(
158
+ path_coords,
159
+ self.max_path_len,
160
+ resample_mode=self.path_resample_mode,
161
+ )
162
+ if return_tensors == "pt":
163
+ path_coords = torch.from_numpy(path_feats).float().unsqueeze(0)
164
+ _path_mask = torch.from_numpy(mask).long().unsqueeze(0)
165
+ else:
166
+ path_coords = np.expand_dims(path_feats, axis=0)
167
+ _path_mask = np.expand_dims(mask, axis=0)
168
+
169
+ # Raw batch of paths: [[{"x","y","t"}, ...], ...]
170
+ elif (
171
+ isinstance(first_elem, (list, tuple))
172
+ and len(first_elem) > 0
173
+ and isinstance(first_elem[0], dict)
174
+ and "x" in first_elem[0]
175
+ ):
176
+ processed_paths = []
177
+ path_masks = []
178
+ for path in path_coords:
179
+ path_feats, mask = preprocess_raw_path_to_features(
180
+ path,
181
+ self.max_path_len,
182
+ resample_mode=self.path_resample_mode,
183
+ )
184
+ processed_paths.append(path_feats)
185
+ path_masks.append(mask)
186
 
187
+ path_coords = np.stack(processed_paths) # [batch, max_path_len, 6]
188
+ _path_mask = np.stack(path_masks) # [batch, max_path_len]
 
 
189
 
190
+ if return_tensors == "pt":
191
+ path_coords = torch.from_numpy(path_coords).float()
192
+ _path_mask = torch.from_numpy(_path_mask).long()
193
+
194
+ else:
195
+ # Numeric list input; process as before
196
+ path_coords = torch.tensor(path_coords, dtype=torch.float32)
197
+ if path_coords.dim() == 2:
198
+ path_coords = path_coords.unsqueeze(0)
199
+
200
+ current_path_len = path_coords.shape[1]
201
+ if truncation and current_path_len > self.max_path_len:
202
+ path_coords = path_coords[:, : self.max_path_len, :]
203
+ if padding and current_path_len < self.max_path_len:
204
+ pad_len = self.max_path_len - current_path_len
205
+ pad_shape = (batch_size, pad_len, self.path_input_dim)
206
+ path_coords = torch.cat([path_coords, torch.zeros(pad_shape)], dim=1)
207
+
208
+ _path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
209
+ is_padding = (path_coords == 0).all(dim=-1)
210
+ _path_mask[is_padding] = 0
211
+ elif isinstance(path_coords, np.ndarray):
212
+ path_coords = torch.from_numpy(path_coords).float()
213
+ if path_coords.dim() == 2:
214
+ path_coords = path_coords.unsqueeze(0)
215
+ # If user provided raw (x,y,t) triples but model expects engineered features,
216
+ # convert to motion features and resample.
217
+ if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
218
+ processed_paths = []
219
+ path_masks = []
220
+ for path in path_coords.cpu().numpy():
221
+ raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
222
+ path_feats, mask = preprocess_raw_path_to_features(
223
+ raw,
224
+ self.max_path_len,
225
+ resample_mode=self.path_resample_mode,
226
+ )
227
+ processed_paths.append(path_feats)
228
+ path_masks.append(mask)
229
+
230
+ path_coords = torch.from_numpy(np.stack(processed_paths)).float()
231
+ _path_mask = torch.from_numpy(np.stack(path_masks)).long()
232
+ else:
233
+ _path_mask = torch.ones(
234
+ path_coords.shape[0], self.max_path_len, dtype=torch.long
235
+ )
236
+ elif isinstance(path_coords, torch.Tensor):
237
+ if path_coords.dim() == 2:
238
+ path_coords = path_coords.unsqueeze(0)
239
+ # If user provided raw (x,y,t) triples but model expects engineered features,
240
+ # convert to motion features and resample.
241
+ if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
242
+ processed_paths = []
243
+ path_masks = []
244
+ for path in path_coords.detach().cpu().numpy():
245
+ raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
246
+ path_feats, mask = preprocess_raw_path_to_features(
247
+ raw,
248
+ self.max_path_len,
249
+ resample_mode=self.path_resample_mode,
250
+ )
251
+ processed_paths.append(path_feats)
252
+ path_masks.append(mask)
253
 
254
+ path_coords = torch.from_numpy(np.stack(processed_paths)).float()
255
+ _path_mask = torch.from_numpy(np.stack(path_masks)).long()
256
+ else:
257
+ _path_mask = torch.ones(
258
+ path_coords.shape[0], self.max_path_len, dtype=torch.long
259
+ )
260
 
261
  result["path_coords"] = path_coords
 
 
262
  else:
263
  # No path coords provided, create empty/zero tensors
264
+ path_coords = torch.zeros(batch_size, self.max_path_len, self.path_input_dim)
265
  _path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
266
  result["path_coords"] = path_coords
267
 
 
296
  # Truncate but preserve EOS at the end
297
  for i in range(len(encoded_raw["input_ids"])):
298
  if len(encoded_raw["input_ids"][i]) > text_max_length:
299
+ encoded_raw["input_ids"][i] = encoded_raw["input_ids"][i][
300
+ : text_max_length - 1
301
+ ] + [eos_id]
302
 
303
  # Pad sequences
304
  if padding:
 
403
  """
404
  return self.tokenizer.decode(token_ids, **kwargs)
405
 
406
+ def encode_path(self, path_coords, *, return_tensors: str | None = "pt", **kwargs: Any):
407
+ """Create model inputs from a swipe path only (no text)."""
408
+ return self(path_coords=path_coords, text=None, return_tensors=return_tensors, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ def encode_text(self, text, *, return_tensors: str | None = "pt", **kwargs: Any):
411
+ """Create model inputs from text only (no path)."""
412
+ return self(path_coords=None, text=text, return_tensors=return_tensors, **kwargs)
413
 
414
+ # Preprocessing methods are now imported from shared preprocessing module
415
+ # See src/swipealot/data/preprocessing.py for the implementation
416
 
417
+ def save_pretrained(
418
+ self,
419
+ save_directory,
420
+ push_to_hub=False,
421
+ **kwargs,
422
+ ):
423
  """
424
+ Save the processor to a directory, ensuring auto_map is included.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  """
426
+ # Call parent save_pretrained
427
+ result = super().save_pretrained(
428
+ save_directory,
429
+ push_to_hub=push_to_hub,
430
+ **kwargs,
431
+ )
432
+
433
+ # Add auto_map to processor_config.json for AutoProcessor compatibility
434
+ import json
435
+ from pathlib import Path
436
+
437
+ # Try both possible config file names
438
+ for config_name in ["preprocessor_config.json", "processor_config.json"]:
439
+ processor_config_path = Path(save_directory) / config_name
440
+ if processor_config_path.exists():
441
+ with open(processor_config_path) as f:
442
+ config = json.load(f)
443
+
444
+ config["auto_map"] = {"AutoProcessor": "processing_swipe.SwipeProcessor"}
445
+
446
+ with open(processor_config_path, "w") as f:
447
+ json.dump(config, f, indent=2)
448
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
+ return result
 
 
 
 
processor_config.json CHANGED
@@ -1,6 +1,8 @@
1
  {
2
  "max_char_len": 48,
3
  "max_path_len": 128,
 
 
4
  "processor_class": "SwipeProcessor",
5
  "auto_map": {
6
  "AutoProcessor": "processing_swipe.SwipeProcessor"
 
1
  {
2
  "max_char_len": 48,
3
  "max_path_len": 128,
4
+ "path_input_dim": 6,
5
+ "path_resample_mode": "time",
6
  "processor_class": "SwipeProcessor",
7
  "auto_map": {
8
  "AutoProcessor": "processing_swipe.SwipeProcessor"
special_tokens_map.json CHANGED
@@ -1,8 +1,44 @@
1
  {
2
- "cls_token": "[CLS]",
3
- "eos_token": "[EOS]",
4
- "mask_token": "[MASK]",
5
- "pad_token": "[PAD]",
6
- "sep_token": "[SEP]",
7
- "unk_token": "[UNK]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
 
1
  {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "[EOS]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "mask_token": {
17
+ "content": "[MASK]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "pad_token": {
24
+ "content": "[PAD]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "sep_token": {
31
+ "content": "[SEP]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "unk_token": {
38
+ "content": "[UNK]",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ }
44
  }
tokenization_swipe.py CHANGED
@@ -117,7 +117,8 @@ class SwipeTokenizer(PreTrainedTokenizer):
117
  Returns:
118
  int: Token ID
119
  """
120
- return self._tokenizer.char_to_id.get(token, self._tokenizer.unk_token_id)
 
121
 
122
  def _convert_id_to_token(self, index: int) -> str:
123
  """
@@ -154,6 +155,41 @@ class SwipeTokenizer(PreTrainedTokenizer):
154
  filtered = [t for t in tokens if t not in special_tokens]
155
  return "".join(filtered)
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple:
158
  """
159
  Save the tokenizer vocabulary to a directory.
 
117
  Returns:
118
  int: Token ID
119
  """
120
+ # Delegate to the core tokenizer to keep token/id mapping logic in one place.
121
+ return self._tokenizer.token_to_id(token)
122
 
123
  def _convert_id_to_token(self, index: int) -> str:
124
  """
 
155
  filtered = [t for t in tokens if t not in special_tokens]
156
  return "".join(filtered)
157
 
158
+ def save_pretrained(
159
+ self,
160
+ save_directory,
161
+ legacy_format=None,
162
+ filename_prefix=None,
163
+ push_to_hub=False,
164
+ **kwargs,
165
+ ):
166
+ """Save the tokenizer and write `auto_map` for `AutoTokenizer` loading."""
167
+ # Call parent save_pretrained
168
+ result = super().save_pretrained(
169
+ save_directory,
170
+ legacy_format=legacy_format,
171
+ filename_prefix=filename_prefix,
172
+ push_to_hub=push_to_hub,
173
+ **kwargs,
174
+ )
175
+
176
+ # Add auto_map to tokenizer_config.json for AutoTokenizer compatibility
177
+ from pathlib import Path
178
+
179
+ tokenizer_config_path = Path(save_directory) / "tokenizer_config.json"
180
+ if tokenizer_config_path.exists():
181
+ with open(tokenizer_config_path) as f:
182
+ config = json.load(f)
183
+
184
+ # For tokenizers, Transformers expects the 2-tuple form: [slow, fast].
185
+ # We only ship a slow tokenizer implementation, so fast is None.
186
+ config["auto_map"] = {"AutoTokenizer": ["tokenization_swipe.SwipeTokenizer", None]}
187
+
188
+ with open(tokenizer_config_path, "w") as f:
189
+ json.dump(config, f, indent=2)
190
+
191
+ return result
192
+
193
  def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple:
194
  """
195
  Save the tokenizer vocabulary to a directory.
tokenizer.py CHANGED
@@ -48,16 +48,27 @@ class CharacterTokenizer:
48
  self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
49
  self.vocab_size = len(self.char_to_id)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def encode(self, text: str) -> list[int]:
52
  """Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
53
- unk_id = self.char_to_id[self.unk_token]
54
- punc_id = self.char_to_id[self.punc_token]
55
  tokens = []
56
  for char in text.lower():
57
- if char.isalpha() or char.isdigit():
58
- tokens.append(self.char_to_id.get(char, unk_id))
59
- else:
60
- tokens.append(punc_id)
61
  return tokens
62
 
63
  def decode(self, token_ids: list[int]) -> str:
 
48
  self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
49
  self.vocab_size = len(self.char_to_id)
50
 
51
+ def encode_char(self, char: str) -> int:
52
+ """Encode a single character to a token id (case-insensitive; punctuation -> [PUNC])."""
53
+ char = char.lower()
54
+ if char.isalpha() or char.isdigit():
55
+ return self.char_to_id.get(char, self.unk_token_id)
56
+ return self.punc_token_id
57
+
58
+ def token_to_id(self, token: str) -> int:
59
+ """Map a token string to its id (supports specials and single characters)."""
60
+ direct = self.char_to_id.get(token)
61
+ if direct is not None:
62
+ return direct
63
+ if len(token) == 1:
64
+ return self.encode_char(token)
65
+ return self.unk_token_id
66
+
67
  def encode(self, text: str) -> list[int]:
68
  """Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
 
 
69
  tokens = []
70
  for char in text.lower():
71
+ tokens.append(self.encode_char(char))
 
 
 
72
  return tokens
73
 
74
  def decode(self, token_ids: list[int]) -> str:
tokenizer_config.json CHANGED
@@ -49,6 +49,12 @@
49
  "special": true
50
  }
51
  },
 
 
 
 
 
 
52
  "clean_up_tokenization_spaces": false,
53
  "cls_token": "[CLS]",
54
  "eos_token": "[EOS]",
@@ -59,11 +65,5 @@
59
  "processor_class": "SwipeProcessor",
60
  "sep_token": "[SEP]",
61
  "tokenizer_class": "SwipeTokenizer",
62
- "unk_token": "[UNK]",
63
- "auto_map": {
64
- "AutoTokenizer": [
65
- "tokenization_swipe.SwipeTokenizer",
66
- null
67
- ]
68
- }
69
  }
 
49
  "special": true
50
  }
51
  },
52
+ "auto_map": {
53
+ "AutoTokenizer": [
54
+ "tokenization_swipe.SwipeTokenizer",
55
+ null
56
+ ]
57
+ },
58
  "clean_up_tokenization_spaces": false,
59
  "cls_token": "[CLS]",
60
  "eos_token": "[EOS]",
 
65
  "processor_class": "SwipeProcessor",
66
  "sep_token": "[SEP]",
67
  "tokenizer_class": "SwipeTokenizer",
68
+ "unk_token": "[UNK]"
 
 
 
 
 
 
69
  }