|
|
--- |
|
|
license: apache-2.0 |
|
|
library_name: transformers |
|
|
tags: |
|
|
- multimodal |
|
|
- swipe-keyboard |
|
|
- gesture-recognition |
|
|
- text-prediction |
|
|
- character-prediction |
|
|
- embeddings |
|
|
- feature-extraction |
|
|
language: |
|
|
- en |
|
|
datasets: |
|
|
- futo-org/swipe.futo.org |
|
|
metrics: |
|
|
- accuracy |
|
|
--- |
|
|
|
|
|
# SwipeALot Base Model |
|
|
|
|
|
> [!IMPORTANT] |
|
|
> This model is currently in beta status and is subject to change. |
|
|
> Last updated 2025-12-19 |
|
|
|
|
|
Multimodal, multi-objective transformer for swipe keyboard prediction. |
|
|
Trained on the [futo-org/swipe.futo.org](https://huggingface.co/datasets/futo-org/swipe.futo.org) dataset. |
|
|
|
|
|
This model is trained with the following objectives: |
|
|
- Masked character prediction (MLM) |
|
|
- Masked path prediction |
|
|
- Text length prediction (CLS token) |
|
|
- Path/text embedding (SEP token, contrastive + Matryoshka@ 64, 128, 384, 768) |
|
|
|
|
|
<p align="center"> |
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/OV87xy-_ID0TqKW0bfvVq.png" style="width: 400px;"> |
|
|
</p> |
|
|
|
|
|
|
|
|
|
|
|
## Quick Start (Length Prediction) |
|
|
|
|
|
```python |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
|
|
|
|
|
|
model = AutoModel.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True) |
|
|
model.eval() |
|
|
processor = AutoProcessor.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True) |
|
|
|
|
|
# Load a sample row from the dataset. |
|
|
ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]") |
|
|
row = ds[0] # "Brahmas" |
|
|
|
|
|
# Length-only inference: |
|
|
# `encode_path(...)` preprocesses the swipe path to fixed-length motion features and sets text attention to 0. |
|
|
inputs = processor.encode_path(row["data"], return_tensors="pt") |
|
|
outputs = model(**inputs, return_dict=True) |
|
|
|
|
|
# Length prediction is a regression scalar (float); round it for an integer length. |
|
|
pred_len = float(outputs.length_logits.item()) |
|
|
pred_len_rounded = max(0, int(round(pred_len))) |
|
|
true_len = sum(1 for c in row["word"].lower() if c.isalpha() or c.isdigit()) |
|
|
|
|
|
print(f'Word: "{row["word"]}"') |
|
|
print(f"Length (true): {true_len}") |
|
|
print(f"Length (pred): {pred_len:.3f}") |
|
|
print(f"Length (pred rounded):{pred_len_rounded}") |
|
|
``` |
|
|
|
|
|
```text |
|
|
Word: "Brahmas" |
|
|
Length (true): 7 |
|
|
Length (pred): 7.483 |
|
|
Length (pred rounded):7 |
|
|
``` |
|
|
|
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Architecture**: Transformer encoder (768-dim, 12 layers, 12 heads) |
|
|
- **Parameters**: 87M |
|
|
- **Training Data**: futo-org/swipe.futo.org dataset |
|
|
- **Max Path Length**: 128 points (paths are interpolated down or padded up to this length) |
|
|
- **Max Word Length**: 48 characters (words are truncated or padded to this length) |
|
|
- **Vocab Size**: 43 (a-z, 0-9, special tokens) |
|
|
|
|
|
**Input Constraints:** |
|
|
- Path coordinates must be normalized to [0, 1] range for x, y |
|
|
- Timestamps must be normalized to [0, 1] range |
|
|
- Paths longer than 128 points are downsampled using linear interpolation |
|
|
- Text longer than 48 characters is truncated with EOS preserved |
|
|
|
|
|
## Capabilities |
|
|
|
|
|
<p align="center"> |
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/H0uEoluxh4FG22XeSeQUI.png" style="width: 800px;"> |
|
|
</p> |
|
|
|
|
|
### 1. Character Prediction |
|
|
Predict characters from swipe paths with partial text context. |
|
|
|
|
|
Trained via masked language modeling with a sophisticated pairwise masking strategy that creates two augmented views of each input for contrastive learning. Training uses focal loss to focus on hard-to-predict characters and frequency-based weighting to handle character imbalance (rare letters like 'z' vs common letters like 'e'). |
|
|
|
|
|
**Pairwise Masking Strategy:** |
|
|
- **Inverted Mode (80%)**: Asymmetric augmentation pairs |
|
|
- Query view: Heavy masking (50-70% of path points and characters randomly masked) with gradients |
|
|
- Key view: Light masking (10-20% of path points and characters randomly masked) with stop gradient |
|
|
- Teaches robust representations invariant to noise and occlusion |
|
|
|
|
|
- **Modality Mode (20%)**: Cross-modal alignment pairs |
|
|
- Query view: Text fully masked, path visible (teaches path → semantic representation) with gradients |
|
|
- Key view: Path fully masked, text visible (provides alignment target) with stop gradient |
|
|
- Teaches correspondence between path geometry and text meaning |
|
|
|
|
|
### 2. Length Prediction |
|
|
Predict word length from swipe path alone. |
|
|
|
|
|
Trained as an auxiliary task where the CLS token aggregates path information to predict word length (0-48 characters). This helps the model learn geometric properties of swipe gestures that correlate with word length, such as path extent and complexity. |
|
|
|
|
|
Length supervision occurs only during modality mode when text attention is fully zeroed (10% of training batches: 20% modality mode × 50% zero-attention probability). This trains the model to predict length from path geometry alone without any text length cues. Uses 10% of the total loss weight to encourage learning without dominating the primary objectives. |
|
|
|
|
|
### 3. Path Reconstruction |
|
|
Reconstruct missing path coordinates. |
|
|
|
|
|
Trained via masked path prediction as part of the pairwise masking strategy. During inverted mode (80% of batches), path points are randomly masked at 50-70% for heavy augmentation and 10-20% for light augmentation. During modality mode (20% of batches), either all path points are masked (key view) or none are masked (query view). The model learns to reconstruct spatial-temporal structure from partial path information and text context, teaching it the geometric and temporal patterns of swipe gestures. Uses 50% of the character prediction loss weight, making it a significant secondary objective. |
|
|
|
|
|
### 4. Embedding Extraction |
|
|
Extract fixed-size embeddings for similarity search. |
|
|
|
|
|
**Dimension**: 768 |
|
|
|
|
|
Trained via contrastive learning where the SEP token produces fixed-size embeddings for path-text pairs. The pairwise masking strategy is central to embedding training: |
|
|
- **Inverted mode (80%)**: Pulls embeddings of heavily-masked and lightly-masked versions of the same input close together, teaching invariance to noise and occlusion |
|
|
- **Modality mode (20%)**: Pulls embeddings of path-only and text-only views of the same word close together, teaching cross-modal alignment between gesture geometry and semantic meaning |
|
|
|
|
|
The contrastive loss (10-20% weight, temperature 0.07) pulls matching pairs together in embedding space while pushing non-matches apart. Uses Matryoshka embeddings to create nested representations at multiple dimensions (64, 128, 384, 768), with stronger weight on lower-dimensional representations (2.0×, 1.5×, 1.0×, 1.0×) to ensure the first 64 dimensions are highly informative on their own. |
|
|
|
|
|
## More Usage Examples |
|
|
|
|
|
### Embedding Similarity |
|
|
|
|
|
Modality attention masking adds a similar task to CLIP-style models. |
|
|
The model can output vector representations of words or paths (or both), with high similarity. |
|
|
|
|
|
```python |
|
|
# Continuing from above (reuses `model` and `processor`): |
|
|
# |
|
|
# Goal: show matching via embeddings. |
|
|
# - "word-only" embedding: `processor.encode_text(...)` (equivalent to `text=...` + `path_coords=None`) |
|
|
# -> path attention is all zeros. |
|
|
# - "path-only" embedding: `processor.encode_path(...)` (equivalent to `path_coords=...` + `text=None`) |
|
|
# -> text attention is all zeros. |
|
|
# |
|
|
# We then compare cosine similarity: |
|
|
# sim(path(row0), word(row0)) should be higher than sim(path(row0), word(row1)). |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
ds = load_dataset("futo-org/swipe.futo.org", split="test[:200]") |
|
|
row0 = ds[0] # "Brahmas" |
|
|
row1 = ds[7] # "central" |
|
|
|
|
|
word_inputs = processor.encode_text([row0["word"], row1["word"]], return_tensors="pt") |
|
|
word_out = model(**word_inputs, return_dict=True) |
|
|
word_emb = word_out.pooler_output.detach().cpu().numpy() # shape: [2, d_model] |
|
|
|
|
|
path_inputs = processor.encode_path(row0["data"], return_tensors="pt") |
|
|
path_out = model(**path_inputs, return_dict=True) |
|
|
path_emb0 = path_out.pooler_output.detach().cpu().numpy()[0] # shape: [d_model] |
|
|
|
|
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: |
|
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) |
|
|
|
|
|
sim_pos = cosine_similarity(path_emb0, word_emb[0]) |
|
|
sim_neg = cosine_similarity(path_emb0, word_emb[1]) |
|
|
|
|
|
print(f'Row0 word: "{row0["word"]}"') |
|
|
print(f'Row1 word: "{row1["word"]}"') |
|
|
print(f"Cosine similarity [positive]: {sim_pos:.4f}") |
|
|
print(f"Cosine similarity [negative]: {sim_neg:.4f}") |
|
|
``` |
|
|
|
|
|
|
|
|
```text |
|
|
Row0 word: "Brahmas" |
|
|
Row1 word: "central" |
|
|
Cosine similarity [positive]: 0.7927 |
|
|
Cosine similarity [negative]: -0.0117 |
|
|
``` |
|
|
|
|
|
|
|
|
### Word Reconstruction "Blind Reconstruction" |
|
|
|
|
|
Here's how you can do a 2-step prediction, first predicting the word length (to get # of masks), |
|
|
and then using mask prediction to fill the word. |
|
|
|
|
|
```python |
|
|
# Continuing from above (reuses `model` and `processor`): |
|
|
# |
|
|
# Word reconstruction (unknown length): |
|
|
# 1) Run path-only inference to predict the length. |
|
|
# 2) Create a text segment of `[MASK] * predicted_length + [EOS]`, enable text attention for it, |
|
|
# then reconstruct the characters from the path. |
|
|
|
|
|
tokenizer = processor.tokenizer |
|
|
|
|
|
ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]") |
|
|
row = ds[0] # "Brahmas" |
|
|
|
|
|
inputs = processor.encode_path(row["data"], return_tensors="pt") |
|
|
|
|
|
pad_id = int(tokenizer.pad_token_id) |
|
|
mask_id = int(tokenizer.mask_token_id) |
|
|
eos_id = int(getattr(tokenizer, "eos_token_id", -1)) |
|
|
|
|
|
pred_len = float(model(**inputs, return_dict=True).length_logits.item()) |
|
|
pred_len_rounded = max(0, int(round(pred_len))) |
|
|
pred_len_rounded = min(pred_len_rounded, int(processor.max_char_len) - 1) # reserve 1 for EOS |
|
|
|
|
|
# Overwrite the padded text segment from `encode_path(...)` with `[MASK]... [EOS]`. |
|
|
inputs["input_ids"].fill_(pad_id) |
|
|
inputs["input_ids"][:, :pred_len_rounded].fill_(mask_id) |
|
|
inputs["input_ids"][:, pred_len_rounded].fill_(eos_id) |
|
|
|
|
|
# Enable text attention up to and including EOS. |
|
|
char_start = 1 + int(processor.max_path_len) + 1 # [CLS] + path + [SEP] |
|
|
inputs["attention_mask"][:, char_start:].fill_(0) |
|
|
inputs["attention_mask"][:, char_start : char_start + pred_len_rounded + 1].fill_(1) |
|
|
|
|
|
outputs = model( |
|
|
**inputs, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
pred_ids = outputs.char_logits.argmax(dim=-1)[0].detach().cpu().tolist() |
|
|
pred_word = tokenizer.decode(pred_ids[:pred_len_rounded]).strip().lower() |
|
|
|
|
|
print(f'Word: "{row["word"]}"') |
|
|
print(f'Reconstructed word: "{pred_word}"') |
|
|
``` |
|
|
|
|
|
```text |
|
|
Word: "Brahmas" |
|
|
Reconstructed word: "brahmas" |
|
|
``` |
|
|
|
|
|
|
|
|
## Performance Metrics |
|
|
|
|
|
Evaluated on 49,970 test samples: |
|
|
|
|
|
| Task | Metric | Score | |
|
|
|------|--------|-------| |
|
|
| Masked Prediction (30%) | Character Accuracy | 98.4% | |
|
|
| | Top-3 Accuracy | 99.9% | |
|
|
| | Word Accuracy | 97.2% | |
|
|
| Full Reconstruction (100%) | Character Accuracy | 95.6% | |
|
|
| | Word Accuracy | 89.3% | |
|
|
| Blind Reconstruction (2-pass) | Character Accuracy | 92.8% | |
|
|
| | Word Accuracy | 87.0% | |
|
|
| Length Prediction | Exact Accuracy | 93.0% | |
|
|
| | Within ±1 | 99.4% | |
|
|
| | Within ±2 | 99.9% | |
|
|
| Path Reconstruction | MSE (masked; dims=x/y) | 0.000090 | |
|
|
|
|
|
|
|
|
## Model Outputs |
|
|
|
|
|
```python |
|
|
outputs = model(**inputs) |
|
|
|
|
|
# Available outputs: |
|
|
outputs.char_logits # [batch, char_len, vocab_size] - Character predictions |
|
|
outputs.length_logits # [batch, 1] - Length predictions |
|
|
outputs.path_logits # [batch, path_len, 3] - Path coordinate predictions |
|
|
outputs.pooler_output # [batch, d_model] - SEP token embeddings for similarity |
|
|
outputs.last_hidden_state # [batch, seq_len, d_model] - Hidden representations |
|
|
``` |
|
|
|
|
|
```text |
|
|
char_len = 48 |
|
|
path_len = 128 |
|
|
seq_len = 178 |
|
|
d_model = 768 |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@software{swipealot2025, |
|
|
title={SwipeALot: Multimodal Swipe Keyboard Transformer}, |
|
|
author={Lee Miller}, |
|
|
year={2025}, |
|
|
url={https://huggingface.co/dleemiller/SwipeALot-base} |
|
|
} |
|
|
``` |
|
|
|
|
|
<p align="center"> |
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/65ff92ea467d83751a727538/gJLt6WE5iJiLofbSiyIt8.png" style="width: 400px;"> |
|
|
</p> |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|
|
|
|