dleemiller commited on
Commit
bf31071
·
verified ·
1 Parent(s): 8dde163

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,155 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SwipeALot Base Model
2
+
3
+ Multimodal transformer for swipe keyboard prediction. Trained on the [futo-org/swipe.futo.org](https://huggingface.co/datasets/futo-org/swipe.futo.org) dataset.
4
+
5
+ ## Quick Start
6
+
7
+ ```python
8
+ from transformers import AutoModel, AutoProcessor
9
+ import torch
10
+
11
+ # Load model
12
+ model = AutoModel.from_pretrained("path/to/model", trust_remote_code=True)
13
+ processor = AutoProcessor.from_pretrained("path/to/model", trust_remote_code=True)
14
+ model.eval()
15
+
16
+ # Example: Predict word from swipe path
17
+ from datasets import load_dataset
18
+ from swipealot.data.dataset import normalize_coordinates, sample_path_points
19
+
20
+ # Load sample
21
+ dataset = load_dataset("futo-org/swipe.futo.org", split="test[:1]")
22
+ item = dataset[0]
23
+
24
+ # Preprocess path
25
+ normalized = normalize_coordinates(item["data"], item["canvas_width"], item["canvas_height"])
26
+ path_coords, _ = sample_path_points(normalized, processor.max_path_len)
27
+ path = torch.tensor([path_coords], dtype=torch.float32)
28
+
29
+ # Get predictions
30
+ inputs = processor(path_coords=path, text=None, return_tensors="pt")
31
+
32
+ with torch.no_grad():
33
+ outputs = model(**inputs)
34
+
35
+ # Length prediction
36
+ predicted_length = outputs.length_logits.argmax(dim=-1).item()
37
+ print(f"Predicted word length: {predicted_length}")
38
+ ```
39
+
40
+ ## Model Details
41
+
42
+ - **Architecture**: Transformer encoder (768-dim, 12 layers, 12 heads)
43
+ - **Parameters**: 87M
44
+ - **Training Data**: futo-org/swipe.futo.org dataset
45
+ - **Max Path Length**: 128 points
46
+ - **Max Word Length**: 48 characters
47
+ - **Vocab Size**: 43 (a-z, 0-9, special tokens)
48
+
49
+ ## Capabilities
50
+
51
+ ### 1. Character Prediction
52
+ Predict characters from swipe paths with partial text context.
53
+
54
+ **Use Case**: Autocorrection, suggestion ranking
55
+
56
+ ### 2. Length Prediction
57
+ Predict word length from swipe path alone.
58
+
59
+ **Accuracy**: 89% exact, 96% within ±1
60
+
61
+ **Use Case**: Pre-filtering candidate words
62
+
63
+ ### 3. Path Reconstruction
64
+ Reconstruct missing path coordinates.
65
+
66
+ **MSE**: 0.005 on masked points
67
+
68
+ **Use Case**: Noise reduction, gesture smoothing
69
+
70
+ ### 4. Embedding Extraction
71
+ Extract fixed-size embeddings for similarity search.
72
+
73
+ **Dimension**: 768
74
+
75
+ **Use Case**: Similar gesture search, deduplication
76
+
77
+ ## Usage Examples
78
+
79
+ See the [full documentation](https://github.com/dleemiller/legendary-waffle) for detailed examples of each capability.
80
+
81
+ ### Masked Character Prediction
82
+
83
+ ```python
84
+ # Process full word, then manually mask positions
85
+ inputs = processor(path_coords=path, text="hello", return_tensors="pt")
86
+ mask_token_id = processor.tokenizer.mask_token_id
87
+ char_ids = inputs["input_ids"][0].tolist()
88
+ char_ids[2] = mask_token_id # Mask 'l' at position 2
89
+ inputs["input_ids"] = torch.tensor([char_ids], dtype=torch.long)
90
+
91
+ # Model predicts masked character from path + context
92
+ ```
93
+
94
+ ### Full Word Reconstruction
95
+
96
+ ```python
97
+ # Process word, then mask all character positions
98
+ inputs = processor(path_coords=path, text="hello", return_tensors="pt")
99
+ char_ids = inputs["input_ids"][0].tolist()
100
+ mask_token_id = processor.tokenizer.mask_token_id
101
+ masked_ids = [mask_token_id if cid != 0 else 0 for cid in char_ids]
102
+ inputs["input_ids"] = torch.tensor([masked_ids], dtype=torch.long)
103
+
104
+ # Predict from path only - achieves 94% character accuracy
105
+ ```
106
+
107
+ ### Length Prediction
108
+
109
+ ```python
110
+ inputs = processor(path_coords=path, text=None, return_tensors="pt")
111
+ predicted_length = outputs.length_logits.argmax(dim=-1).item()
112
+ ```
113
+
114
+ ## Performance Metrics
115
+
116
+ Evaluated on 200 test samples:
117
+
118
+ | Task | Metric | Score |
119
+ |------|--------|-------|
120
+ | Masked Prediction (30%) | Character Accuracy | 98.7% |
121
+ | | Top-3 Accuracy | 100% |
122
+ | | Word Accuracy | 97.3% |
123
+ | Full Reconstruction (100%) | Character Accuracy | 94% |
124
+ | | Word Accuracy | 83.7% |
125
+ | Length Prediction | Exact Accuracy | 89% |
126
+ | | Within ±1 | 96% |
127
+ | | Within ±2 | 99% |
128
+ | Path Reconstruction | MSE (masked) | 0.005 |
129
+
130
+ ## Model Outputs
131
+
132
+ ```python
133
+ outputs = model(**inputs)
134
+
135
+ # Available outputs:
136
+ outputs.char_logits # [batch, seq_len, vocab_size] - Character predictions
137
+ outputs.length_logits # [batch, max_length] - Length predictions
138
+ outputs.path_logits # [batch, seq_len, 3] - Path coordinate predictions
139
+ outputs.pooler_output # [batch, d_model] - SEP token embeddings for similarity
140
+ outputs.last_hidden_state # [batch, seq_len, d_model] - Hidden representations
141
+ ```
142
+
143
+ ## Citation
144
+
145
+ ```bibtex
146
+ @software{swipealot2024,
147
+ title={SwipeALot: Multimodal Swipe Keyboard Transformer},
148
+ year={2024},
149
+ url={https://github.com/dleemiller/legendary-waffle}
150
+ }
151
+ ```
152
+
153
+ ## License
154
+
155
+ MIT License
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SwipeTransformerModel"
4
+ ],
5
+ "cls_token_id": 1,
6
+ "d_ff": 3072,
7
+ "d_model": 768,
8
+ "dropout": 0.1,
9
+ "dtype": "float32",
10
+ "eos_token_id": 5,
11
+ "mask_token_id": 3,
12
+ "max_char_len": 48,
13
+ "max_path_len": 128,
14
+ "model_type": "swipe_transformer",
15
+ "n_heads": 12,
16
+ "n_layers": 12,
17
+ "pad_token_id": 0,
18
+ "predict_path": true,
19
+ "sep_token_id": 2,
20
+ "transformers_version": "4.57.3",
21
+ "unk_token_id": 4,
22
+ "vocab_size": 43,
23
+ "auto_map": {
24
+ "AutoConfig": "configuration_swipe.SwipeTransformerConfig",
25
+ "AutoModel": "modeling_swipe.SwipeTransformerModel"
26
+ }
27
+ }
configuration_swipe.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration classes for SwipeTransformer HuggingFace models."""
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class SwipeTransformerConfig(PretrainedConfig):
7
+ """
8
+ Configuration class for SwipeTransformerModel.
9
+
10
+ This configuration stores all the parameters needed to instantiate a
11
+ SwipeTransformerModel. This is the base configuration for the multimodal
12
+ swipe keyboard transformer that processes path coordinates and text.
13
+
14
+ Args:
15
+ d_model (int, optional): Hidden dimension size. Defaults to 256.
16
+ n_layers (int, optional): Number of transformer layers. Defaults to 4.
17
+ n_heads (int, optional): Number of attention heads. Defaults to 4.
18
+ d_ff (int, optional): Feedforward dimension. Defaults to 1024.
19
+ dropout (float, optional): Dropout rate. Defaults to 0.1.
20
+ vocab_size (int, optional): Size of vocabulary. Defaults to 100.
21
+ max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
22
+ max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
23
+ predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
24
+ pad_token_id (int, optional): Padding token ID. Defaults to 0.
25
+ cls_token_id (int, optional): CLS token ID. Defaults to 1.
26
+ sep_token_id (int, optional): SEP token ID. Defaults to 2.
27
+ mask_token_id (int, optional): MASK token ID. Defaults to 3.
28
+ unk_token_id (int, optional): Unknown token ID. Defaults to 4.
29
+ eos_token_id (int, optional): End-of-sequence token ID. Defaults to 5.
30
+ """
31
+
32
+ model_type = "swipe_transformer"
33
+
34
+ def __init__(
35
+ self,
36
+ d_model: int = 256,
37
+ n_layers: int = 4,
38
+ n_heads: int = 4,
39
+ d_ff: int = 1024,
40
+ dropout: float = 0.1,
41
+ vocab_size: int = 100,
42
+ max_path_len: int = 64,
43
+ max_char_len: int = 38,
44
+ predict_path: bool = True,
45
+ pad_token_id: int = 0,
46
+ cls_token_id: int = 1,
47
+ sep_token_id: int = 2,
48
+ mask_token_id: int = 3,
49
+ unk_token_id: int = 4,
50
+ eos_token_id: int = 5,
51
+ **kwargs,
52
+ ):
53
+ super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, **kwargs)
54
+
55
+ # Model architecture parameters
56
+ self.d_model = d_model
57
+ self.n_layers = n_layers
58
+ self.n_heads = n_heads
59
+ self.d_ff = d_ff
60
+ self.dropout = dropout
61
+
62
+ # Vocabulary and sequence length
63
+ self.vocab_size = vocab_size
64
+ self.max_path_len = max_path_len
65
+ self.max_char_len = max_char_len
66
+
67
+ # Model capabilities
68
+ self.predict_path = predict_path
69
+
70
+ # Special tokens
71
+ self.cls_token_id = cls_token_id
72
+ self.sep_token_id = sep_token_id
73
+ self.mask_token_id = mask_token_id
74
+ self.unk_token_id = unk_token_id
75
+
76
+
77
+ class SwipeCrossEncoderConfig(PretrainedConfig):
78
+ """
79
+ Configuration class for SwipeCrossEncoderForSequenceClassification.
80
+
81
+ This configuration extends the base SwipeTransformer config for use in
82
+ cross-encoder tasks (e.g., path-word similarity scoring).
83
+
84
+ Args:
85
+ d_model (int, optional): Hidden dimension size. Defaults to 256.
86
+ n_layers (int, optional): Number of transformer layers. Defaults to 4.
87
+ n_heads (int, optional): Number of attention heads. Defaults to 4.
88
+ d_ff (int, optional): Feedforward dimension. Defaults to 1024.
89
+ dropout (float, optional): Dropout rate. Defaults to 0.1.
90
+ vocab_size (int, optional): Size of vocabulary. Defaults to 100.
91
+ max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
92
+ max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
93
+ num_labels (int, optional): Number of classification labels. Defaults to 1.
94
+ problem_type (str, optional): Problem type ('regression' or 'single_label_classification'). Defaults to "regression".
95
+ pad_token_id (int, optional): Padding token ID. Defaults to 0.
96
+ cls_token_id (int, optional): CLS token ID. Defaults to 1.
97
+ sep_token_id (int, optional): SEP token ID. Defaults to 2.
98
+ mask_token_id (int, optional): MASK token ID. Defaults to 3.
99
+ unk_token_id (int, optional): Unknown token ID. Defaults to 4.
100
+ eos_token_id (int, optional): End-of-sequence token ID. Defaults to 5.
101
+ """
102
+
103
+ model_type = "swipe_cross_encoder"
104
+
105
+ def __init__(
106
+ self,
107
+ d_model: int = 256,
108
+ n_layers: int = 4,
109
+ n_heads: int = 4,
110
+ d_ff: int = 1024,
111
+ dropout: float = 0.1,
112
+ vocab_size: int = 100,
113
+ max_path_len: int = 64,
114
+ max_char_len: int = 38,
115
+ num_labels: int = 1,
116
+ problem_type: str = "regression",
117
+ pad_token_id: int = 0,
118
+ cls_token_id: int = 1,
119
+ sep_token_id: int = 2,
120
+ mask_token_id: int = 3,
121
+ unk_token_id: int = 4,
122
+ eos_token_id: int = 5,
123
+ **kwargs,
124
+ ):
125
+ super().__init__(
126
+ pad_token_id=pad_token_id, num_labels=num_labels, eos_token_id=eos_token_id, **kwargs
127
+ )
128
+
129
+ # Model architecture parameters
130
+ self.d_model = d_model
131
+ self.n_layers = n_layers
132
+ self.n_heads = n_heads
133
+ self.d_ff = d_ff
134
+ self.dropout = dropout
135
+
136
+ # Vocabulary and sequence length
137
+ self.vocab_size = vocab_size
138
+ self.max_path_len = max_path_len
139
+ self.max_char_len = max_char_len
140
+
141
+ # Classification parameters
142
+ self.problem_type = problem_type
143
+
144
+ # Special tokens
145
+ self.cls_token_id = cls_token_id
146
+ self.sep_token_id = sep_token_id
147
+ self.mask_token_id = mask_token_id
148
+ self.unk_token_id = unk_token_id
conversion_metadata.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "original_checkpoint": "checkpoints/base_20251213_164813/best.pt",
3
+ "original_config": "embedded_in_checkpoint",
4
+ "converted_at": "2025-12-15 08:03:26.074041",
5
+ "model_type": "base",
6
+ "vocab_size": 43,
7
+ "epoch": 38,
8
+ "global_step": 70000,
9
+ "metrics": {
10
+ "loss": 0.03484317846596241,
11
+ "char_accuracy": 0.9536226987838745,
12
+ "word_accuracy": 0.771,
13
+ "char_loss_mean": 0.03484317846596241,
14
+ "total_loss_mean": 0.03484317846596241
15
+ }
16
+ }
embeddings.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding layers for SwipeTransformer."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class PathEmbedding(nn.Module):
8
+ """Embeds path coordinates (x, y, t) to d_model dimension."""
9
+
10
+ def __init__(self, d_model: int = 256):
11
+ super().__init__()
12
+ self.projection = nn.Linear(3, d_model)
13
+
14
+ def forward(self, path_coords: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Args:
17
+ path_coords: [batch, seq_len, 3] - (x, y, t) coordinates
18
+
19
+ Returns:
20
+ [batch, seq_len, d_model] embeddings
21
+ """
22
+ return self.projection(path_coords)
23
+
24
+
25
+ class CharacterEmbedding(nn.Module):
26
+ """Embeds character tokens."""
27
+
28
+ def __init__(self, vocab_size: int, d_model: int = 256, padding_idx: int = 0):
29
+ super().__init__()
30
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
31
+
32
+ def forward(self, char_tokens: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Args:
35
+ char_tokens: [batch, seq_len] character token IDs
36
+
37
+ Returns:
38
+ [batch, seq_len, d_model] embeddings
39
+ """
40
+ return self.embedding(char_tokens)
41
+
42
+
43
+ class PositionalEmbedding(nn.Module):
44
+ """Learned positional embeddings."""
45
+
46
+ def __init__(self, max_position: int, d_model: int = 256):
47
+ super().__init__()
48
+ self.embedding = nn.Embedding(max_position, d_model)
49
+
50
+ def forward(self, positions: torch.Tensor) -> torch.Tensor:
51
+ """
52
+ Args:
53
+ positions: [batch, seq_len] position indices
54
+
55
+ Returns:
56
+ [batch, seq_len, d_model] positional embeddings
57
+ """
58
+ return self.embedding(positions)
59
+
60
+
61
+ class TypeEmbedding(nn.Module):
62
+ """Token type embeddings to distinguish PATH (0) vs TEXT (1) tokens."""
63
+
64
+ def __init__(self, d_model: int = 256):
65
+ super().__init__()
66
+ # 0 = PATH, 1 = TEXT
67
+ self.embedding = nn.Embedding(2, d_model)
68
+
69
+ def forward(self, token_types: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ Args:
72
+ token_types: [batch, seq_len] type indices (0 or 1)
73
+
74
+ Returns:
75
+ [batch, seq_len, d_model] type embeddings
76
+ """
77
+ return self.embedding(token_types)
78
+
79
+
80
+ class MixedEmbedding(nn.Module):
81
+ """
82
+ Combines path and character embeddings with positional and type information.
83
+ Constructs sequence: [CLS] + path_tokens + [SEP] + char_tokens
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ vocab_size: int,
89
+ max_path_len: int,
90
+ max_char_len: int,
91
+ d_model: int = 256,
92
+ dropout: float = 0.1,
93
+ ):
94
+ super().__init__()
95
+ self.d_model = d_model
96
+
97
+ # Content embeddings
98
+ self.path_embedding = PathEmbedding(d_model)
99
+ self.char_embedding = CharacterEmbedding(vocab_size, d_model, padding_idx=0)
100
+
101
+ # Positional embeddings
102
+ max_seq_len = 1 + max_path_len + 1 + max_char_len # [CLS] + path + [SEP] + chars
103
+ self.positional_embedding = PositionalEmbedding(max_seq_len, d_model)
104
+
105
+ # Type embeddings
106
+ self.type_embedding = TypeEmbedding(d_model)
107
+
108
+ # Layer norm and dropout
109
+ self.layer_norm = nn.LayerNorm(d_model)
110
+ self.dropout = nn.Dropout(dropout)
111
+
112
+ def forward(
113
+ self,
114
+ path_coords: torch.Tensor,
115
+ char_tokens: torch.Tensor,
116
+ cls_token: torch.Tensor,
117
+ sep_token: torch.Tensor,
118
+ ) -> torch.Tensor:
119
+ """
120
+ Create mixed sequence with embeddings.
121
+
122
+ Args:
123
+ path_coords: [batch, path_len, 3] path coordinates
124
+ char_tokens: [batch, char_len] character token IDs
125
+ cls_token: [batch, 1] CLS token IDs
126
+ sep_token: [batch, 1] SEP token IDs
127
+
128
+ Returns:
129
+ [batch, total_seq_len, d_model] embeddings where
130
+ total_seq_len = 1 + path_len + 1 + char_len
131
+ """
132
+ batch_size = path_coords.shape[0]
133
+ path_len = path_coords.shape[1]
134
+ char_len = char_tokens.shape[1]
135
+ device = path_coords.device
136
+
137
+ # Embed [CLS]
138
+ cls_emb = self.char_embedding(cls_token) # [batch, 1, d_model]
139
+
140
+ # Embed path
141
+ path_emb = self.path_embedding(path_coords) # [batch, path_len, d_model]
142
+
143
+ # Embed [SEP]
144
+ sep_emb = self.char_embedding(sep_token) # [batch, 1, d_model]
145
+
146
+ # Embed characters
147
+ char_emb = self.char_embedding(char_tokens) # [batch, char_len, d_model]
148
+
149
+ # Concatenate: [CLS] + PATH + [SEP] + CHARS
150
+ sequence = torch.cat(
151
+ [cls_emb, path_emb, sep_emb, char_emb], dim=1
152
+ ) # [batch, seq_len, d_model]
153
+ seq_len = sequence.shape[1]
154
+
155
+ # Add positional embeddings
156
+ positions = (
157
+ torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
158
+ ) # [batch, seq_len]
159
+ pos_emb = self.positional_embedding(positions)
160
+
161
+ # Add type embeddings
162
+ # Type 0 for [CLS] + path + [SEP], Type 1 for chars
163
+ type_ids = torch.cat(
164
+ [
165
+ torch.zeros(
166
+ batch_size, 1 + path_len + 1, dtype=torch.long, device=device
167
+ ), # [CLS], path, [SEP]
168
+ torch.ones(batch_size, char_len, dtype=torch.long, device=device), # chars
169
+ ],
170
+ dim=1,
171
+ ) # [batch, seq_len]
172
+ type_emb = self.type_embedding(type_ids)
173
+
174
+ # Combine: content + position + type
175
+ embeddings = sequence + pos_emb + type_emb
176
+ embeddings = self.layer_norm(embeddings)
177
+ embeddings = self.dropout(embeddings)
178
+
179
+ return embeddings
heads.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prediction heads for SwipeTransformer."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class CharacterPredictionHead(nn.Module):
8
+ """Prediction head for masked characters."""
9
+
10
+ def __init__(self, d_model: int, vocab_size: int):
11
+ super().__init__()
12
+ self.dense = nn.Linear(d_model, d_model)
13
+ self.layer_norm = nn.LayerNorm(d_model)
14
+ self.decoder = nn.Linear(d_model, vocab_size)
15
+ self.activation = nn.GELU()
16
+
17
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
18
+ """
19
+ Args:
20
+ hidden_states: [batch, seq_len, d_model]
21
+
22
+ Returns:
23
+ [batch, seq_len, vocab_size] logits
24
+ """
25
+ x = self.dense(hidden_states)
26
+ x = self.activation(x)
27
+ x = self.layer_norm(x)
28
+ logits = self.decoder(x)
29
+ return logits
30
+
31
+
32
+ class PathPredictionHead(nn.Module):
33
+ """Prediction head for masked path coordinates."""
34
+
35
+ def __init__(self, d_model: int):
36
+ super().__init__()
37
+ self.dense = nn.Linear(d_model, d_model)
38
+ self.layer_norm = nn.LayerNorm(d_model)
39
+ self.decoder = nn.Linear(d_model, 3) # Predict (x, y, t)
40
+ self.activation = nn.GELU()
41
+
42
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Args:
45
+ hidden_states: [batch, seq_len, d_model]
46
+
47
+ Returns:
48
+ [batch, seq_len, 3] coordinates in [0, 1] range
49
+ """
50
+ x = self.dense(hidden_states)
51
+ x = self.activation(x)
52
+ x = self.layer_norm(x)
53
+ coords = self.decoder(x)
54
+ coords = torch.sigmoid(coords) # Ensure [0, 1] range
55
+ return coords
56
+
57
+
58
+ class ClassificationHead(nn.Module):
59
+ """
60
+ Classification head for cross-encoder.
61
+
62
+ Follows SBERT architecture: Dense → GELU → LayerNorm → Linear(→1)
63
+ Outputs a single similarity score per input.
64
+ """
65
+
66
+ def __init__(self, d_model: int, num_labels: int = 1):
67
+ super().__init__()
68
+ self.dense = nn.Linear(d_model, d_model)
69
+ self.activation = nn.GELU()
70
+ self.norm = nn.LayerNorm(d_model)
71
+ self.classifier = nn.Linear(d_model, num_labels)
72
+
73
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
74
+ """
75
+ Args:
76
+ features: [batch, d_model] - typically SEP token embeddings
77
+
78
+ Returns:
79
+ [batch, num_labels] similarity scores
80
+ """
81
+ x = self.dense(features)
82
+ x = self.activation(x) # GELU
83
+ x = self.norm(x) # LayerNorm
84
+ logits = self.classifier(x) # [batch, 1] or [batch, num_labels]
85
+ return logits
86
+
87
+
88
+ class LengthPredictionHead(nn.Module):
89
+ """Predict sequence length (e.g., swipable character count) from CLS embedding."""
90
+
91
+ def __init__(self, d_model: int, max_length: int):
92
+ super().__init__()
93
+ self.dense = nn.Linear(d_model, d_model)
94
+ self.activation = nn.GELU()
95
+ self.norm = nn.LayerNorm(d_model)
96
+ self.classifier = nn.Linear(d_model, max_length + 1) # classes: 0..max_length
97
+
98
+ def forward(self, cls_features: torch.Tensor) -> torch.Tensor:
99
+ """
100
+ Args:
101
+ cls_features: [batch, d_model] CLS embeddings
102
+
103
+ Returns:
104
+ [batch, max_length+1] logits over lengths
105
+ """
106
+ x = self.dense(cls_features)
107
+ x = self.activation(x)
108
+ x = self.norm(x)
109
+ return self.classifier(x)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b00a7a35dc99485501db127c4433afee95bd4f0c34d6c6db8ed6c69eec43404b
3
+ size 348336548
modeling_swipe.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace-compatible model classes for SwipeTransformer."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ BaseModelOutputWithPooling,
12
+ ModelOutput,
13
+ SequenceClassifierOutput,
14
+ )
15
+
16
+ from .configuration_swipe import SwipeCrossEncoderConfig, SwipeTransformerConfig
17
+
18
+
19
+ @dataclass
20
+ class SwipeTransformerOutput(ModelOutput):
21
+ """
22
+ Output type for SwipeTransformerModel.
23
+
24
+ Args:
25
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
26
+ Language modeling loss (character prediction).
27
+ char_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`):
28
+ Prediction scores of the character prediction head.
29
+ path_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 3)`, *optional*):
30
+ Prediction scores of the path prediction head (if enabled).
31
+ length_logits (`torch.FloatTensor` of shape `(batch_size, max_length+1)`, *optional*):
32
+ Prediction scores of the length prediction head (if enabled).
33
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
34
+ Sequence of hidden-states at the output of the last layer of the model.
35
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
36
+ SEP token embeddings for similarity/embedding tasks.
37
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*):
38
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, hidden_size)`.
39
+ """
40
+
41
+ loss: Optional[torch.FloatTensor] = None
42
+ char_logits: torch.FloatTensor = None
43
+ path_logits: Optional[torch.FloatTensor] = None
44
+ length_logits: Optional[torch.FloatTensor] = None
45
+ last_hidden_state: torch.FloatTensor = None
46
+ pooler_output: Optional[torch.FloatTensor] = None
47
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
48
+
49
+
50
+ class SwipeTransformerPreTrainedModel(PreTrainedModel):
51
+ """
52
+ An abstract class to handle weights initialization and a simple interface
53
+ for downloading and loading pretrained models.
54
+ """
55
+
56
+ config_class = SwipeTransformerConfig
57
+ base_model_prefix = "swipe_transformer"
58
+ supports_gradient_checkpointing = False
59
+
60
+ def _init_weights(self, module):
61
+ """Initialize the weights"""
62
+ if isinstance(module, nn.Linear):
63
+ nn.init.xavier_uniform_(module.weight)
64
+ if module.bias is not None:
65
+ nn.init.zeros_(module.bias)
66
+ elif isinstance(module, nn.Embedding):
67
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
68
+ if module.padding_idx is not None:
69
+ module.weight.data[module.padding_idx].zero_()
70
+ elif isinstance(module, nn.LayerNorm):
71
+ nn.init.ones_(module.weight)
72
+ nn.init.zeros_(module.bias)
73
+
74
+
75
+ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
76
+ """
77
+ HuggingFace-compatible SwipeTransformerModel.
78
+
79
+ This model reuses the existing components from src/swipealot/models/
80
+ and wraps them in a HuggingFace-compatible interface.
81
+
82
+ Args:
83
+ config (SwipeTransformerConfig): Model configuration
84
+ """
85
+
86
+ def __init__(self, config: SwipeTransformerConfig):
87
+ super().__init__(config)
88
+ self.config = config
89
+
90
+ # Import existing components
91
+ from .embeddings import MixedEmbedding
92
+ from .heads import CharacterPredictionHead, LengthPredictionHead, PathPredictionHead
93
+
94
+ # Embeddings
95
+ self.embeddings = MixedEmbedding(
96
+ vocab_size=config.vocab_size,
97
+ max_path_len=config.max_path_len,
98
+ max_char_len=config.max_char_len,
99
+ d_model=config.d_model,
100
+ dropout=config.dropout,
101
+ )
102
+
103
+ # Transformer encoder
104
+ encoder_layer = nn.TransformerEncoderLayer(
105
+ d_model=config.d_model,
106
+ nhead=config.n_heads,
107
+ dim_feedforward=config.d_ff,
108
+ dropout=config.dropout,
109
+ activation="gelu",
110
+ batch_first=True,
111
+ norm_first=True, # Pre-LayerNorm
112
+ )
113
+ self.encoder = nn.TransformerEncoder(
114
+ encoder_layer,
115
+ num_layers=config.n_layers,
116
+ enable_nested_tensor=False,
117
+ )
118
+
119
+ # Prediction heads
120
+ self.char_head = CharacterPredictionHead(
121
+ d_model=config.d_model,
122
+ vocab_size=config.vocab_size,
123
+ )
124
+
125
+ if config.predict_path:
126
+ self.path_head = PathPredictionHead(d_model=config.d_model)
127
+ else:
128
+ self.path_head = None
129
+
130
+ # Length prediction head (predicts word length from path)
131
+ # Max length is max_char_len (including EOS)
132
+ self.length_head = LengthPredictionHead(
133
+ d_model=config.d_model,
134
+ max_length=config.max_char_len,
135
+ )
136
+
137
+ # Initialize weights
138
+ self.post_init()
139
+
140
+ def forward(
141
+ self,
142
+ path_coords: torch.Tensor,
143
+ input_ids: torch.Tensor,
144
+ attention_mask: torch.Tensor | None = None,
145
+ labels: torch.Tensor | None = None,
146
+ return_dict: bool | None = None,
147
+ output_hidden_states: bool | None = None,
148
+ ):
149
+ """
150
+ Forward pass of the model.
151
+
152
+ Args:
153
+ path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
154
+ input_ids (torch.Tensor): Character token IDs [batch, char_len]
155
+ attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
156
+ labels (torch.Tensor, optional): Labels for loss calculation [batch, char_len]
157
+ return_dict (bool, optional): Whether to return ModelOutput object
158
+ output_hidden_states (bool, optional): Whether to output hidden states
159
+
160
+ Returns:
161
+ SwipeTransformerOutput or tuple: Model outputs with:
162
+ - loss: Optional loss value
163
+ - char_logits: Character prediction logits [batch, seq_len, vocab_size]
164
+ - path_logits: Path prediction logits [batch, seq_len, 3] (if predict_path=True)
165
+ - length_logits: Length prediction logits [batch, max_length]
166
+ - last_hidden_state: Hidden states [batch, seq_len, d_model]
167
+ - pooler_output: SEP token embeddings [batch, d_model] for similarity/embedding tasks
168
+ - hidden_states: Tuple of hidden states (if output_hidden_states=True)
169
+ """
170
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
171
+
172
+ batch_size = path_coords.shape[0]
173
+ device = path_coords.device
174
+
175
+ # Create [CLS] and [SEP] tokens
176
+ cls_token = torch.full(
177
+ (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
178
+ )
179
+ sep_token = torch.full(
180
+ (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
181
+ )
182
+
183
+ # Get embeddings
184
+ embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
185
+
186
+ # Prepare attention mask for encoder
187
+ if attention_mask is not None:
188
+ # Convert attention mask: 1 = attend, 0 = ignore
189
+ # PyTorch expects: False = attend, True = ignore
190
+ src_key_padding_mask = attention_mask == 0
191
+ else:
192
+ src_key_padding_mask = None
193
+
194
+ # Encode (batch_first=True is set in TransformerEncoderLayer)
195
+ hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
196
+
197
+ # Character prediction
198
+ char_logits = self.char_head(hidden_states)
199
+
200
+ # Path prediction (if enabled)
201
+ path_logits = None
202
+ if self.path_head is not None:
203
+ path_logits = self.path_head(hidden_states)
204
+
205
+ # Length prediction from CLS token
206
+ cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
207
+ length_logits = self.length_head(cls_hidden) # [batch, max_length]
208
+
209
+ # Extract SEP token embedding for pooler output (embeddings/similarity tasks)
210
+ # SEP is at position 1 + path_len
211
+ path_len = path_coords.shape[1]
212
+ sep_position = 1 + path_len
213
+ pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
214
+
215
+ # Compute loss if labels provided
216
+ loss = None
217
+ if labels is not None:
218
+ loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
219
+ # Extract character positions from hidden states
220
+ # Sequence is: [CLS] + path + [SEP] + chars
221
+ char_start = 1 + path_len + 1 # After [CLS], path, and [SEP]
222
+ char_hidden = hidden_states[:, char_start : char_start + labels.shape[1], :]
223
+ char_pred = self.char_head(char_hidden)
224
+ loss = loss_fct(char_pred.reshape(-1, self.config.vocab_size), labels.reshape(-1))
225
+
226
+ if not return_dict:
227
+ output = (hidden_states, char_logits, length_logits, pooler_output)
228
+ if path_logits is not None:
229
+ output = output + (path_logits,)
230
+ return ((loss,) + output) if loss is not None else output
231
+
232
+ return SwipeTransformerOutput(
233
+ loss=loss,
234
+ char_logits=char_logits,
235
+ path_logits=path_logits,
236
+ length_logits=length_logits,
237
+ last_hidden_state=hidden_states,
238
+ pooler_output=pooler_output,
239
+ hidden_states=(hidden_states,) if output_hidden_states else None,
240
+ )
241
+
242
+
243
+ class SwipeCrossEncoderForSequenceClassification(SwipeTransformerPreTrainedModel):
244
+ """
245
+ HuggingFace-compatible cross-encoder for sequence classification.
246
+
247
+ This model is designed for similarity scoring between swipe paths and words.
248
+ It extracts the SEP token embedding and passes it through a classification head.
249
+
250
+ Args:
251
+ config (SwipeCrossEncoderConfig): Model configuration
252
+ """
253
+
254
+ config_class = SwipeCrossEncoderConfig
255
+ base_model_prefix = "swipe_cross_encoder"
256
+
257
+ def __init__(self, config: SwipeCrossEncoderConfig):
258
+ super().__init__(config)
259
+ self.config = config
260
+ self.num_labels = config.num_labels
261
+
262
+ # Import existing components
263
+ from .embeddings import MixedEmbedding
264
+ from .heads import ClassificationHead
265
+
266
+ # Embeddings
267
+ self.embeddings = MixedEmbedding(
268
+ vocab_size=config.vocab_size,
269
+ max_path_len=config.max_path_len,
270
+ max_char_len=config.max_char_len,
271
+ d_model=config.d_model,
272
+ dropout=config.dropout,
273
+ )
274
+
275
+ # Transformer encoder
276
+ encoder_layer = nn.TransformerEncoderLayer(
277
+ d_model=config.d_model,
278
+ nhead=config.n_heads,
279
+ dim_feedforward=config.d_ff,
280
+ dropout=config.dropout,
281
+ activation="gelu",
282
+ batch_first=True,
283
+ norm_first=True, # Pre-LayerNorm
284
+ )
285
+ self.encoder = nn.TransformerEncoder(
286
+ encoder_layer,
287
+ num_layers=config.n_layers,
288
+ enable_nested_tensor=False,
289
+ )
290
+
291
+ # Classification head
292
+ self.classifier = ClassificationHead(
293
+ d_model=config.d_model,
294
+ num_labels=config.num_labels,
295
+ )
296
+
297
+ # Initialize weights
298
+ self.post_init()
299
+
300
+ def forward(
301
+ self,
302
+ path_coords: torch.Tensor,
303
+ input_ids: torch.Tensor,
304
+ attention_mask: torch.Tensor | None = None,
305
+ labels: torch.Tensor | None = None,
306
+ return_dict: bool | None = None,
307
+ ):
308
+ """
309
+ Forward pass for cross-encoder.
310
+
311
+ Args:
312
+ path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
313
+ input_ids (torch.Tensor): Character token IDs [batch, char_len]
314
+ attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
315
+ labels (torch.Tensor, optional): Labels for loss calculation [batch, num_labels]
316
+ return_dict (bool, optional): Whether to return ModelOutput object
317
+
318
+ Returns:
319
+ SequenceClassifierOutput or tuple: Model outputs with logits and optional loss
320
+ """
321
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
322
+
323
+ batch_size = path_coords.shape[0]
324
+ device = path_coords.device
325
+
326
+ # Create [CLS] and [SEP] tokens
327
+ cls_token = torch.full(
328
+ (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
329
+ )
330
+ sep_token = torch.full(
331
+ (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
332
+ )
333
+
334
+ # Get embeddings
335
+ embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
336
+
337
+ # Prepare attention mask
338
+ if attention_mask is not None:
339
+ src_key_padding_mask = attention_mask == 0
340
+ else:
341
+ src_key_padding_mask = None
342
+
343
+ # Encode (batch_first=True is set in TransformerEncoderLayer)
344
+ hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
345
+
346
+ # Extract SEP token embedding
347
+ # SEP is at position 1 + path_len
348
+ path_len = path_coords.shape[1]
349
+ sep_position = 1 + path_len
350
+ sep_embedding = hidden_states[:, sep_position, :] # [batch, d_model]
351
+
352
+ # Classification
353
+ logits = self.classifier(sep_embedding) # [batch, num_labels]
354
+
355
+ # Compute loss if labels provided
356
+ loss = None
357
+ if labels is not None:
358
+ if self.config.problem_type is None:
359
+ if self.num_labels == 1:
360
+ self.config.problem_type = "regression"
361
+ else:
362
+ self.config.problem_type = "single_label_classification"
363
+
364
+ if self.config.problem_type == "regression":
365
+ loss_fct = nn.MSELoss()
366
+ if self.num_labels == 1:
367
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
368
+ else:
369
+ loss = loss_fct(logits, labels)
370
+ elif self.config.problem_type == "single_label_classification":
371
+ loss_fct = nn.CrossEntropyLoss()
372
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
373
+
374
+ if not return_dict:
375
+ output = (logits,) + (hidden_states,)
376
+ return ((loss,) + output) if loss is not None else output
377
+
378
+ return SequenceClassifierOutput(
379
+ loss=loss,
380
+ logits=logits,
381
+ hidden_states=(hidden_states,),
382
+ )
383
+
384
+
385
+ class SwipeModel(SwipeTransformerPreTrainedModel):
386
+ """
387
+ Base Swipe model for extracting embeddings.
388
+
389
+ .. deprecated::
390
+ This class is deprecated. Use SwipeTransformerModel instead, which now
391
+ includes pooler_output for embeddings alongside prediction heads.
392
+ SwipeTransformerModel provides both predictions AND embeddings in a single model.
393
+
394
+ This model returns the SEP token embedding, which can be used for:
395
+ - Vector databases
396
+ - Semantic search
397
+ - Similarity computation
398
+
399
+ The SEP token embedding represents the joint encoding of the path and text.
400
+
401
+ Usage (Deprecated - use SwipeTransformerModel instead):
402
+ ```python
403
+ from transformers import AutoModel
404
+
405
+ model = AutoModel.from_pretrained(
406
+ "your-username/swipe-model",
407
+ trust_remote_code=True
408
+ )
409
+
410
+ # Get embeddings
411
+ outputs = model(path_coords=paths, input_ids=tokens)
412
+ embeddings = outputs.pooler_output # SEP token embeddings
413
+ ```
414
+
415
+ Args:
416
+ config (SwipeTransformerConfig or SwipeCrossEncoderConfig): Model configuration
417
+ """
418
+
419
+ def __init__(self, config):
420
+ super().__init__(config)
421
+ self.config = config
422
+
423
+ # Import existing components
424
+ from .embeddings import MixedEmbedding
425
+
426
+ # Embeddings
427
+ self.embeddings = MixedEmbedding(
428
+ vocab_size=config.vocab_size,
429
+ max_path_len=config.max_path_len,
430
+ max_char_len=config.max_char_len,
431
+ d_model=config.d_model,
432
+ dropout=config.dropout,
433
+ )
434
+
435
+ # Transformer encoder
436
+ encoder_layer = nn.TransformerEncoderLayer(
437
+ d_model=config.d_model,
438
+ nhead=config.n_heads,
439
+ dim_feedforward=config.d_ff,
440
+ dropout=config.dropout,
441
+ activation="gelu",
442
+ batch_first=True,
443
+ norm_first=True, # Pre-LayerNorm
444
+ )
445
+ self.encoder = nn.TransformerEncoder(
446
+ encoder_layer,
447
+ num_layers=config.n_layers,
448
+ enable_nested_tensor=False,
449
+ )
450
+
451
+ # Initialize weights
452
+ self.post_init()
453
+
454
+ def forward(
455
+ self,
456
+ path_coords: torch.Tensor,
457
+ input_ids: torch.Tensor,
458
+ attention_mask: torch.Tensor | None = None,
459
+ return_dict: bool | None = None,
460
+ output_hidden_states: bool | None = None,
461
+ ):
462
+ """
463
+ Forward pass that returns embeddings.
464
+
465
+ Args:
466
+ path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
467
+ input_ids (torch.Tensor): Character token IDs [batch, char_len]
468
+ attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
469
+ return_dict (bool, optional): Whether to return ModelOutput object
470
+ output_hidden_states (bool, optional): Whether to output all hidden states
471
+
472
+ Returns:
473
+ BaseModelOutputWithPooling with:
474
+ - last_hidden_state: Full sequence hidden states [batch, seq_len, d_model]
475
+ - pooler_output: SEP token embeddings [batch, d_model]
476
+ - hidden_states: Tuple of hidden states (if output_hidden_states=True)
477
+ """
478
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
479
+
480
+ batch_size = path_coords.shape[0]
481
+ device = path_coords.device
482
+
483
+ # Create [CLS] and [SEP] tokens
484
+ cls_token = torch.full(
485
+ (batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
486
+ )
487
+ sep_token = torch.full(
488
+ (batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
489
+ )
490
+
491
+ # Get embeddings
492
+ embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
493
+
494
+ # Prepare attention mask
495
+ if attention_mask is not None:
496
+ src_key_padding_mask = attention_mask == 0
497
+ else:
498
+ src_key_padding_mask = None
499
+
500
+ # Encode (batch_first=True is set in TransformerEncoderLayer)
501
+ hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
502
+
503
+ # Extract SEP token embedding (pooler output)
504
+ # SEP is at position 1 + path_len
505
+ path_len = path_coords.shape[1]
506
+ sep_position = 1 + path_len
507
+ pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
508
+
509
+ if not return_dict:
510
+ return (hidden_states, pooler_output)
511
+
512
+ return BaseModelOutputWithPooling(
513
+ last_hidden_state=hidden_states,
514
+ pooler_output=pooler_output,
515
+ hidden_states=(hidden_states,) if output_hidden_states else None,
516
+ )
processing_swipe.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Processor for handling multimodal swipe inputs (path + text)."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from transformers import ProcessorMixin
6
+
7
+
8
+ class SwipeProcessor(ProcessorMixin):
9
+ """
10
+ Processor for handling multimodal swipe inputs (path coordinates + text).
11
+
12
+ This processor combines path coordinate preprocessing with text tokenization,
13
+ creating the inputs needed for SwipeTransformer models.
14
+
15
+ Args:
16
+ tokenizer: SwipeTokenizer instance
17
+ max_path_len (int): Maximum path length. Defaults to 64.
18
+ max_char_len (int): Maximum character length. Defaults to 38.
19
+ """
20
+
21
+ attributes = ["tokenizer"]
22
+ tokenizer_class = "AutoTokenizer" # Will use auto_map from tokenizer_config.json
23
+
24
+ def __init__(self, tokenizer=None, max_path_len: int = 64, max_char_len: int = 38):
25
+ self.tokenizer = tokenizer
26
+ self.max_path_len = max_path_len
27
+ self.max_char_len = max_char_len
28
+ # Attributes expected by newer transformers (not used for swipe models)
29
+ self.chat_template = None
30
+ self.audio_tokenizer = None
31
+ self.feature_extractor = None
32
+ self.image_processor = None
33
+
34
+ def __call__(
35
+ self,
36
+ path_coords: list[list[list[float]]] | torch.Tensor | np.ndarray | None = None,
37
+ text: str | list[str] | None = None,
38
+ padding: bool | str = True,
39
+ truncation: bool = True,
40
+ max_length: int | None = None,
41
+ return_tensors: str | None = "pt",
42
+ **kwargs,
43
+ ):
44
+ """
45
+ Process path coordinates and text into model inputs.
46
+
47
+ Args:
48
+ path_coords: List of paths or tensor [batch, path_len, 3]
49
+ Each point is (x, y, time). Can be None if only processing text.
50
+ text: String or list of strings to encode. Can be None if only processing paths.
51
+ padding: Whether to pad sequences. Can be True/False or "max_length"
52
+ truncation: Whether to truncate sequences
53
+ max_length: Maximum sequence length for text (overrides max_char_len)
54
+ return_tensors: "pt" for PyTorch, "np" for NumPy, None for lists
55
+ **kwargs: Additional keyword arguments
56
+
57
+ Returns:
58
+ Dictionary with:
59
+ - path_coords: [batch, max_path_len, 3] (if path_coords provided)
60
+ - input_ids: [batch, max_char_len] (if text provided)
61
+ - attention_mask: [batch, total_seq_len]
62
+ """
63
+ if path_coords is None and text is None:
64
+ raise ValueError("Must provide either path_coords or text (or both)")
65
+
66
+ # Determine batch size
67
+ if path_coords is not None:
68
+ # Handle path coordinates
69
+ if isinstance(path_coords, (list, tuple)):
70
+ # Check if it's a batch or single path
71
+ if len(path_coords) > 0 and isinstance(path_coords[0][0], (list, tuple)):
72
+ # Batch of paths [[path1], [path2], ...]
73
+ path_coords = torch.tensor(path_coords, dtype=torch.float32)
74
+ else:
75
+ # Single path [[x,y,t], [x,y,t], ...]
76
+ path_coords = torch.tensor([path_coords], dtype=torch.float32)
77
+ elif isinstance(path_coords, np.ndarray):
78
+ path_coords = torch.from_numpy(path_coords).float()
79
+ if path_coords.dim() == 2:
80
+ # Single path, add batch dimension
81
+ path_coords = path_coords.unsqueeze(0)
82
+ elif isinstance(path_coords, torch.Tensor):
83
+ if path_coords.dim() == 2:
84
+ # Single path, add batch dimension
85
+ path_coords = path_coords.unsqueeze(0)
86
+
87
+ batch_size = path_coords.shape[0]
88
+ elif text is not None:
89
+ if isinstance(text, str):
90
+ batch_size = 1
91
+ text = [text]
92
+ else:
93
+ batch_size = len(text)
94
+ else:
95
+ batch_size = 1
96
+
97
+ result = {}
98
+
99
+ # Process path coordinates
100
+ if path_coords is not None:
101
+ current_path_len = path_coords.shape[1]
102
+
103
+ # Truncate if needed
104
+ if truncation and current_path_len > self.max_path_len:
105
+ path_coords = path_coords[:, : self.max_path_len, :]
106
+ current_path_len = self.max_path_len
107
+
108
+ # Pad if needed
109
+ if padding and current_path_len < self.max_path_len:
110
+ pad_len = self.max_path_len - current_path_len
111
+ path_coords = torch.cat([path_coords, torch.zeros(batch_size, pad_len, 3)], dim=1)
112
+
113
+ # Create path mask (1 = real data, 0 = padding)
114
+ path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
115
+ if padding and current_path_len < self.max_path_len:
116
+ path_mask[:, current_path_len:] = 0
117
+
118
+ result["path_coords"] = path_coords
119
+ # Store path_mask internally for attention_mask construction
120
+ _path_mask = path_mask
121
+ else:
122
+ # No path coords provided, create empty/zero tensors
123
+ path_coords = torch.zeros(batch_size, self.max_path_len, 3)
124
+ _path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
125
+ result["path_coords"] = path_coords
126
+
127
+ # Process text
128
+ if text is not None:
129
+ # Ensure text is a list
130
+ if isinstance(text, str):
131
+ text = [text]
132
+
133
+ # Tokenize text
134
+ text_max_length = max_length if max_length is not None else self.max_char_len
135
+
136
+ # First tokenize without padding/truncation to add EOS
137
+ encoded_raw = self.tokenizer(
138
+ text,
139
+ padding=False,
140
+ truncation=False,
141
+ return_tensors=None, # Get lists first
142
+ **kwargs,
143
+ )
144
+
145
+ # Add EOS token after each word (matching training dataset behavior)
146
+ eos_id = self.tokenizer.eos_token_id
147
+ for i in range(len(encoded_raw["input_ids"])):
148
+ # Add EOS if not already present
149
+ if encoded_raw["input_ids"][i][-1] != eos_id:
150
+ encoded_raw["input_ids"][i].append(eos_id)
151
+
152
+ # Now apply padding and truncation
153
+ max_len_needed = max(len(ids) for ids in encoded_raw["input_ids"])
154
+ if truncation and max_len_needed > text_max_length:
155
+ # Truncate but preserve EOS at the end
156
+ for i in range(len(encoded_raw["input_ids"])):
157
+ if len(encoded_raw["input_ids"][i]) > text_max_length:
158
+ encoded_raw["input_ids"][i] = (
159
+ encoded_raw["input_ids"][i][: text_max_length - 1] + [eos_id]
160
+ )
161
+
162
+ # Pad sequences
163
+ if padding:
164
+ pad_id = self.tokenizer.pad_token_id
165
+ for i in range(len(encoded_raw["input_ids"])):
166
+ seq_len = len(encoded_raw["input_ids"][i])
167
+ if seq_len < text_max_length:
168
+ encoded_raw["input_ids"][i].extend([pad_id] * (text_max_length - seq_len))
169
+
170
+ # Create attention mask (1 for real tokens + EOS, 0 for padding)
171
+ _char_mask = []
172
+ for ids in encoded_raw["input_ids"]:
173
+ mask = [1 if token_id != self.tokenizer.pad_token_id else 0 for token_id in ids]
174
+ _char_mask.append(mask)
175
+
176
+ # Convert to tensors if requested
177
+ if return_tensors == "pt":
178
+ result["input_ids"] = torch.tensor(encoded_raw["input_ids"], dtype=torch.long)
179
+ _char_mask = torch.tensor(_char_mask, dtype=torch.long)
180
+ elif return_tensors == "np":
181
+ result["input_ids"] = np.array(encoded_raw["input_ids"], dtype=np.int64)
182
+ _char_mask = np.array(_char_mask, dtype=np.int64)
183
+ else:
184
+ result["input_ids"] = encoded_raw["input_ids"]
185
+ else:
186
+ # No text provided, create padding tokens
187
+ if return_tensors == "pt":
188
+ char_tokens = torch.full(
189
+ (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=torch.long
190
+ )
191
+ _char_mask = torch.zeros(batch_size, self.max_char_len, dtype=torch.long)
192
+ elif return_tensors == "np":
193
+ char_tokens = np.full(
194
+ (batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=np.int64
195
+ )
196
+ _char_mask = np.zeros((batch_size, self.max_char_len), dtype=np.int64)
197
+ else:
198
+ char_tokens = [
199
+ [self.tokenizer.pad_token_id] * self.max_char_len for _ in range(batch_size)
200
+ ]
201
+ _char_mask = [[0] * self.max_char_len for _ in range(batch_size)]
202
+
203
+ result["input_ids"] = char_tokens
204
+
205
+ # Create combined attention mask: [CLS] + path + [SEP] + chars
206
+ # Sequence structure: [CLS:1] + _path_mask + [SEP:1] + _char_mask
207
+ if return_tensors == "pt":
208
+ cls_mask = torch.ones(batch_size, 1, dtype=torch.long)
209
+ sep_mask = torch.ones(batch_size, 1, dtype=torch.long)
210
+ attention_mask = torch.cat([cls_mask, _path_mask, sep_mask, _char_mask], dim=1)
211
+ elif return_tensors == "np":
212
+ cls_mask = np.ones((batch_size, 1), dtype=np.int64)
213
+ sep_mask = np.ones((batch_size, 1), dtype=np.int64)
214
+ attention_mask = np.concatenate([cls_mask, _path_mask, sep_mask, _char_mask], axis=1)
215
+ else:
216
+ cls_mask = [[1] for _ in range(batch_size)]
217
+ sep_mask = [[1] for _ in range(batch_size)]
218
+ attention_mask = [
219
+ cls + path.tolist() + sep + char
220
+ for cls, path, sep, char in zip(
221
+ cls_mask, _path_mask, sep_mask, _char_mask, strict=False
222
+ )
223
+ ]
224
+
225
+ result["attention_mask"] = attention_mask
226
+
227
+ # Convert to requested format
228
+ if return_tensors == "np":
229
+ for key in result:
230
+ if isinstance(result[key], torch.Tensor):
231
+ result[key] = result[key].numpy()
232
+ elif return_tensors is None:
233
+ for key in result:
234
+ if isinstance(result[key], torch.Tensor):
235
+ result[key] = result[key].tolist()
236
+
237
+ return result
238
+
239
+ def batch_decode(self, token_ids, **kwargs):
240
+ """
241
+ Decode token IDs to strings.
242
+
243
+ Args:
244
+ token_ids: Token IDs to decode
245
+ **kwargs: Additional arguments passed to tokenizer
246
+
247
+ Returns:
248
+ List of decoded strings
249
+ """
250
+ return self.tokenizer.batch_decode(token_ids, **kwargs)
251
+
252
+ def decode(self, token_ids, **kwargs):
253
+ """
254
+ Decode single sequence of token IDs to string.
255
+
256
+ Args:
257
+ token_ids: Token IDs to decode
258
+ **kwargs: Additional arguments passed to tokenizer
259
+
260
+ Returns:
261
+ Decoded string
262
+ """
263
+ return self.tokenizer.decode(token_ids, **kwargs)
processor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_char_len": 48,
3
+ "max_path_len": 128,
4
+ "processor_class": "SwipeProcessor",
5
+ "auto_map": {
6
+ "AutoProcessor": "processing_swipe.SwipeProcessor"
7
+ }
8
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "eos_token": "[EOS]",
4
+ "mask_token": "[MASK]",
5
+ "pad_token": "[PAD]",
6
+ "sep_token": "[SEP]",
7
+ "unk_token": "[UNK]"
8
+ }
tokenization_swipe.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace-compatible tokenizer for SwipeTransformer."""
2
+
3
+ import json
4
+ import os
5
+
6
+ from transformers import PreTrainedTokenizer
7
+
8
+ from .tokenizer import CharacterTokenizer
9
+
10
+
11
+ class SwipeTokenizer(PreTrainedTokenizer):
12
+ """
13
+ HuggingFace-compatible tokenizer that wraps the existing CharacterTokenizer.
14
+
15
+ This tokenizer provides a HuggingFace-compatible interface for the custom
16
+ character-level tokenization used in the swipe keyboard model.
17
+
18
+ Args:
19
+ vocab_file (str, optional): Path to vocabulary file
20
+ unk_token (str): Unknown token. Defaults to "[UNK]"
21
+ sep_token (str): Separator token. Defaults to "[SEP]"
22
+ pad_token (str): Padding token. Defaults to "[PAD]"
23
+ cls_token (str): Classification token. Defaults to "[CLS]"
24
+ mask_token (str): Mask token. Defaults to "[MASK]"
25
+ eos_token (str): End-of-sequence token. Defaults to "[EOS]"
26
+ """
27
+
28
+ vocab_files_names = {"vocab_file": "vocab.json"}
29
+ model_input_names = ["input_ids", "attention_mask"]
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_file: str | None = None,
34
+ unk_token: str = "[UNK]",
35
+ sep_token: str = "[SEP]",
36
+ pad_token: str = "[PAD]",
37
+ cls_token: str = "[CLS]",
38
+ mask_token: str = "[MASK]",
39
+ eos_token: str = "[EOS]",
40
+ **kwargs,
41
+ ):
42
+ # Initialize internal CharacterTokenizer BEFORE calling super().__init__()
43
+ # because super().__init__() will call get_vocab() which needs self._tokenizer
44
+ if vocab_file is not None and os.path.exists(vocab_file):
45
+ # Load from vocab file
46
+ with open(vocab_file, encoding="utf-8") as f:
47
+ vocab_data = json.load(f)
48
+
49
+ # Extract vocabulary (excluding ALL special tokens)
50
+ # All special tokens that should NOT be passed to CharacterTokenizer
51
+ # Convert AddedToken objects to strings
52
+ special_tokens_to_exclude = {
53
+ str(pad_token),
54
+ str(cls_token),
55
+ str(sep_token),
56
+ str(mask_token),
57
+ str(unk_token),
58
+ str(eos_token),
59
+ "[PUNC]",
60
+ }
61
+
62
+ if "chars" in vocab_data:
63
+ # Filter out special tokens from the chars list
64
+ vocab = set(c for c in vocab_data["chars"] if c not in special_tokens_to_exclude)
65
+ elif "char_to_id" in vocab_data:
66
+ # Get all characters except special tokens
67
+ vocab = set(
68
+ c for c in vocab_data["char_to_id"].keys() if c not in special_tokens_to_exclude
69
+ )
70
+ else:
71
+ vocab = None
72
+
73
+ self._tokenizer = CharacterTokenizer(vocab=vocab)
74
+ else:
75
+ # Default vocab (will be built from dataset during conversion)
76
+ self._tokenizer = CharacterTokenizer()
77
+
78
+ super().__init__(
79
+ unk_token=unk_token,
80
+ sep_token=sep_token,
81
+ pad_token=pad_token,
82
+ cls_token=cls_token,
83
+ mask_token=mask_token,
84
+ eos_token=eos_token,
85
+ **kwargs,
86
+ )
87
+
88
+ @property
89
+ def vocab_size(self) -> int:
90
+ """Return the size of the vocabulary"""
91
+ return self._tokenizer.vocab_size
92
+
93
+ def get_vocab(self):
94
+ """Return the vocabulary as a dict"""
95
+ return self._tokenizer.char_to_id.copy()
96
+
97
+ def _tokenize(self, text: str) -> list[str]:
98
+ """
99
+ Tokenize a string into tokens (characters).
100
+
101
+ Args:
102
+ text (str): Text to tokenize
103
+
104
+ Returns:
105
+ List[str]: List of character tokens
106
+ """
107
+ # Convert to lowercase and split into characters
108
+ return list(text.lower())
109
+
110
+ def _convert_token_to_id(self, token: str) -> int:
111
+ """
112
+ Convert a token (character) to an id using the vocabulary.
113
+
114
+ Args:
115
+ token (str): Token to convert
116
+
117
+ Returns:
118
+ int: Token ID
119
+ """
120
+ return self._tokenizer.char_to_id.get(token, self._tokenizer.unk_token_id)
121
+
122
+ def _convert_id_to_token(self, index: int) -> str:
123
+ """
124
+ Convert an index to a token using the vocabulary.
125
+
126
+ Args:
127
+ index (int): Token ID
128
+
129
+ Returns:
130
+ str: Token (character)
131
+ """
132
+ return self._tokenizer.id_to_char.get(index, self.unk_token)
133
+
134
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
135
+ """
136
+ Convert a list of tokens (characters) to a string.
137
+
138
+ Args:
139
+ tokens (List[str]): List of tokens
140
+
141
+ Returns:
142
+ str: Concatenated string
143
+ """
144
+ # Filter out special tokens
145
+ special_tokens = {
146
+ self.pad_token,
147
+ self.cls_token,
148
+ self.sep_token,
149
+ self.mask_token,
150
+ self.unk_token,
151
+ self.eos_token,
152
+ }
153
+ filtered = [t for t in tokens if t not in special_tokens]
154
+ return "".join(filtered)
155
+
156
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple:
157
+ """
158
+ Save the tokenizer vocabulary to a directory.
159
+
160
+ Args:
161
+ save_directory (str): Directory to save the vocabulary
162
+ filename_prefix (str, optional): Optional prefix for the vocabulary file
163
+
164
+ Returns:
165
+ tuple: Tuple containing the path to the saved vocabulary file
166
+ """
167
+ if not os.path.isdir(save_directory):
168
+ os.makedirs(save_directory, exist_ok=True)
169
+
170
+ vocab_file = os.path.join(
171
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
172
+ )
173
+
174
+ # Save vocabulary and mappings
175
+ vocab_data = {
176
+ "chars": sorted(list(set(self._tokenizer.char_to_id.keys()))),
177
+ "char_to_id": self._tokenizer.char_to_id,
178
+ "special_tokens": {
179
+ "pad_token": self.pad_token,
180
+ "cls_token": self.cls_token,
181
+ "sep_token": self.sep_token,
182
+ "mask_token": self.mask_token,
183
+ "unk_token": self.unk_token,
184
+ "eos_token": self.eos_token,
185
+ },
186
+ }
187
+
188
+ with open(vocab_file, "w", encoding="utf-8") as f:
189
+ json.dump(vocab_data, f, ensure_ascii=False, indent=2)
190
+
191
+ return (vocab_file,)
192
+
193
+ def build_inputs_with_special_tokens(
194
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
195
+ ) -> list[int]:
196
+ """
197
+ Build model inputs from a sequence by adding special tokens.
198
+
199
+ For swipe models, we don't add special tokens here as they are
200
+ handled separately (CLS and SEP are managed by the model/processor).
201
+
202
+ Args:
203
+ token_ids_0 (List[int]): First sequence
204
+ token_ids_1 (List[int], optional): Second sequence
205
+
206
+ Returns:
207
+ List[int]: Sequence with special tokens
208
+ """
209
+ # For swipe models, special tokens are handled by the processor
210
+ # Just return the tokens as-is
211
+ if token_ids_1 is None:
212
+ return token_ids_0
213
+ return token_ids_0 + token_ids_1
214
+
215
+ def get_special_tokens_mask(
216
+ self,
217
+ token_ids_0: list[int],
218
+ token_ids_1: list[int] | None = None,
219
+ already_has_special_tokens: bool = False,
220
+ ) -> list[int]:
221
+ """
222
+ Retrieve sequence ids from a token list.
223
+
224
+ Args:
225
+ token_ids_0 (List[int]): First sequence
226
+ token_ids_1 (List[int], optional): Second sequence
227
+ already_has_special_tokens (bool): Whether tokens already have special tokens
228
+
229
+ Returns:
230
+ List[int]: Mask (1 for special tokens, 0 for normal tokens)
231
+ """
232
+ # All special token handling is done by the processor
233
+ # Return all zeros
234
+ if already_has_special_tokens:
235
+ if token_ids_1 is not None:
236
+ raise ValueError(
237
+ "You should not supply a second sequence if the provided sequence already has special tokens."
238
+ )
239
+ return [0] * len(token_ids_0)
240
+
241
+ if token_ids_1 is None:
242
+ return [0] * len(token_ids_0)
243
+ return [0] * len(token_ids_0) + [0] * len(token_ids_1)
tokenizer.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Character tokenizer and vocabulary utilities for swipe keyboard dataset."""
2
+
3
+ import hashlib
4
+
5
+ import torch
6
+
7
+
8
+ class CharacterTokenizer:
9
+ """Character-level tokenizer for swipe keyboard words."""
10
+
11
+ def __init__(self, vocab: set | None = None):
12
+ """
13
+ Initialize tokenizer with vocabulary.
14
+
15
+ Args:
16
+ vocab: Optional set of characters. If None, will use printable ASCII.
17
+ """
18
+ # Special tokens
19
+ self.pad_token = "[PAD]"
20
+ self.cls_token = "[CLS]"
21
+ self.sep_token = "[SEP]"
22
+ self.mask_token = "[MASK]"
23
+ self.unk_token = "[UNK]"
24
+ self.eos_token = "[EOS]" # End of word token
25
+ self.punc_token = "[PUNC]"
26
+
27
+ self.special_tokens = [
28
+ self.pad_token, # 0
29
+ self.cls_token, # 1
30
+ self.sep_token, # 2
31
+ self.mask_token, # 3
32
+ self.unk_token, # 4
33
+ self.eos_token, # 5
34
+ self.punc_token, # 6
35
+ ]
36
+
37
+ # Build vocabulary deterministically (lowercase letters + digits).
38
+ chars = set(chr(i) for i in range(ord("a"), ord("z") + 1))
39
+ chars.update(str(d) for d in range(10))
40
+ if vocab is not None:
41
+ # Allow explicit extension for special cases
42
+ chars.update(vocab)
43
+
44
+ self.char_to_id = {token: idx for idx, token in enumerate(self.special_tokens)}
45
+ for idx, char in enumerate(sorted(chars), start=len(self.special_tokens)):
46
+ self.char_to_id[char] = idx
47
+
48
+ self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
49
+ self.vocab_size = len(self.char_to_id)
50
+
51
+ def encode(self, text: str) -> list[int]:
52
+ """Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
53
+ unk_id = self.char_to_id[self.unk_token]
54
+ punc_id = self.char_to_id[self.punc_token]
55
+ tokens = []
56
+ for char in text.lower():
57
+ if char.isalpha() or char.isdigit():
58
+ tokens.append(self.char_to_id.get(char, unk_id))
59
+ else:
60
+ tokens.append(punc_id)
61
+ return tokens
62
+
63
+ def decode(self, token_ids: list[int]) -> str:
64
+ """Decode token IDs to text, stopping at EOS token."""
65
+ chars = []
66
+ for token_id in token_ids:
67
+ if token_id in self.id_to_char:
68
+ char = self.id_to_char[token_id]
69
+ # Stop at EOS token
70
+ if char == self.eos_token:
71
+ break
72
+ # Skip other special tokens except for debugging
73
+ if char not in self.special_tokens or char == " ":
74
+ chars.append(char)
75
+ return "".join(chars)
76
+
77
+ @property
78
+ def pad_token_id(self) -> int:
79
+ return self.char_to_id[self.pad_token]
80
+
81
+ @property
82
+ def cls_token_id(self) -> int:
83
+ return self.char_to_id[self.cls_token]
84
+
85
+ @property
86
+ def sep_token_id(self) -> int:
87
+ return self.char_to_id[self.sep_token]
88
+
89
+ @property
90
+ def mask_token_id(self) -> int:
91
+ return self.char_to_id[self.mask_token]
92
+
93
+ @property
94
+ def unk_token_id(self) -> int:
95
+ return self.char_to_id[self.unk_token]
96
+
97
+ @property
98
+ def eos_token_id(self) -> int:
99
+ return self.char_to_id[self.eos_token]
100
+
101
+ @property
102
+ def punc_token_id(self) -> int:
103
+ return self.char_to_id[self.punc_token]
104
+
105
+
106
+ def vocab_hash(tokenizer: CharacterTokenizer) -> str:
107
+ """Stable hash of the tokenizer's id->token mapping (includes specials)."""
108
+ ordered_tokens = [tokenizer.id_to_char[i] for i in range(tokenizer.vocab_size)]
109
+ joined = "\n".join(ordered_tokens).encode("utf-8")
110
+ return hashlib.sha256(joined).hexdigest()
111
+
112
+
113
+ def compute_char_frequency_weights(
114
+ tokenizer: CharacterTokenizer,
115
+ dataset,
116
+ max_samples: int | None = None,
117
+ weight_exponent: float = 1.0,
118
+ ):
119
+ """Compute inverse log frequency weights for characters.
120
+
121
+ Args:
122
+ tokenizer: CharacterTokenizer used for encoding
123
+ dataset: HF dataset or iterable of samples with a 'word' field
124
+ max_samples: Optional cap on samples to scan
125
+
126
+ Returns:
127
+ torch.Tensor of shape [vocab_size] with weights normalized to mean=1.
128
+ Padding token weight is set to the non-pad mean (not zero) so min>0.
129
+ """
130
+ counts = torch.ones(tokenizer.vocab_size, dtype=torch.float) # start at 1 for smoothing
131
+
132
+ # Collect all token IDs first for vectorized counting
133
+ all_token_ids = []
134
+ for idx, sample in enumerate(dataset):
135
+ if max_samples is not None and idx >= max_samples:
136
+ break
137
+
138
+ # Encode lowercase characters and append EOS (matches training labels)
139
+ token_ids = tokenizer.encode(sample["word"]) + [tokenizer.eos_token_id]
140
+ all_token_ids.extend(token_ids)
141
+
142
+ # Use bincount for efficient vectorized counting
143
+ if all_token_ids:
144
+ token_tensor = torch.tensor(all_token_ids, dtype=torch.long)
145
+ bincount_result = torch.bincount(token_tensor, minlength=tokenizer.vocab_size).float()
146
+ counts = counts + bincount_result
147
+
148
+ # Padding is never a supervised label, but keep a finite weight
149
+ pad_id = tokenizer.pad_token_id
150
+ counts[pad_id] = counts[pad_id] # leave smoothing value as-is
151
+
152
+ # Inverse log weighting; add 1 inside log to avoid div-by-zero
153
+ weights = 1.0 / torch.log1p(counts)
154
+
155
+ # Use non-pad mean for pad token to avoid zero/inf
156
+ non_pad_mask = torch.ones_like(weights, dtype=torch.bool)
157
+ non_pad_mask[pad_id] = False
158
+ non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8)
159
+ weights[pad_id] = non_pad_mean
160
+
161
+ # Optional tempering (e.g., exponent <1 flattens extremes)
162
+ if weight_exponent != 1.0:
163
+ weights = torch.pow(weights, weight_exponent)
164
+
165
+ # Normalize to keep loss scale stable (mean of non-pad tokens -> 1)
166
+ non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8)
167
+ weights[pad_id] = non_pad_mean
168
+ weights = weights / non_pad_mean
169
+
170
+ return weights
tokenizer_config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[MASK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[UNK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "[EOS]",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ }
51
+ },
52
+ "clean_up_tokenization_spaces": false,
53
+ "cls_token": "[CLS]",
54
+ "eos_token": "[EOS]",
55
+ "extra_special_tokens": {},
56
+ "mask_token": "[MASK]",
57
+ "model_max_length": 1000000000000000019884624838656,
58
+ "pad_token": "[PAD]",
59
+ "processor_class": "SwipeProcessor",
60
+ "sep_token": "[SEP]",
61
+ "tokenizer_class": "SwipeTokenizer",
62
+ "unk_token": "[UNK]",
63
+ "auto_map": {
64
+ "AutoTokenizer": [
65
+ "tokenization_swipe.SwipeTokenizer",
66
+ null
67
+ ]
68
+ }
69
+ }
vocab.json ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chars": [
3
+ "0",
4
+ "1",
5
+ "2",
6
+ "3",
7
+ "4",
8
+ "5",
9
+ "6",
10
+ "7",
11
+ "8",
12
+ "9",
13
+ "[CLS]",
14
+ "[EOS]",
15
+ "[MASK]",
16
+ "[PAD]",
17
+ "[PUNC]",
18
+ "[SEP]",
19
+ "[UNK]",
20
+ "a",
21
+ "b",
22
+ "c",
23
+ "d",
24
+ "e",
25
+ "f",
26
+ "g",
27
+ "h",
28
+ "i",
29
+ "j",
30
+ "k",
31
+ "l",
32
+ "m",
33
+ "n",
34
+ "o",
35
+ "p",
36
+ "q",
37
+ "r",
38
+ "s",
39
+ "t",
40
+ "u",
41
+ "v",
42
+ "w",
43
+ "x",
44
+ "y",
45
+ "z"
46
+ ],
47
+ "char_to_id": {
48
+ "[PAD]": 0,
49
+ "[CLS]": 1,
50
+ "[SEP]": 2,
51
+ "[MASK]": 3,
52
+ "[UNK]": 4,
53
+ "[EOS]": 5,
54
+ "[PUNC]": 6,
55
+ "0": 7,
56
+ "1": 8,
57
+ "2": 9,
58
+ "3": 10,
59
+ "4": 11,
60
+ "5": 12,
61
+ "6": 13,
62
+ "7": 14,
63
+ "8": 15,
64
+ "9": 16,
65
+ "a": 17,
66
+ "b": 18,
67
+ "c": 19,
68
+ "d": 20,
69
+ "e": 21,
70
+ "f": 22,
71
+ "g": 23,
72
+ "h": 24,
73
+ "i": 25,
74
+ "j": 26,
75
+ "k": 27,
76
+ "l": 28,
77
+ "m": 29,
78
+ "n": 30,
79
+ "o": 31,
80
+ "p": 32,
81
+ "q": 33,
82
+ "r": 34,
83
+ "s": 35,
84
+ "t": 36,
85
+ "u": 37,
86
+ "v": 38,
87
+ "w": 39,
88
+ "x": 40,
89
+ "y": 41,
90
+ "z": 42
91
+ },
92
+ "special_tokens": {
93
+ "pad_token": "[PAD]",
94
+ "cls_token": "[CLS]",
95
+ "sep_token": "[SEP]",
96
+ "mask_token": "[MASK]",
97
+ "unk_token": "[UNK]",
98
+ "eos_token": "[EOS]"
99
+ }
100
+ }