File size: 18,195 Bytes
0452a9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
 
12a6191
04c98bd
391c639
12a6191
d7e4055
 
3b6d8b5
 
d7e4055
3b6d8b5
 
d7e4055
bf31071
e4ab25c
4809086
e4ab25c
 
d7e4055
 
391c639
bf31071
391c639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
 
391c639
bf31071
 
 
 
 
0452a9c
 
bf31071
 
0452a9c
 
 
 
 
 
bf31071
 
4809086
 
 
 
bf31071
 
 
24bee9e
 
 
 
 
 
 
 
 
 
 
 
bf31071
 
 
 
24bee9e
bf31071
24bee9e
bf31071
 
 
 
24bee9e
bf31071
 
 
 
 
 
24bee9e
 
 
 
391c639
bf31071
391c639
58031d0
 
 
 
e80f794
58031d0
bf31071
58031d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e80f794
 
bf31071
 
 
58031d0
 
 
e80f794
 
bf31071
 
58031d0
c31b82e
58031d0
 
 
bf31071
 
58031d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
 
58031d0
bf31071
 
24bee9e
bf31071
 
 
58031d0
 
 
 
 
137450c
 
58031d0
 
 
 
 
bf31071
 
 
 
 
 
 
58031d0
 
 
bf31071
 
 
 
58031d0
 
 
 
 
 
 
bf31071
 
 
4809086
bf31071
0452a9c
4809086
0452a9c
bf31071
 
 
28129b7
 
 
 
bf31071
 
0452a9c
763dd75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
---
license: apache-2.0
library_name: transformers
tags:
- multimodal
- swipe-keyboard
- gesture-recognition
- text-prediction
- character-prediction
- embeddings
- feature-extraction
language:
- en
datasets:
- futo-org/swipe.futo.org
metrics:
- accuracy
---

# SwipeALot Base Model

> [!IMPORTANT]
> This model is currently in beta status and is subject to change.
> Last updated 2025-12-19

