dleemiller commited on
Commit
391c639
·
verified ·
1 Parent(s): 58031d0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +40 -81
README.md CHANGED
@@ -21,6 +21,7 @@ metrics:
21
 
22
  > [!IMPORTANT]
23
  > This model is currently in beta status and is subject to change.
 
24
 
25
  Multimodal, multi-objective transformer for swipe keyboard prediction.
26
  Trained on the [futo-org/swipe.futo.org](https://huggingface.co/datasets/futo-org/swipe.futo.org) dataset.
@@ -36,47 +37,47 @@ This model is trained with the following objectives:
36
  </p>
37
 
38
 
39
- > [!NOTE]
40
- > This model should be further fine-tuned for a specific task, if not using the embedding mode.
41
- > For example, length prediction can be significantly improved in a single task setting.
42
 
43
- ## Quick Start
44
 
45
- ```python
46
- from datasets import load_dataset
47
- from transformers import AutoModel, AutoProcessor
48
- import torch
49
-
50
- # Load model
51
- model = AutoModel.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
52
- processor = AutoProcessor.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
53
- model.eval()
54
-
55
- # Load sample
56
- dataset = load_dataset("futo-org/swipe.futo.org", split="test[:10]")
57
- item = dataset[4]
58
-
59
- # Preprocess swipe path using processor methods
60
- # 1. Normalize timestamps (x,y already normalized in futo dataset)
61
- normalized = processor.normalize_coordinates(item["data"], item["canvas_width"], item["canvas_height"])
62
-
63
- # 2. Resample to fixed length (max_path_len=128)
64
- # - Pads with zeros if path < 128 points
65
- # - Interpolates if path > 128 points
66
- path_coords, _ = processor.sample_path_points(normalized, processor.max_path_len)
67
- path = torch.tensor([path_coords], dtype=torch.float32)
68
-
69
- # Get predictions
70
- inputs = processor(path_coords=path, text=None, return_tensors="pt")
71
-
72
- with torch.no_grad():
73
- outputs = model(**inputs)
74
-
75
- # Length prediction
76
- predicted_length = outputs.length_logits.argmax(dim=-1).item()
77
- print(f"Predicted word length: {predicted_length}")
 
 
78
  ```
79
 
 
80
  ## Model Details
81
 
82
  - **Architecture**: Transformer encoder (768-dim, 12 layers, 12 heads)
@@ -135,51 +136,9 @@ Trained via contrastive learning where the SEP token produces fixed-size embeddi
135
  - **Inverted mode (80%)**: Pulls embeddings of heavily-masked and lightly-masked versions of the same input close together, teaching invariance to noise and occlusion
136
  - **Modality mode (20%)**: Pulls embeddings of path-only and text-only views of the same word close together, teaching cross-modal alignment between gesture geometry and semantic meaning
137
 
138
- The contrastive loss (15% weight, temperature 0.07) pulls matching pairs together in embedding space while pushing non-matches apart. Uses Matryoshka embeddings to create nested representations at multiple dimensions (64, 128, 384, 768), with stronger weight on lower-dimensional representations (2.0×, 1.5×, 1.0×, 1.0×) to ensure the first 64 dimensions are highly informative on their own.
139
-
140
- ## Usage Examples
141
 
142
- ### Length Prediction
143
-
144
- This
145
-
146
- ```python
147
- from datasets import load_dataset
148
- from transformers import AutoModel, AutoProcessor
149
-
150
-
151
- model = AutoModel.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
152
- model.eval()
153
- model.requires_grad_(False)
154
- processor = AutoProcessor.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
155
-
156
- # Load a sample row from the dataset.
157
- ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]")
158
- row = ds[0] # "Brahmas"
159
-
160
- # Length-only inference:
161
- # `encode_path(...)` preprocesses the swipe path to fixed-length motion features and sets text attention to 0.
162
- inputs = processor.encode_path(row["data"], return_tensors="pt")
163
- outputs = model(**inputs, return_dict=True)
164
-
165
- # Length prediction is a regression scalar (float); round it for an integer length.
166
- pred_len = float(outputs.length_logits.item())
167
- pred_len_rounded = max(0, int(round(pred_len)))
168
- true_len = sum(1 for c in row["word"].lower() if c.isalpha() or c.isdigit())
169
-
170
- print(f'Word: "{row["word"]}"')
171
- print(f"Length (true): {true_len}")
172
- print(f"Length (pred): {pred_len:.3f}")
173
- print(f"Length (pred rounded):{pred_len_rounded}")
174
- ```
175
-
176
-
177
- ```text
178
- Word: "Brahmas"
179
- Length (true): 7
180
- Length (pred): 7.483
181
- Length (pred rounded):7
182
- ```
183
 
184
  ### Embedding Similarity
185
 
 
21
 
22
  > [!IMPORTANT]
23
  > This model is currently in beta status and is subject to change.
24
+ > Last updated 2025-12-19
25
 
26
  Multimodal, multi-objective transformer for swipe keyboard prediction.
27
  Trained on the [futo-org/swipe.futo.org](https://huggingface.co/datasets/futo-org/swipe.futo.org) dataset.
 
37
  </p>
38
 
39
 
 
 
 
40
 
41
+ ## Quick Start (Length Prediction)
42
 
43
+ ```python
44
+ from datasets import load_dataset
45
+ from transformers import AutoModel, AutoProcessor
46
+
47
+
48
+ model = AutoModel.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
49
+ model.eval()
50
+ model.requires_grad_(False)
51
+ processor = AutoProcessor.from_pretrained("dleemiller/SwipeALot-base", trust_remote_code=True)
52
+
53
+ # Load a sample row from the dataset.
54
+ ds = load_dataset("futo-org/swipe.futo.org", split="test[:50]")
55
+ row = ds[0] # "Brahmas"
56
+
57
+ # Length-only inference:
58
+ # `encode_path(...)` preprocesses the swipe path to fixed-length motion features and sets text attention to 0.
59
+ inputs = processor.encode_path(row["data"], return_tensors="pt")
60
+ outputs = model(**inputs, return_dict=True)
61
+
62
+ # Length prediction is a regression scalar (float); round it for an integer length.
63
+ pred_len = float(outputs.length_logits.item())
64
+ pred_len_rounded = max(0, int(round(pred_len)))
65
+ true_len = sum(1 for c in row["word"].lower() if c.isalpha() or c.isdigit())
66
+
67
+ print(f'Word: "{row["word"]}"')
68
+ print(f"Length (true): {true_len}")
69
+ print(f"Length (pred): {pred_len:.3f}")
70
+ print(f"Length (pred rounded):{pred_len_rounded}")
71
+ ```
72
+
73
+ ```text
74
+ Word: "Brahmas"
75
+ Length (true): 7
76
+ Length (pred): 7.483
77
+ Length (pred rounded):7
78
  ```
79
 
80
+
81
  ## Model Details
82
 
83
  - **Architecture**: Transformer encoder (768-dim, 12 layers, 12 heads)
 
136
  - **Inverted mode (80%)**: Pulls embeddings of heavily-masked and lightly-masked versions of the same input close together, teaching invariance to noise and occlusion
137
  - **Modality mode (20%)**: Pulls embeddings of path-only and text-only views of the same word close together, teaching cross-modal alignment between gesture geometry and semantic meaning
138
 
139
+ The contrastive loss (10-20% weight, temperature 0.07) pulls matching pairs together in embedding space while pushing non-matches apart. Uses Matryoshka embeddings to create nested representations at multiple dimensions (64, 128, 384, 768), with stronger weight on lower-dimensional representations (2.0×, 1.5×, 1.0×, 1.0×) to ensure the first 64 dimensions are highly informative on their own.
 
 
140
 
141
+ ## More Usage Examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  ### Embedding Similarity
144