File size: 18,195 Bytes
0452a9c bf31071 12a6191 04c98bd 391c639 12a6191 d7e4055 3b6d8b5 d7e4055 3b6d8b5 d7e4055 bf31071 e4ab25c 4809086 e4ab25c d7e4055 391c639 bf31071 391c639 bf31071 391c639 bf31071 0452a9c bf31071 0452a9c bf31071 4809086 bf31071 24bee9e bf31071 24bee9e bf31071 24bee9e bf31071 24bee9e bf31071 24bee9e 391c639 bf31071 391c639 58031d0 e80f794 58031d0 bf31071 58031d0 e80f794 bf31071 58031d0 e80f794 bf31071 58031d0 c31b82e 58031d0 bf31071 58031d0 bf31071 58031d0 bf31071 24bee9e bf31071 58031d0 137450c 58031d0 bf31071 58031d0 bf31071 58031d0 bf31071 4809086 bf31071 0452a9c 4809086 0452a9c bf31071 28129b7 bf31071 0452a9c 763dd75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
---
license: apache-2.0
library_name: transformers
tags:
- multimodal
- swipe-keyboard
- gesture-recognition
- text-prediction
- character-prediction
- embeddings
- feature-extraction
language:
- en
datasets:
- futo-org/swipe.futo.org
metrics:
- accuracy
---
# SwipeALot Base Model
> [!IMPORTANT]
> This model is currently in beta status and is subject to change.
> Last updated 2025-12-19
Multimodal, multi-objective transformer for swipe keyboard prediction.
Trained on the [futo-org/swipe.futo.org](https://huggingface.co/datasets/futo-org/swipe.futo.org) dataset.
This model is trained with the following objectives:
- Masked character prediction (MLM)
- Masked path prediction
- Text length prediction (CLS token)
- Path/text embedding (SEP token, contrastive + Matryoshka@ 64, 128, 384, 768)
<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/OV87xy-_ID0TqKW0bfvVq.png" style="width: 400px;">
</p>
## Quick Start (Length Prediction)
```python
from datasets import load_dataset
from transformers import AutoModel, AutoProcessor
model = AutoModel.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
model.eval()
processor = AutoProcessor.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
# Load a sample row from the dataset.
ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]")
row = ds[0] # "Brahmas"
# Length-only inference:
# `encode_path(...)` preprocesses the swipe path to fixed-length motion features and sets text attention to 0.
inputs = processor.encode_path(row["data"], return_tensors="pt")
outputs = model(**inputs, return_dict=True)
# Length prediction is a regression scalar (float); round it for an integer length.
pred_len = float(outputs.length_logits.item())
pred_len_rounded = max(0, int(round(pred_len)))
true_len = sum(1 for c in row["word"].lower() if c.isalpha() or c.isdigit())
print(f'Word: "{row["word"]}"')
print(f"Length (true): {true_len}")
print(f"Length (pred): {pred_len:.3f}")
print(f"Length (pred rounded):{pred_len_rounded}")
```
```text
Word: "Brahmas"
Length (true): 7
Length (pred): 7.483
Length (pred rounded):7
```
## Model Details
- **Architecture**: Transformer encoder (768-dim, 12 layers, 12 heads)
- **Parameters**: 87M
- **Training Data**: futo-org/swipe.futo.org dataset
- **Max Path Length**: 128 points (paths are interpolated down or padded up to this length)
- **Max Word Length**: 48 characters (words are truncated or padded to this length)
- **Vocab Size**: 43 (a-z, 0-9, special tokens)
**Input Constraints:**
- Path coordinates must be normalized to [0, 1] range for x, y
- Timestamps must be normalized to [0, 1] range
- Paths longer than 128 points are downsampled using linear interpolation
- Text longer than 48 characters is truncated with EOS preserved
## Capabilities
<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/H0uEoluxh4FG22XeSeQUI.png" style="width: 800px;">
</p>
### 1. Character Prediction
Predict characters from swipe paths with partial text context.
Trained via masked language modeling with a sophisticated pairwise masking strategy that creates two augmented views of each input for contrastive learning. Training uses focal loss to focus on hard-to-predict characters and frequency-based weighting to handle character imbalance (rare letters like 'z' vs common letters like 'e').
**Pairwise Masking Strategy:**
- **Inverted Mode (80%)**: Asymmetric augmentation pairs
- Query view: Heavy masking (50-70% of path points and characters randomly masked) with gradients
- Key view: Light masking (10-20% of path points and characters randomly masked) with stop gradient
- Teaches robust representations invariant to noise and occlusion
- **Modality Mode (20%)**: Cross-modal alignment pairs
- Query view: Text fully masked, path visible (teaches path → semantic representation) with gradients
- Key view: Path fully masked, text visible (provides alignment target) with stop gradient
- Teaches correspondence between path geometry and text meaning
### 2. Length Prediction
Predict word length from swipe path alone.
Trained as an auxiliary task where the CLS token aggregates path information to predict word length (0-48 characters). This helps the model learn geometric properties of swipe gestures that correlate with word length, such as path extent and complexity.
Length supervision occurs only during modality mode when text attention is fully zeroed (10% of training batches: 20% modality mode × 50% zero-attention probability). This trains the model to predict length from path geometry alone without any text length cues. Uses 10% of the total loss weight to encourage learning without dominating the primary objectives.
### 3. Path Reconstruction
Reconstruct missing path coordinates.
Trained via masked path prediction as part of the pairwise masking strategy. During inverted mode (80% of batches), path points are randomly masked at 50-70% for heavy augmentation and 10-20% for light augmentation. During modality mode (20% of batches), either all path points are masked (key view) or none are masked (query view). The model learns to reconstruct spatial-temporal structure from partial path information and text context, teaching it the geometric and temporal patterns of swipe gestures. Uses 50% of the character prediction loss weight, making it a significant secondary objective.
### 4. Embedding Extraction
Extract fixed-size embeddings for similarity search.
**Dimension**: 768
Trained via contrastive learning where the SEP token produces fixed-size embeddings for path-text pairs. The pairwise masking strategy is central to embedding training:
- **Inverted mode (80%)**: Pulls embeddings of heavily-masked and lightly-masked versions of the same input close together, teaching invariance to noise and occlusion
- **Modality mode (20%)**: Pulls embeddings of path-only and text-only views of the same word close together, teaching cross-modal alignment between gesture geometry and semantic meaning
The contrastive loss (10-20% weight, temperature 0.07) pulls matching pairs together in embedding space while pushing non-matches apart. Uses Matryoshka embeddings to create nested representations at multiple dimensions (64, 128, 384, 768), with stronger weight on lower-dimensional representations (2.0×, 1.5×, 1.0×, 1.0×) to ensure the first 64 dimensions are highly informative on their own.
## More Usage Examples
### Embedding Similarity
Modality attention masking adds a similar task to CLIP-style models.
The model can output vector representations of words or paths (or both), with high similarity.
```python
# Continuing from above (reuses `model` and `processor`):
#
# Goal: show matching via embeddings.
# - "word-only" embedding: `processor.encode_text(...)` (equivalent to `text=...` + `path_coords=None`)
# -> path attention is all zeros.
# - "path-only" embedding: `processor.encode_path(...)` (equivalent to `path_coords=...` + `text=None`)
# -> text attention is all zeros.
#
# We then compare cosine similarity:
# sim(path(row0), word(row0)) should be higher than sim(path(row0), word(row1)).
import numpy as np
ds = load_dataset("futo-org/swipe.futo.org", split="test[:200]")
row0 = ds[0] # "Brahmas"
row1 = ds[7] # "central"
word_inputs = processor.encode_text([row0["word"], row1["word"]], return_tensors="pt")
word_out = model(**word_inputs, return_dict=True)
word_emb = word_out.pooler_output.detach().cpu().numpy() # shape: [2, d_model]
path_inputs = processor.encode_path(row0["data"], return_tensors="pt")
path_out = model(**path_inputs, return_dict=True)
path_emb0 = path_out.pooler_output.detach().cpu().numpy()[0] # shape: [d_model]
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
sim_pos = cosine_similarity(path_emb0, word_emb[0])
sim_neg = cosine_similarity(path_emb0, word_emb[1])
print(f'Row0 word: "{row0["word"]}"')
print(f'Row1 word: "{row1["word"]}"')
print(f"Cosine similarity [positive]: {sim_pos:.4f}")
print(f"Cosine similarity [negative]: {sim_neg:.4f}")
```
```text
Row0 word: "Brahmas"
Row1 word: "central"
Cosine similarity [positive]: 0.7927
Cosine similarity [negative]: -0.0117
```
### Word Reconstruction "Blind Reconstruction"
Here's how you can do a 2-step prediction, first predicting the word length (to get # of masks),
and then using mask prediction to fill the word.
```python
# Continuing from above (reuses `model` and `processor`):
#
# Word reconstruction (unknown length):
# 1) Run path-only inference to predict the length.
# 2) Create a text segment of `[MASK] * predicted_length + [EOS]`, enable text attention for it,
# then reconstruct the characters from the path.
tokenizer = processor.tokenizer
ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]")
row = ds[0] # "Brahmas"
inputs = processor.encode_path(row["data"], return_tensors="pt")
pad_id = int(tokenizer.pad_token_id)
mask_id = int(tokenizer.mask_token_id)
eos_id = int(getattr(tokenizer, "eos_token_id", -1))
pred_len = float(model(**inputs, return_dict=True).length_logits.item())
pred_len_rounded = max(0, int(round(pred_len)))
pred_len_rounded = min(pred_len_rounded, int(processor.max_char_len) - 1) # reserve 1 for EOS
# Overwrite the padded text segment from `encode_path(...)` with `[MASK]... [EOS]`.
inputs["input_ids"].fill_(pad_id)
inputs["input_ids"][:, :pred_len_rounded].fill_(mask_id)
inputs["input_ids"][:, pred_len_rounded].fill_(eos_id)
# Enable text attention up to and including EOS.
char_start = 1 + int(processor.max_path_len) + 1 # [CLS] + path + [SEP]
inputs["attention_mask"][:, char_start:].fill_(0)
inputs["attention_mask"][:, char_start : char_start + pred_len_rounded + 1].fill_(1)
outputs = model(
**inputs,
return_dict=True,
)
pred_ids = outputs.char_logits.argmax(dim=-1)[0].detach().cpu().tolist()
pred_word = tokenizer.decode(pred_ids[:pred_len_rounded]).strip().lower()
print(f'Word: "{row["word"]}"')
print(f'Reconstructed word: "{pred_word}"')
```
```text
Word: "Brahmas"
Reconstructed word: "brahmas"
```
## Performance Metrics
Evaluated on 49,970 test samples:
| Task | Metric | Score |
|------|--------|-------|
| Masked Prediction (30%) | Character Accuracy | 98.4% |
| | Top-3 Accuracy | 99.9% |
| | Word Accuracy | 97.2% |
| Full Reconstruction (100%) | Character Accuracy | 95.6% |
| | Word Accuracy | 89.3% |
| Blind Reconstruction (2-pass) | Character Accuracy | 92.8% |
| | Word Accuracy | 87.0% |
| Length Prediction | Exact Accuracy | 93.0% |
| | Within ±1 | 99.4% |
| | Within ±2 | 99.9% |
| Path Reconstruction | MSE (masked; dims=x/y) | 0.000090 |
## Model Outputs
```python
outputs = model(**inputs)
# Available outputs:
outputs.char_logits # [batch, char_len, vocab_size] - Character predictions
outputs.length_logits # [batch, 1] - Length predictions
outputs.path_logits # [batch, path_len, 3] - Path coordinate predictions
outputs.pooler_output # [batch, d_model] - SEP token embeddings for similarity
outputs.last_hidden_state # [batch, seq_len, d_model] - Hidden representations
```
```text
char_len = 48
path_len = 128
seq_len = 178
d_model = 768
```
## Citation
```bibtex
@software{swipealot2025,
title={SwipeALot: Multimodal Swipe Keyboard Transformer},
author={Lee Miller},
year={2025},
url={https://huggingface.co/dleemiller/SwipeALot-base}
}
```
<p align="center">
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/gJLt6WE5iJiLofbSiyIt8.png" style="width: 400px;">
</p>
## License
Apache 2.0
|