Multimodal, multi-objective transformer for swipe keyboard prediction.
Trained on the [futo-org/swipe.futo.org](https://huggingface.co/datasets/futo-org/swipe.futo.org) dataset.

This model is trained with the following objectives:
- Masked character prediction (MLM)
- Masked path prediction
- Text length prediction (CLS token)
- Path/text embedding (SEP token, contrastive + Matryoshka@ 64, 128, 384, 768)

<p align="center">
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/OV87xy-_ID0TqKW0bfvVq.png" style="width: 400px;">
</p>



## Quick Start (Length Prediction)

```python                                                                                                                                                                                                                            
from datasets import load_dataset                                                                                                                                                                                                    
from transformers import AutoModel, AutoProcessor                                                                                                                                                                                    
                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                     
model = AutoModel.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)                                                                                                                                               
model.eval()                                                                                                                                                                                                                         
processor = AutoProcessor.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)                                                                                                                                       
                                                                                                                                                                                                                                     
# Load a sample row from the dataset.                                                                                                                                                                                                
ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]")                                                                                                                                                                      
row = ds[0]  # "Brahmas"                                                                                                                                                                                                             
                                                                                                                                                                                                                                     
# Length-only inference:                                                                                                                                                                                                             
# `encode_path(...)` preprocesses the swipe path to fixed-length motion features and sets text attention to 0.                                                                                                                       
inputs = processor.encode_path(row["data"], return_tensors="pt")                                                                                                                                                                     
outputs = model(**inputs, return_dict=True)                                                                                                                                                                                          
                                                                                                                                                                                                                                     
# Length prediction is a regression scalar (float); round it for an integer length.                                                                                                                                                  
pred_len = float(outputs.length_logits.item())                                                                                                                                                                                       
pred_len_rounded = max(0, int(round(pred_len)))                                                                                                                                                                                      
true_len = sum(1 for c in row["word"].lower() if c.isalpha() or c.isdigit())                                                                                                                                                         
                                                                                                                                                                                                                                     
print(f'Word:                 "{row["word"]}"')                                                                                                                                                                                      
print(f"Length (true):        {true_len}")                                                                                                                                                                                           
print(f"Length (pred):        {pred_len:.3f}")                                                                                                                                                                                       
print(f"Length (pred rounded):{pred_len_rounded}")                                                                                                                                                                                   
```                                                                                                                                                                                                                                  

```text
Word:                 "Brahmas"                    
Length (true):        7                                                                                           
Length (pred):        7.483                        
Length (pred rounded):7
```


## Model Details

- **Architecture**: Transformer encoder (768-dim, 12 layers, 12 heads)
- **Parameters**: 87M
- **Training Data**: futo-org/swipe.futo.org dataset
- **Max Path Length**: 128 points (paths are interpolated down or padded up to this length)
- **Max Word Length**: 48 characters (words are truncated or padded to this length)
- **Vocab Size**: 43 (a-z, 0-9, special tokens)

**Input Constraints:**
- Path coordinates must be normalized to [0, 1] range for x, y
- Timestamps must be normalized to [0, 1] range
- Paths longer than 128 points are downsampled using linear interpolation
- Text longer than 48 characters is truncated with EOS preserved

## Capabilities

<p align="center">
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/H0uEoluxh4FG22XeSeQUI.png" style="width: 800px;">
</p>

### 1. Character Prediction
Predict characters from swipe paths with partial text context.

Trained via masked language modeling with a sophisticated pairwise masking strategy that creates two augmented views of each input for contrastive learning. Training uses focal loss to focus on hard-to-predict characters and frequency-based weighting to handle character imbalance (rare letters like 'z' vs common letters like 'e').

**Pairwise Masking Strategy:**
- **Inverted Mode (80%)**: Asymmetric augmentation pairs
  - Query view: Heavy masking (50-70% of path points and characters randomly masked) with gradients
  - Key view: Light masking (10-20% of path points and characters randomly masked) with stop gradient
  - Teaches robust representations invariant to noise and occlusion

- **Modality Mode (20%)**: Cross-modal alignment pairs
  - Query view: Text fully masked, path visible (teaches path → semantic representation) with gradients
  - Key view: Path fully masked, text visible (provides alignment target) with stop gradient
  - Teaches correspondence between path geometry and text meaning

### 2. Length Prediction
Predict word length from swipe path alone.

Trained as an auxiliary task where the CLS token aggregates path information to predict word length (0-48 characters). This helps the model learn geometric properties of swipe gestures that correlate with word length, such as path extent and complexity.

Length supervision occurs only during modality mode when text attention is fully zeroed (10% of training batches: 20% modality mode × 50% zero-attention probability). This trains the model to predict length from path geometry alone without any text length cues. Uses 10% of the total loss weight to encourage learning without dominating the primary objectives.

### 3. Path Reconstruction
Reconstruct missing path coordinates.

Trained via masked path prediction as part of the pairwise masking strategy. During inverted mode (80% of batches), path points are randomly masked at 50-70% for heavy augmentation and 10-20% for light augmentation. During modality mode (20% of batches), either all path points are masked (key view) or none are masked (query view). The model learns to reconstruct spatial-temporal structure from partial path information and text context, teaching it the geometric and temporal patterns of swipe gestures. Uses 50% of the character prediction loss weight, making it a significant secondary objective.

### 4. Embedding Extraction
Extract fixed-size embeddings for similarity search.

**Dimension**: 768

Trained via contrastive learning where the SEP token produces fixed-size embeddings for path-text pairs. The pairwise masking strategy is central to embedding training:
- **Inverted mode (80%)**: Pulls embeddings of heavily-masked and lightly-masked versions of the same input close together, teaching invariance to noise and occlusion
- **Modality mode (20%)**: Pulls embeddings of path-only and text-only views of the same word close together, teaching cross-modal alignment between gesture geometry and semantic meaning

The contrastive loss (10-20% weight, temperature 0.07) pulls matching pairs together in embedding space while pushing non-matches apart. Uses Matryoshka embeddings to create nested representations at multiple dimensions (64, 128, 384, 768), with stronger weight on lower-dimensional representations (2.0×, 1.5×, 1.0×, 1.0×) to ensure the first 64 dimensions are highly informative on their own.

## More Usage Examples

### Embedding Similarity                                                                                          

Modality attention masking adds a similar task to CLIP-style models.
The model can output vector representations of words or paths (or both), with high similarity.
                                                         
```python
# Continuing from above (reuses `model` and `processor`):       
#
# Goal: show matching via embeddings.
# - "word-only" embedding: `processor.encode_text(...)` (equivalent to `text=...` + `path_coords=None`)
#   -> path attention is all zeros.                 
# - "path-only" embedding: `processor.encode_path(...)` (equivalent to `path_coords=...` + `text=None`)
#   -> text attention is all zeros.                                                                               
#                                              
# We then compare cosine similarity:                                                                              
#   sim(path(row0), word(row0))  should be higher than  sim(path(row0), word(row1)).
                                                                                                                  
import numpy as np               
                                                         
ds = load_dataset("futo-org/swipe.futo.org", split="test[:200]")
row0 = ds[0]  # "Brahmas"
row1 = ds[7]  # "central"                       
                                                                                                                  
word_inputs = processor.encode_text([row0["word"], row1["word"]], return_tensors="pt")
word_out = model(**word_inputs, return_dict=True)                                                                 
word_emb = word_out.pooler_output.detach().cpu().numpy()  # shape: [2, d_model]
                                                         
path_inputs = processor.encode_path(row0["data"], return_tensors="pt")
path_out = model(**path_inputs, return_dict=True)
path_emb0 = path_out.pooler_output.detach().cpu().numpy()[0]  # shape: [d_model]

def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:           
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) 

sim_pos = cosine_similarity(path_emb0, word_emb[0])
sim_neg = cosine_similarity(path_emb0, word_emb[1])
                                                         
print(f'Row0 word: "{row0["word"]}"')
print(f'Row1 word: "{row1["word"]}"')   
print(f"Cosine similarity [positive]: {sim_pos:.4f}")
print(f"Cosine similarity [negative]: {sim_neg:.4f}")
```


```text
Row0 word: "Brahmas"
Row1 word: "central"
Cosine similarity [positive]: 0.7927
Cosine similarity [negative]: -0.0117
```


### Word Reconstruction "Blind Reconstruction"

Here's how you can do a 2-step prediction, first predicting the word length (to get # of masks),
and then using mask prediction to fill the word.

```python
# Continuing from above (reuses `model` and `processor`):
#
# Word reconstruction (unknown length):
# 1) Run path-only inference to predict the length.
# 2) Create a text segment of `[MASK] * predicted_length + [EOS]`, enable text attention for it,
#    then reconstruct the characters from the path.

tokenizer = processor.tokenizer

ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]")
row = ds[0]  # "Brahmas"

inputs = processor.encode_path(row["data"], return_tensors="pt")

pad_id = int(tokenizer.pad_token_id)
mask_id = int(tokenizer.mask_token_id)
eos_id = int(getattr(tokenizer, "eos_token_id", -1))

pred_len = float(model(**inputs, return_dict=True).length_logits.item())
pred_len_rounded = max(0, int(round(pred_len)))
pred_len_rounded = min(pred_len_rounded, int(processor.max_char_len) - 1)  # reserve 1 for EOS

# Overwrite the padded text segment from `encode_path(...)` with `[MASK]... [EOS]`.
inputs["input_ids"].fill_(pad_id)
inputs["input_ids"][:, :pred_len_rounded].fill_(mask_id)
inputs["input_ids"][:, pred_len_rounded].fill_(eos_id)

# Enable text attention up to and including EOS.
char_start = 1 + int(processor.max_path_len) + 1  # [CLS] + path + [SEP]
inputs["attention_mask"][:, char_start:].fill_(0)
inputs["attention_mask"][:, char_start : char_start + pred_len_rounded + 1].fill_(1)

outputs = model(
    **inputs,
    return_dict=True,
)

pred_ids = outputs.char_logits.argmax(dim=-1)[0].detach().cpu().tolist()
pred_word = tokenizer.decode(pred_ids[:pred_len_rounded]).strip().lower()

print(f'Word:                 "{row["word"]}"')
print(f'Reconstructed word:   "{pred_word}"')
```

```text
Word:                 "Brahmas"
Reconstructed word:   "brahmas"
```


## Performance Metrics

Evaluated on 49,970 test samples:

| Task | Metric | Score |
|------|--------|-------|
| Masked Prediction (30%) | Character Accuracy | 98.4% |
|  | Top-3 Accuracy | 99.9% |
|  | Word Accuracy | 97.2% |
| Full Reconstruction (100%) | Character Accuracy | 95.6% |
|  | Word Accuracy | 89.3% |
| Blind Reconstruction (2-pass) | Character Accuracy | 92.8% |
|  | Word Accuracy | 87.0% |
| Length Prediction | Exact Accuracy | 93.0% |
|  | Within ±1 | 99.4% |
|  | Within ±2 | 99.9% |
| Path Reconstruction | MSE (masked; dims=x/y) | 0.000090 |


## Model Outputs

```python
outputs = model(**inputs)

# Available outputs:
outputs.char_logits       # [batch, char_len, vocab_size] - Character predictions
outputs.length_logits     # [batch, 1] - Length predictions
outputs.path_logits       # [batch, path_len, 3] - Path coordinate predictions
outputs.pooler_output     # [batch, d_model] - SEP token embeddings for similarity
outputs.last_hidden_state # [batch, seq_len, d_model] - Hidden representations
```

```text
char_len = 48
path_len = 128
seq_len = 178
d_model = 768
```

## Citation

```bibtex
@software{swipealot2025,
  title={SwipeALot: Multimodal Swipe Keyboard Transformer},
  author={Lee Miller},
  year={2025},
  url={https://huggingface.co/dleemiller/SwipeALot-base}
}
```

<p align="center">
  <img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/gJLt6WE5iJiLofbSiyIt8.png" style="width: 400px;">
</p>

## License

Apache 2.0