Upload folder using huggingface_hub
Browse files- config.json +5 -1
- configuration_swipe.py +7 -74
- embeddings.py +18 -6
- heads.py +22 -39
- model.safetensors +2 -2
- modeling_swipe.py +191 -334
- preprocessing.py +275 -0
- processing_swipe.py +215 -132
- processor_config.json +2 -0
- special_tokens_map.json +42 -6
- tokenization_swipe.py +37 -1
- tokenizer.py +17 -6
- tokenizer_config.json +7 -7
config.json
CHANGED
|
@@ -15,6 +15,9 @@
|
|
| 15 |
"n_heads": 12,
|
| 16 |
"n_layers": 12,
|
| 17 |
"pad_token_id": 0,
|
|
|
|
|
|
|
|
|
|
| 18 |
"predict_path": true,
|
| 19 |
"sep_token_id": 2,
|
| 20 |
"transformers_version": "4.57.3",
|
|
@@ -22,6 +25,7 @@
|
|
| 22 |
"vocab_size": 43,
|
| 23 |
"auto_map": {
|
| 24 |
"AutoConfig": "configuration_swipe.SwipeTransformerConfig",
|
| 25 |
-
"AutoModel": "modeling_swipe.SwipeTransformerModel"
|
|
|
|
| 26 |
}
|
| 27 |
}
|
|
|
|
| 15 |
"n_heads": 12,
|
| 16 |
"n_layers": 12,
|
| 17 |
"pad_token_id": 0,
|
| 18 |
+
"path_input_dim": 6,
|
| 19 |
+
"predict_char": true,
|
| 20 |
+
"predict_length": true,
|
| 21 |
"predict_path": true,
|
| 22 |
"sep_token_id": 2,
|
| 23 |
"transformers_version": "4.57.3",
|
|
|
|
| 25 |
"vocab_size": 43,
|
| 26 |
"auto_map": {
|
| 27 |
"AutoConfig": "configuration_swipe.SwipeTransformerConfig",
|
| 28 |
+
"AutoModel": "modeling_swipe.SwipeTransformerModel",
|
| 29 |
+
"AutoModelForCausalLM": "modeling_swipe.SwipeTransformerModel"
|
| 30 |
}
|
| 31 |
}
|
configuration_swipe.py
CHANGED
|
@@ -20,6 +20,7 @@ class SwipeTransformerConfig(PretrainedConfig):
|
|
| 20 |
vocab_size (int, optional): Size of vocabulary. Defaults to 100.
|
| 21 |
max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
|
| 22 |
max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
|
|
|
|
| 23 |
predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
|
| 24 |
pad_token_id (int, optional): Padding token ID. Defaults to 0.
|
| 25 |
cls_token_id (int, optional): CLS token ID. Defaults to 1.
|
|
@@ -41,7 +42,10 @@ class SwipeTransformerConfig(PretrainedConfig):
|
|
| 41 |
vocab_size: int = 100,
|
| 42 |
max_path_len: int = 64,
|
| 43 |
max_char_len: int = 38,
|
|
|
|
|
|
|
| 44 |
predict_path: bool = True,
|
|
|
|
| 45 |
pad_token_id: int = 0,
|
| 46 |
cls_token_id: int = 1,
|
| 47 |
sep_token_id: int = 2,
|
|
@@ -63,83 +67,12 @@ class SwipeTransformerConfig(PretrainedConfig):
|
|
| 63 |
self.vocab_size = vocab_size
|
| 64 |
self.max_path_len = max_path_len
|
| 65 |
self.max_char_len = max_char_len
|
|
|
|
| 66 |
|
| 67 |
# Model capabilities
|
|
|
|
| 68 |
self.predict_path = predict_path
|
| 69 |
-
|
| 70 |
-
# Special tokens
|
| 71 |
-
self.cls_token_id = cls_token_id
|
| 72 |
-
self.sep_token_id = sep_token_id
|
| 73 |
-
self.mask_token_id = mask_token_id
|
| 74 |
-
self.unk_token_id = unk_token_id
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class SwipeCrossEncoderConfig(PretrainedConfig):
|
| 78 |
-
"""
|
| 79 |
-
Configuration class for SwipeCrossEncoderForSequenceClassification.
|
| 80 |
-
|
| 81 |
-
This configuration extends the base SwipeTransformer config for use in
|
| 82 |
-
cross-encoder tasks (e.g., path-word similarity scoring).
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
d_model (int, optional): Hidden dimension size. Defaults to 256.
|
| 86 |
-
n_layers (int, optional): Number of transformer layers. Defaults to 4.
|
| 87 |
-
n_heads (int, optional): Number of attention heads. Defaults to 4.
|
| 88 |
-
d_ff (int, optional): Feedforward dimension. Defaults to 1024.
|
| 89 |
-
dropout (float, optional): Dropout rate. Defaults to 0.1.
|
| 90 |
-
vocab_size (int, optional): Size of vocabulary. Defaults to 100.
|
| 91 |
-
max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
|
| 92 |
-
max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
|
| 93 |
-
num_labels (int, optional): Number of classification labels. Defaults to 1.
|
| 94 |
-
problem_type (str, optional): Problem type ('regression' or 'single_label_classification'). Defaults to "regression".
|
| 95 |
-
pad_token_id (int, optional): Padding token ID. Defaults to 0.
|
| 96 |
-
cls_token_id (int, optional): CLS token ID. Defaults to 1.
|
| 97 |
-
sep_token_id (int, optional): SEP token ID. Defaults to 2.
|
| 98 |
-
mask_token_id (int, optional): MASK token ID. Defaults to 3.
|
| 99 |
-
unk_token_id (int, optional): Unknown token ID. Defaults to 4.
|
| 100 |
-
eos_token_id (int, optional): End-of-sequence token ID. Defaults to 5.
|
| 101 |
-
"""
|
| 102 |
-
|
| 103 |
-
model_type = "swipe_cross_encoder"
|
| 104 |
-
|
| 105 |
-
def __init__(
|
| 106 |
-
self,
|
| 107 |
-
d_model: int = 256,
|
| 108 |
-
n_layers: int = 4,
|
| 109 |
-
n_heads: int = 4,
|
| 110 |
-
d_ff: int = 1024,
|
| 111 |
-
dropout: float = 0.1,
|
| 112 |
-
vocab_size: int = 100,
|
| 113 |
-
max_path_len: int = 64,
|
| 114 |
-
max_char_len: int = 38,
|
| 115 |
-
num_labels: int = 1,
|
| 116 |
-
problem_type: str = "regression",
|
| 117 |
-
pad_token_id: int = 0,
|
| 118 |
-
cls_token_id: int = 1,
|
| 119 |
-
sep_token_id: int = 2,
|
| 120 |
-
mask_token_id: int = 3,
|
| 121 |
-
unk_token_id: int = 4,
|
| 122 |
-
eos_token_id: int = 5,
|
| 123 |
-
**kwargs,
|
| 124 |
-
):
|
| 125 |
-
super().__init__(
|
| 126 |
-
pad_token_id=pad_token_id, num_labels=num_labels, eos_token_id=eos_token_id, **kwargs
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
# Model architecture parameters
|
| 130 |
-
self.d_model = d_model
|
| 131 |
-
self.n_layers = n_layers
|
| 132 |
-
self.n_heads = n_heads
|
| 133 |
-
self.d_ff = d_ff
|
| 134 |
-
self.dropout = dropout
|
| 135 |
-
|
| 136 |
-
# Vocabulary and sequence length
|
| 137 |
-
self.vocab_size = vocab_size
|
| 138 |
-
self.max_path_len = max_path_len
|
| 139 |
-
self.max_char_len = max_char_len
|
| 140 |
-
|
| 141 |
-
# Classification parameters
|
| 142 |
-
self.problem_type = problem_type
|
| 143 |
|
| 144 |
# Special tokens
|
| 145 |
self.cls_token_id = cls_token_id
|
|
|
|
| 20 |
vocab_size (int, optional): Size of vocabulary. Defaults to 100.
|
| 21 |
max_path_len (int, optional): Maximum path sequence length. Defaults to 64.
|
| 22 |
max_char_len (int, optional): Maximum character sequence length. Defaults to 38.
|
| 23 |
+
path_input_dim (int, optional): Path feature dimension. Defaults to 6 for (x, y, dx, dy, ds, log_dt).
|
| 24 |
predict_path (bool, optional): Whether to predict path coordinates. Defaults to True.
|
| 25 |
pad_token_id (int, optional): Padding token ID. Defaults to 0.
|
| 26 |
cls_token_id (int, optional): CLS token ID. Defaults to 1.
|
|
|
|
| 42 |
vocab_size: int = 100,
|
| 43 |
max_path_len: int = 64,
|
| 44 |
max_char_len: int = 38,
|
| 45 |
+
path_input_dim: int = 6,
|
| 46 |
+
predict_char: bool = True,
|
| 47 |
predict_path: bool = True,
|
| 48 |
+
predict_length: bool = True,
|
| 49 |
pad_token_id: int = 0,
|
| 50 |
cls_token_id: int = 1,
|
| 51 |
sep_token_id: int = 2,
|
|
|
|
| 67 |
self.vocab_size = vocab_size
|
| 68 |
self.max_path_len = max_path_len
|
| 69 |
self.max_char_len = max_char_len
|
| 70 |
+
self.path_input_dim = path_input_dim
|
| 71 |
|
| 72 |
# Model capabilities
|
| 73 |
+
self.predict_char = predict_char
|
| 74 |
self.predict_path = predict_path
|
| 75 |
+
self.predict_length = predict_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Special tokens
|
| 78 |
self.cls_token_id = cls_token_id
|
embeddings.py
CHANGED
|
@@ -5,16 +5,26 @@ import torch.nn as nn
|
|
| 5 |
|
| 6 |
|
| 7 |
class PathEmbedding(nn.Module):
|
| 8 |
-
"""Embeds path
|
| 9 |
|
| 10 |
-
def __init__(self, d_model: int = 256):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
super().__init__()
|
| 12 |
-
self.projection = nn.Linear(
|
| 13 |
|
| 14 |
def forward(self, path_coords: torch.Tensor) -> torch.Tensor:
|
| 15 |
"""
|
|
|
|
|
|
|
| 16 |
Args:
|
| 17 |
-
path_coords: [batch, seq_len,
|
|
|
|
| 18 |
|
| 19 |
Returns:
|
| 20 |
[batch, seq_len, d_model] embeddings
|
|
@@ -90,12 +100,13 @@ class MixedEmbedding(nn.Module):
|
|
| 90 |
max_char_len: int,
|
| 91 |
d_model: int = 256,
|
| 92 |
dropout: float = 0.1,
|
|
|
|
| 93 |
):
|
| 94 |
super().__init__()
|
| 95 |
self.d_model = d_model
|
| 96 |
|
| 97 |
# Content embeddings
|
| 98 |
-
self.path_embedding = PathEmbedding(d_model)
|
| 99 |
self.char_embedding = CharacterEmbedding(vocab_size, d_model, padding_idx=0)
|
| 100 |
|
| 101 |
# Positional embeddings
|
|
@@ -120,7 +131,8 @@ class MixedEmbedding(nn.Module):
|
|
| 120 |
Create mixed sequence with embeddings.
|
| 121 |
|
| 122 |
Args:
|
| 123 |
-
path_coords: [batch, path_len,
|
|
|
|
| 124 |
char_tokens: [batch, char_len] character token IDs
|
| 125 |
cls_token: [batch, 1] CLS token IDs
|
| 126 |
sep_token: [batch, 1] SEP token IDs
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class PathEmbedding(nn.Module):
|
| 8 |
+
"""Embeds path features (x, y, dx, dy, ds, log_dt) to d_model dimension."""
|
| 9 |
|
| 10 |
+
def __init__(self, d_model: int = 256, input_dim: int = 6):
|
| 11 |
+
"""
|
| 12 |
+
Initialize path embedding layer.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
d_model: Output dimension
|
| 16 |
+
input_dim: Input feature dimension (default: 6 for x, y, dx, dy, ds, log_dt)
|
| 17 |
+
"""
|
| 18 |
super().__init__()
|
| 19 |
+
self.projection = nn.Linear(input_dim, d_model)
|
| 20 |
|
| 21 |
def forward(self, path_coords: torch.Tensor) -> torch.Tensor:
|
| 22 |
"""
|
| 23 |
+
Project path features to d_model dimension.
|
| 24 |
+
|
| 25 |
Args:
|
| 26 |
+
path_coords: [batch, seq_len, input_dim] - path features
|
| 27 |
+
Default: (x, y, dx, dy, ds, log_dt) with input_dim=6
|
| 28 |
|
| 29 |
Returns:
|
| 30 |
[batch, seq_len, d_model] embeddings
|
|
|
|
| 100 |
max_char_len: int,
|
| 101 |
d_model: int = 256,
|
| 102 |
dropout: float = 0.1,
|
| 103 |
+
path_input_dim: int = 6,
|
| 104 |
):
|
| 105 |
super().__init__()
|
| 106 |
self.d_model = d_model
|
| 107 |
|
| 108 |
# Content embeddings
|
| 109 |
+
self.path_embedding = PathEmbedding(d_model, input_dim=path_input_dim)
|
| 110 |
self.char_embedding = CharacterEmbedding(vocab_size, d_model, padding_idx=0)
|
| 111 |
|
| 112 |
# Positional embeddings
|
|
|
|
| 131 |
Create mixed sequence with embeddings.
|
| 132 |
|
| 133 |
Args:
|
| 134 |
+
path_coords: [batch, path_len, path_input_dim] path features
|
| 135 |
+
Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
|
| 136 |
char_tokens: [batch, char_len] character token IDs
|
| 137 |
cls_token: [batch, 1] CLS token IDs
|
| 138 |
sep_token: [batch, 1] SEP token IDs
|
heads.py
CHANGED
|
@@ -32,11 +32,11 @@ class CharacterPredictionHead(nn.Module):
|
|
| 32 |
class PathPredictionHead(nn.Module):
|
| 33 |
"""Prediction head for masked path coordinates."""
|
| 34 |
|
| 35 |
-
def __init__(self, d_model: int):
|
| 36 |
super().__init__()
|
| 37 |
self.dense = nn.Linear(d_model, d_model)
|
| 38 |
self.layer_norm = nn.LayerNorm(d_model)
|
| 39 |
-
self.decoder = nn.Linear(d_model,
|
| 40 |
self.activation = nn.GELU()
|
| 41 |
|
| 42 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
@@ -45,55 +45,38 @@ class PathPredictionHead(nn.Module):
|
|
| 45 |
hidden_states: [batch, seq_len, d_model]
|
| 46 |
|
| 47 |
Returns:
|
| 48 |
-
[batch, seq_len,
|
| 49 |
"""
|
| 50 |
x = self.dense(hidden_states)
|
| 51 |
x = self.activation(x)
|
| 52 |
x = self.layer_norm(x)
|
| 53 |
-
|
| 54 |
-
coords = torch.sigmoid(coords) # Ensure [0, 1] range
|
| 55 |
-
return coords
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
def __init__(self, d_model: int, num_labels: int = 1):
|
| 67 |
-
super().__init__()
|
| 68 |
-
self.dense = nn.Linear(d_model, d_model)
|
| 69 |
-
self.activation = nn.GELU()
|
| 70 |
-
self.norm = nn.LayerNorm(d_model)
|
| 71 |
-
self.classifier = nn.Linear(d_model, num_labels)
|
| 72 |
-
|
| 73 |
-
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 74 |
-
"""
|
| 75 |
-
Args:
|
| 76 |
-
features: [batch, d_model] - typically SEP token embeddings
|
| 77 |
-
|
| 78 |
-
Returns:
|
| 79 |
-
[batch, num_labels] similarity scores
|
| 80 |
-
"""
|
| 81 |
-
x = self.dense(features)
|
| 82 |
-
x = self.activation(x) # GELU
|
| 83 |
-
x = self.norm(x) # LayerNorm
|
| 84 |
-
logits = self.classifier(x) # [batch, 1] or [batch, num_labels]
|
| 85 |
-
return logits
|
| 86 |
|
| 87 |
|
| 88 |
class LengthPredictionHead(nn.Module):
|
| 89 |
-
"""
|
| 90 |
|
| 91 |
-
def __init__(self, d_model: int
|
| 92 |
super().__init__()
|
| 93 |
self.dense = nn.Linear(d_model, d_model)
|
| 94 |
self.activation = nn.GELU()
|
| 95 |
self.norm = nn.LayerNorm(d_model)
|
| 96 |
-
self.
|
| 97 |
|
| 98 |
def forward(self, cls_features: torch.Tensor) -> torch.Tensor:
|
| 99 |
"""
|
|
@@ -101,9 +84,9 @@ class LengthPredictionHead(nn.Module):
|
|
| 101 |
cls_features: [batch, d_model] CLS embeddings
|
| 102 |
|
| 103 |
Returns:
|
| 104 |
-
[batch,
|
| 105 |
"""
|
| 106 |
x = self.dense(cls_features)
|
| 107 |
x = self.activation(x)
|
| 108 |
x = self.norm(x)
|
| 109 |
-
return self.
|
|
|
|
| 32 |
class PathPredictionHead(nn.Module):
|
| 33 |
"""Prediction head for masked path coordinates."""
|
| 34 |
|
| 35 |
+
def __init__(self, d_model: int, output_dim: int = 6):
|
| 36 |
super().__init__()
|
| 37 |
self.dense = nn.Linear(d_model, d_model)
|
| 38 |
self.layer_norm = nn.LayerNorm(d_model)
|
| 39 |
+
self.decoder = nn.Linear(d_model, output_dim)
|
| 40 |
self.activation = nn.GELU()
|
| 41 |
|
| 42 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 45 |
hidden_states: [batch, seq_len, d_model]
|
| 46 |
|
| 47 |
Returns:
|
| 48 |
+
[batch, seq_len, output_dim] path features.
|
| 49 |
"""
|
| 50 |
x = self.dense(hidden_states)
|
| 51 |
x = self.activation(x)
|
| 52 |
x = self.layer_norm(x)
|
| 53 |
+
features = self.decoder(x)
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
# Per-feature constraints:
|
| 56 |
+
# - x, y are normalized to [0,1]
|
| 57 |
+
# - dx, dy are signed deltas (roughly [-1,1])
|
| 58 |
+
# - ds is non-negative
|
| 59 |
+
# - log_dt is non-negative
|
| 60 |
+
if features.shape[-1] == 6:
|
| 61 |
+
x_y = torch.sigmoid(features[..., 0:2])
|
| 62 |
+
dx_dy = torch.tanh(features[..., 2:4])
|
| 63 |
+
ds = torch.nn.functional.softplus(features[..., 4:5])
|
| 64 |
+
log_dt = torch.nn.functional.softplus(features[..., 5:6])
|
| 65 |
+
return torch.cat([x_y, dx_dy, ds, log_dt], dim=-1)
|
| 66 |
|
| 67 |
+
# Fallback: unconstrained regression for other output dims.
|
| 68 |
+
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
class LengthPredictionHead(nn.Module):
|
| 72 |
+
"""Regress sequence length (e.g., swipable character count) from CLS embedding."""
|
| 73 |
|
| 74 |
+
def __init__(self, d_model: int):
|
| 75 |
super().__init__()
|
| 76 |
self.dense = nn.Linear(d_model, d_model)
|
| 77 |
self.activation = nn.GELU()
|
| 78 |
self.norm = nn.LayerNorm(d_model)
|
| 79 |
+
self.regressor = nn.Linear(d_model, 1) # predict expected length directly
|
| 80 |
|
| 81 |
def forward(self, cls_features: torch.Tensor) -> torch.Tensor:
|
| 82 |
"""
|
|
|
|
| 84 |
cls_features: [batch, d_model] CLS embeddings
|
| 85 |
|
| 86 |
Returns:
|
| 87 |
+
[batch, 1] predicted length
|
| 88 |
"""
|
| 89 |
x = self.dense(cls_features)
|
| 90 |
x = self.activation(x)
|
| 91 |
x = self.norm(x)
|
| 92 |
+
return self.regressor(x).squeeze(-1)
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:452880eddd71bf06c0c61750ae3f5ba65670c15aa97da8233df856ac74bd3e56
|
| 3 |
+
size 348207344
|
modeling_swipe.py
CHANGED
|
@@ -1,19 +1,13 @@
|
|
| 1 |
"""HuggingFace-compatible model classes for SwipeTransformer."""
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
-
from typing import Optional, Tuple
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from transformers import PreTrainedModel
|
| 9 |
-
from transformers.modeling_outputs import
|
| 10 |
-
BaseModelOutput,
|
| 11 |
-
BaseModelOutputWithPooling,
|
| 12 |
-
ModelOutput,
|
| 13 |
-
SequenceClassifierOutput,
|
| 14 |
-
)
|
| 15 |
|
| 16 |
-
from .configuration_swipe import
|
| 17 |
|
| 18 |
|
| 19 |
@dataclass
|
|
@@ -24,27 +18,32 @@ class SwipeTransformerOutput(ModelOutput):
|
|
| 24 |
Args:
|
| 25 |
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 26 |
Language modeling loss (character prediction).
|
| 27 |
-
char_logits (`torch.FloatTensor` of shape `(batch_size,
|
| 28 |
-
Prediction scores of the character prediction head.
|
| 29 |
-
path_logits (`torch.FloatTensor` of shape `(batch_size,
|
| 30 |
-
Prediction scores of the path prediction head (if enabled).
|
| 31 |
-
length_logits (`torch.FloatTensor` of shape `(batch_size,
|
| 32 |
-
|
| 33 |
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 34 |
Sequence of hidden-states at the output of the last layer of the model.
|
| 35 |
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
| 36 |
SEP token embeddings for similarity/embedding tasks.
|
| 37 |
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
| 38 |
-
Tuple of `torch.FloatTensor`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
|
| 41 |
-
loss:
|
| 42 |
-
char_logits: torch.FloatTensor = None
|
| 43 |
-
path_logits:
|
| 44 |
-
length_logits:
|
| 45 |
-
last_hidden_state: torch.FloatTensor = None
|
| 46 |
-
pooler_output:
|
| 47 |
-
hidden_states:
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
class SwipeTransformerPreTrainedModel(PreTrainedModel):
|
|
@@ -98,6 +97,7 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
|
|
| 98 |
max_char_len=config.max_char_len,
|
| 99 |
d_model=config.d_model,
|
| 100 |
dropout=config.dropout,
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
# Transformer encoder
|
|
@@ -117,21 +117,26 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
|
|
| 117 |
)
|
| 118 |
|
| 119 |
# Prediction heads
|
| 120 |
-
self.char_head =
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
if config.predict_path:
|
| 126 |
-
self.path_head = PathPredictionHead(
|
|
|
|
|
|
|
| 127 |
else:
|
| 128 |
self.path_head = None
|
| 129 |
|
| 130 |
# Length prediction head (predicts word length from path)
|
| 131 |
# Max length is max_char_len (including EOS)
|
| 132 |
-
self.length_head =
|
| 133 |
-
d_model=config.d_model
|
| 134 |
-
max_length=config.max_char_len,
|
| 135 |
)
|
| 136 |
|
| 137 |
# Initialize weights
|
|
@@ -139,35 +144,61 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
|
|
| 139 |
|
| 140 |
def forward(
|
| 141 |
self,
|
| 142 |
-
path_coords: torch.Tensor,
|
| 143 |
input_ids: torch.Tensor,
|
|
|
|
| 144 |
attention_mask: torch.Tensor | None = None,
|
| 145 |
-
labels: torch.Tensor | None = None,
|
| 146 |
return_dict: bool | None = None,
|
| 147 |
output_hidden_states: bool | None = None,
|
|
|
|
|
|
|
| 148 |
):
|
| 149 |
"""
|
| 150 |
Forward pass of the model.
|
| 151 |
|
| 152 |
Args:
|
| 153 |
-
path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
|
| 154 |
input_ids (torch.Tensor): Character token IDs [batch, char_len]
|
|
|
|
|
|
|
| 155 |
attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
|
| 156 |
-
labels (torch.Tensor, optional): Labels for loss calculation
|
|
|
|
| 157 |
return_dict (bool, optional): Whether to return ModelOutput object
|
| 158 |
output_hidden_states (bool, optional): Whether to output hidden states
|
|
|
|
|
|
|
| 159 |
|
| 160 |
Returns:
|
| 161 |
SwipeTransformerOutput or tuple: Model outputs with:
|
| 162 |
- loss: Optional loss value
|
| 163 |
-
- char_logits: Character prediction logits [batch,
|
| 164 |
-
- path_logits: Path prediction logits [batch,
|
| 165 |
-
- length_logits: Length
|
| 166 |
- last_hidden_state: Hidden states [batch, seq_len, d_model]
|
| 167 |
-
- pooler_output: SEP token
|
| 168 |
-
- hidden_states: Tuple of hidden states (if output_hidden_states=True)
|
|
|
|
| 169 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
batch_size = path_coords.shape[0]
|
| 173 |
device = path_coords.device
|
|
@@ -191,43 +222,138 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
|
|
| 191 |
else:
|
| 192 |
src_key_padding_mask = None
|
| 193 |
|
| 194 |
-
# Encode
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
# Character prediction
|
| 198 |
-
char_logits =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
# Path prediction (if enabled)
|
| 201 |
path_logits = None
|
| 202 |
if self.path_head is not None:
|
| 203 |
-
|
|
|
|
| 204 |
|
| 205 |
# Length prediction from CLS token
|
| 206 |
cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
|
| 207 |
-
length_logits = self.length_head(cls_hidden)
|
| 208 |
|
| 209 |
# Extract SEP token embedding for pooler output (embeddings/similarity tasks)
|
| 210 |
# SEP is at position 1 + path_len
|
| 211 |
-
path_len = path_coords.shape[1]
|
| 212 |
sep_position = 1 + path_len
|
| 213 |
pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
|
| 214 |
|
| 215 |
-
# Compute loss if labels provided
|
| 216 |
loss = None
|
| 217 |
-
if
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
if not return_dict:
|
| 227 |
-
|
| 228 |
-
if
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
return SwipeTransformerOutput(
|
| 233 |
loss=loss,
|
|
@@ -236,281 +362,12 @@ class SwipeTransformerModel(SwipeTransformerPreTrainedModel):
|
|
| 236 |
length_logits=length_logits,
|
| 237 |
last_hidden_state=hidden_states,
|
| 238 |
pooler_output=pooler_output,
|
| 239 |
-
hidden_states=
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
class SwipeCrossEncoderForSequenceClassification(SwipeTransformerPreTrainedModel):
|
| 244 |
-
"""
|
| 245 |
-
HuggingFace-compatible cross-encoder for sequence classification.
|
| 246 |
-
|
| 247 |
-
This model is designed for similarity scoring between swipe paths and words.
|
| 248 |
-
It extracts the SEP token embedding and passes it through a classification head.
|
| 249 |
-
|
| 250 |
-
Args:
|
| 251 |
-
config (SwipeCrossEncoderConfig): Model configuration
|
| 252 |
-
"""
|
| 253 |
-
|
| 254 |
-
config_class = SwipeCrossEncoderConfig
|
| 255 |
-
base_model_prefix = "swipe_cross_encoder"
|
| 256 |
-
|
| 257 |
-
def __init__(self, config: SwipeCrossEncoderConfig):
|
| 258 |
-
super().__init__(config)
|
| 259 |
-
self.config = config
|
| 260 |
-
self.num_labels = config.num_labels
|
| 261 |
-
|
| 262 |
-
# Import existing components
|
| 263 |
-
from .embeddings import MixedEmbedding
|
| 264 |
-
from .heads import ClassificationHead
|
| 265 |
-
|
| 266 |
-
# Embeddings
|
| 267 |
-
self.embeddings = MixedEmbedding(
|
| 268 |
-
vocab_size=config.vocab_size,
|
| 269 |
-
max_path_len=config.max_path_len,
|
| 270 |
-
max_char_len=config.max_char_len,
|
| 271 |
-
d_model=config.d_model,
|
| 272 |
-
dropout=config.dropout,
|
| 273 |
)
|
| 274 |
|
| 275 |
-
# Transformer encoder
|
| 276 |
-
encoder_layer = nn.TransformerEncoderLayer(
|
| 277 |
-
d_model=config.d_model,
|
| 278 |
-
nhead=config.n_heads,
|
| 279 |
-
dim_feedforward=config.d_ff,
|
| 280 |
-
dropout=config.dropout,
|
| 281 |
-
activation="gelu",
|
| 282 |
-
batch_first=True,
|
| 283 |
-
norm_first=True, # Pre-LayerNorm
|
| 284 |
-
)
|
| 285 |
-
self.encoder = nn.TransformerEncoder(
|
| 286 |
-
encoder_layer,
|
| 287 |
-
num_layers=config.n_layers,
|
| 288 |
-
enable_nested_tensor=False,
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
-
# Classification head
|
| 292 |
-
self.classifier = ClassificationHead(
|
| 293 |
-
d_model=config.d_model,
|
| 294 |
-
num_labels=config.num_labels,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# Initialize weights
|
| 298 |
-
self.post_init()
|
| 299 |
-
|
| 300 |
-
def forward(
|
| 301 |
-
self,
|
| 302 |
-
path_coords: torch.Tensor,
|
| 303 |
-
input_ids: torch.Tensor,
|
| 304 |
-
attention_mask: torch.Tensor | None = None,
|
| 305 |
-
labels: torch.Tensor | None = None,
|
| 306 |
-
return_dict: bool | None = None,
|
| 307 |
-
):
|
| 308 |
-
"""
|
| 309 |
-
Forward pass for cross-encoder.
|
| 310 |
-
|
| 311 |
-
Args:
|
| 312 |
-
path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
|
| 313 |
-
input_ids (torch.Tensor): Character token IDs [batch, char_len]
|
| 314 |
-
attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
|
| 315 |
-
labels (torch.Tensor, optional): Labels for loss calculation [batch, num_labels]
|
| 316 |
-
return_dict (bool, optional): Whether to return ModelOutput object
|
| 317 |
-
|
| 318 |
-
Returns:
|
| 319 |
-
SequenceClassifierOutput or tuple: Model outputs with logits and optional loss
|
| 320 |
-
"""
|
| 321 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 322 |
-
|
| 323 |
-
batch_size = path_coords.shape[0]
|
| 324 |
-
device = path_coords.device
|
| 325 |
-
|
| 326 |
-
# Create [CLS] and [SEP] tokens
|
| 327 |
-
cls_token = torch.full(
|
| 328 |
-
(batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
|
| 329 |
-
)
|
| 330 |
-
sep_token = torch.full(
|
| 331 |
-
(batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
# Get embeddings
|
| 335 |
-
embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
|
| 336 |
-
|
| 337 |
-
# Prepare attention mask
|
| 338 |
-
if attention_mask is not None:
|
| 339 |
-
src_key_padding_mask = attention_mask == 0
|
| 340 |
-
else:
|
| 341 |
-
src_key_padding_mask = None
|
| 342 |
-
|
| 343 |
-
# Encode (batch_first=True is set in TransformerEncoderLayer)
|
| 344 |
-
hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
|
| 345 |
-
|
| 346 |
-
# Extract SEP token embedding
|
| 347 |
-
# SEP is at position 1 + path_len
|
| 348 |
-
path_len = path_coords.shape[1]
|
| 349 |
-
sep_position = 1 + path_len
|
| 350 |
-
sep_embedding = hidden_states[:, sep_position, :] # [batch, d_model]
|
| 351 |
-
|
| 352 |
-
# Classification
|
| 353 |
-
logits = self.classifier(sep_embedding) # [batch, num_labels]
|
| 354 |
-
|
| 355 |
-
# Compute loss if labels provided
|
| 356 |
-
loss = None
|
| 357 |
-
if labels is not None:
|
| 358 |
-
if self.config.problem_type is None:
|
| 359 |
-
if self.num_labels == 1:
|
| 360 |
-
self.config.problem_type = "regression"
|
| 361 |
-
else:
|
| 362 |
-
self.config.problem_type = "single_label_classification"
|
| 363 |
-
|
| 364 |
-
if self.config.problem_type == "regression":
|
| 365 |
-
loss_fct = nn.MSELoss()
|
| 366 |
-
if self.num_labels == 1:
|
| 367 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 368 |
-
else:
|
| 369 |
-
loss = loss_fct(logits, labels)
|
| 370 |
-
elif self.config.problem_type == "single_label_classification":
|
| 371 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 372 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
return SequenceClassifierOutput(
|
| 379 |
-
loss=loss,
|
| 380 |
-
logits=logits,
|
| 381 |
-
hidden_states=(hidden_states,),
|
| 382 |
-
)
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
class SwipeModel(SwipeTransformerPreTrainedModel):
|
| 386 |
-
"""
|
| 387 |
-
Base Swipe model for extracting embeddings.
|
| 388 |
-
|
| 389 |
-
.. deprecated::
|
| 390 |
-
This class is deprecated. Use SwipeTransformerModel instead, which now
|
| 391 |
-
includes pooler_output for embeddings alongside prediction heads.
|
| 392 |
-
SwipeTransformerModel provides both predictions AND embeddings in a single model.
|
| 393 |
-
|
| 394 |
-
This model returns the SEP token embedding, which can be used for:
|
| 395 |
-
- Vector databases
|
| 396 |
-
- Semantic search
|
| 397 |
-
- Similarity computation
|
| 398 |
-
|
| 399 |
-
The SEP token embedding represents the joint encoding of the path and text.
|
| 400 |
-
|
| 401 |
-
Usage (Deprecated - use SwipeTransformerModel instead):
|
| 402 |
-
```python
|
| 403 |
-
from transformers import AutoModel
|
| 404 |
-
|
| 405 |
-
model = AutoModel.from_pretrained(
|
| 406 |
-
"your-username/swipe-model",
|
| 407 |
-
trust_remote_code=True
|
| 408 |
-
)
|
| 409 |
-
|
| 410 |
-
# Get embeddings
|
| 411 |
-
outputs = model(path_coords=paths, input_ids=tokens)
|
| 412 |
-
embeddings = outputs.pooler_output # SEP token embeddings
|
| 413 |
-
```
|
| 414 |
-
|
| 415 |
-
Args:
|
| 416 |
-
config (SwipeTransformerConfig or SwipeCrossEncoderConfig): Model configuration
|
| 417 |
-
"""
|
| 418 |
-
|
| 419 |
-
def __init__(self, config):
|
| 420 |
-
super().__init__(config)
|
| 421 |
-
self.config = config
|
| 422 |
-
|
| 423 |
-
# Import existing components
|
| 424 |
-
from .embeddings import MixedEmbedding
|
| 425 |
-
|
| 426 |
-
# Embeddings
|
| 427 |
-
self.embeddings = MixedEmbedding(
|
| 428 |
-
vocab_size=config.vocab_size,
|
| 429 |
-
max_path_len=config.max_path_len,
|
| 430 |
-
max_char_len=config.max_char_len,
|
| 431 |
-
d_model=config.d_model,
|
| 432 |
-
dropout=config.dropout,
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
# Transformer encoder
|
| 436 |
-
encoder_layer = nn.TransformerEncoderLayer(
|
| 437 |
-
d_model=config.d_model,
|
| 438 |
-
nhead=config.n_heads,
|
| 439 |
-
dim_feedforward=config.d_ff,
|
| 440 |
-
dropout=config.dropout,
|
| 441 |
-
activation="gelu",
|
| 442 |
-
batch_first=True,
|
| 443 |
-
norm_first=True, # Pre-LayerNorm
|
| 444 |
-
)
|
| 445 |
-
self.encoder = nn.TransformerEncoder(
|
| 446 |
-
encoder_layer,
|
| 447 |
-
num_layers=config.n_layers,
|
| 448 |
-
enable_nested_tensor=False,
|
| 449 |
-
)
|
| 450 |
-
|
| 451 |
-
# Initialize weights
|
| 452 |
-
self.post_init()
|
| 453 |
-
|
| 454 |
-
def forward(
|
| 455 |
-
self,
|
| 456 |
-
path_coords: torch.Tensor,
|
| 457 |
-
input_ids: torch.Tensor,
|
| 458 |
-
attention_mask: torch.Tensor | None = None,
|
| 459 |
-
return_dict: bool | None = None,
|
| 460 |
-
output_hidden_states: bool | None = None,
|
| 461 |
-
):
|
| 462 |
-
"""
|
| 463 |
-
Forward pass that returns embeddings.
|
| 464 |
-
|
| 465 |
-
Args:
|
| 466 |
-
path_coords (torch.Tensor): Path coordinates [batch, path_len, 3]
|
| 467 |
-
input_ids (torch.Tensor): Character token IDs [batch, char_len]
|
| 468 |
-
attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
|
| 469 |
-
return_dict (bool, optional): Whether to return ModelOutput object
|
| 470 |
-
output_hidden_states (bool, optional): Whether to output all hidden states
|
| 471 |
-
|
| 472 |
-
Returns:
|
| 473 |
-
BaseModelOutputWithPooling with:
|
| 474 |
-
- last_hidden_state: Full sequence hidden states [batch, seq_len, d_model]
|
| 475 |
-
- pooler_output: SEP token embeddings [batch, d_model]
|
| 476 |
-
- hidden_states: Tuple of hidden states (if output_hidden_states=True)
|
| 477 |
-
"""
|
| 478 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 479 |
-
|
| 480 |
-
batch_size = path_coords.shape[0]
|
| 481 |
-
device = path_coords.device
|
| 482 |
-
|
| 483 |
-
# Create [CLS] and [SEP] tokens
|
| 484 |
-
cls_token = torch.full(
|
| 485 |
-
(batch_size, 1), fill_value=self.config.cls_token_id, dtype=torch.long, device=device
|
| 486 |
-
)
|
| 487 |
-
sep_token = torch.full(
|
| 488 |
-
(batch_size, 1), fill_value=self.config.sep_token_id, dtype=torch.long, device=device
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
# Get embeddings
|
| 492 |
-
embeddings = self.embeddings(path_coords, input_ids, cls_token, sep_token)
|
| 493 |
-
|
| 494 |
-
# Prepare attention mask
|
| 495 |
-
if attention_mask is not None:
|
| 496 |
-
src_key_padding_mask = attention_mask == 0
|
| 497 |
-
else:
|
| 498 |
-
src_key_padding_mask = None
|
| 499 |
-
|
| 500 |
-
# Encode (batch_first=True is set in TransformerEncoderLayer)
|
| 501 |
-
hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
|
| 502 |
-
|
| 503 |
-
# Extract SEP token embedding (pooler output)
|
| 504 |
-
# SEP is at position 1 + path_len
|
| 505 |
-
path_len = path_coords.shape[1]
|
| 506 |
-
sep_position = 1 + path_len
|
| 507 |
-
pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
|
| 508 |
-
|
| 509 |
-
if not return_dict:
|
| 510 |
-
return (hidden_states, pooler_output)
|
| 511 |
-
|
| 512 |
-
return BaseModelOutputWithPooling(
|
| 513 |
-
last_hidden_state=hidden_states,
|
| 514 |
-
pooler_output=pooler_output,
|
| 515 |
-
hidden_states=(hidden_states,) if output_hidden_states else None,
|
| 516 |
-
)
|
|
|
|
| 1 |
"""HuggingFace-compatible model classes for SwipeTransformer."""
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import PreTrainedModel
|
| 8 |
+
from transformers.modeling_outputs import ModelOutput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
from .configuration_swipe import SwipeTransformerConfig
|
| 11 |
|
| 12 |
|
| 13 |
@dataclass
|
|
|
|
| 18 |
Args:
|
| 19 |
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 20 |
Language modeling loss (character prediction).
|
| 21 |
+
char_logits (`torch.FloatTensor` of shape `(batch_size, char_length, vocab_size)`):
|
| 22 |
+
Prediction scores of the character prediction head (text segment only).
|
| 23 |
+
path_logits (`torch.FloatTensor` of shape `(batch_size, path_length, path_input_dim)`, *optional*):
|
| 24 |
+
Prediction scores of the path prediction head (path segment only, if enabled).
|
| 25 |
+
length_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
|
| 26 |
+
Predicted length from the length head (if enabled).
|
| 27 |
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 28 |
Sequence of hidden-states at the output of the last layer of the model.
|
| 29 |
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
| 30 |
SEP token embeddings for similarity/embedding tasks.
|
| 31 |
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
|
| 32 |
+
Tuple of `torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`.
|
| 33 |
+
When requested, this includes the input embeddings plus one entry per encoder layer.
|
| 34 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
| 35 |
+
Tuple of attention tensors (one for each layer) of shape
|
| 36 |
+
`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 37 |
"""
|
| 38 |
|
| 39 |
+
loss: torch.FloatTensor | None = None
|
| 40 |
+
char_logits: torch.FloatTensor | None = None
|
| 41 |
+
path_logits: torch.FloatTensor | None = None
|
| 42 |
+
length_logits: torch.FloatTensor | None = None
|
| 43 |
+
last_hidden_state: torch.FloatTensor | None = None
|
| 44 |
+
pooler_output: torch.FloatTensor | None = None
|
| 45 |
+
hidden_states: tuple[torch.FloatTensor] | None = None
|
| 46 |
+
attentions: tuple[torch.FloatTensor] | None = None
|
| 47 |
|
| 48 |
|
| 49 |
class SwipeTransformerPreTrainedModel(PreTrainedModel):
|
|
|
|
| 97 |
max_char_len=config.max_char_len,
|
| 98 |
d_model=config.d_model,
|
| 99 |
dropout=config.dropout,
|
| 100 |
+
path_input_dim=config.path_input_dim,
|
| 101 |
)
|
| 102 |
|
| 103 |
# Transformer encoder
|
|
|
|
| 117 |
)
|
| 118 |
|
| 119 |
# Prediction heads
|
| 120 |
+
self.char_head = (
|
| 121 |
+
CharacterPredictionHead(
|
| 122 |
+
d_model=config.d_model,
|
| 123 |
+
vocab_size=config.vocab_size,
|
| 124 |
+
)
|
| 125 |
+
if config.predict_char
|
| 126 |
+
else None
|
| 127 |
)
|
| 128 |
|
| 129 |
if config.predict_path:
|
| 130 |
+
self.path_head = PathPredictionHead(
|
| 131 |
+
d_model=config.d_model, output_dim=config.path_input_dim
|
| 132 |
+
)
|
| 133 |
else:
|
| 134 |
self.path_head = None
|
| 135 |
|
| 136 |
# Length prediction head (predicts word length from path)
|
| 137 |
# Max length is max_char_len (including EOS)
|
| 138 |
+
self.length_head = (
|
| 139 |
+
LengthPredictionHead(d_model=config.d_model) if config.predict_length else None
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
# Initialize weights
|
|
|
|
| 144 |
|
| 145 |
def forward(
|
| 146 |
self,
|
|
|
|
| 147 |
input_ids: torch.Tensor,
|
| 148 |
+
path_coords: torch.Tensor,
|
| 149 |
attention_mask: torch.Tensor | None = None,
|
| 150 |
+
labels: torch.Tensor | dict | None = None,
|
| 151 |
return_dict: bool | None = None,
|
| 152 |
output_hidden_states: bool | None = None,
|
| 153 |
+
output_attentions: bool | None = None,
|
| 154 |
+
**kwargs,
|
| 155 |
):
|
| 156 |
"""
|
| 157 |
Forward pass of the model.
|
| 158 |
|
| 159 |
Args:
|
|
|
|
| 160 |
input_ids (torch.Tensor): Character token IDs [batch, char_len]
|
| 161 |
+
path_coords (torch.Tensor): Path features [batch, path_len, path_input_dim]
|
| 162 |
+
Default: [batch, path_len, 6] for (x, y, dx, dy, ds, log_dt)
|
| 163 |
attention_mask (torch.Tensor, optional): Attention mask [batch, seq_len]
|
| 164 |
+
labels (torch.Tensor or dict, optional): Labels for loss calculation
|
| 165 |
+
Can be tensor [batch, char_len] or dict with keys like char_labels, path_labels
|
| 166 |
return_dict (bool, optional): Whether to return ModelOutput object
|
| 167 |
output_hidden_states (bool, optional): Whether to output hidden states
|
| 168 |
+
output_attentions (bool, optional): Whether to output attention weights
|
| 169 |
+
**kwargs: Additional arguments (for compatibility)
|
| 170 |
|
| 171 |
Returns:
|
| 172 |
SwipeTransformerOutput or tuple: Model outputs with:
|
| 173 |
- loss: Optional loss value
|
| 174 |
+
- char_logits: Character prediction logits [batch, char_len, vocab_size] (if enabled)
|
| 175 |
+
- path_logits: Path prediction logits [batch, path_len, path_input_dim] (if enabled)
|
| 176 |
+
- length_logits: Length regression output [batch] (if enabled)
|
| 177 |
- last_hidden_state: Hidden states [batch, seq_len, d_model]
|
| 178 |
+
- pooler_output: SEP token embedding [batch, d_model] for similarity/embedding tasks
|
| 179 |
+
- hidden_states: Tuple of per-layer hidden states (if output_hidden_states=True)
|
| 180 |
+
- attentions: Tuple of per-layer attention weights (if output_attentions=True)
|
| 181 |
"""
|
| 182 |
+
# Validate required inputs
|
| 183 |
+
if input_ids is None or path_coords is None:
|
| 184 |
+
raise ValueError("Both input_ids and path_coords are required")
|
| 185 |
+
|
| 186 |
+
# Extract labels if dict (used by custom trainers)
|
| 187 |
+
if isinstance(labels, dict):
|
| 188 |
+
char_labels = labels.get("char_labels")
|
| 189 |
+
# Can handle other label types in the future (path_labels, etc.)
|
| 190 |
+
else:
|
| 191 |
+
char_labels = labels
|
| 192 |
+
|
| 193 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 194 |
+
output_hidden_states = (
|
| 195 |
+
output_hidden_states
|
| 196 |
+
if output_hidden_states is not None
|
| 197 |
+
else self.config.output_hidden_states
|
| 198 |
+
)
|
| 199 |
+
output_attentions = (
|
| 200 |
+
output_attentions if output_attentions is not None else self.config.output_attentions
|
| 201 |
+
)
|
| 202 |
|
| 203 |
batch_size = path_coords.shape[0]
|
| 204 |
device = path_coords.device
|
|
|
|
| 222 |
else:
|
| 223 |
src_key_padding_mask = None
|
| 224 |
|
| 225 |
+
# Encode while optionally capturing attentions and per-layer hidden states.
|
| 226 |
+
attentions: tuple[torch.Tensor, ...] | None = None
|
| 227 |
+
hidden_states_by_layer: list[torch.Tensor] | None = [] if output_hidden_states else None
|
| 228 |
+
|
| 229 |
+
hooks = []
|
| 230 |
+
original_forwards: dict[int, callable] = {}
|
| 231 |
+
attentions_buffer: list[torch.Tensor | None] | None = None
|
| 232 |
+
|
| 233 |
+
def make_patched_forward(original_forward):
|
| 234 |
+
def patched_forward(
|
| 235 |
+
query,
|
| 236 |
+
key,
|
| 237 |
+
value,
|
| 238 |
+
key_padding_mask=None,
|
| 239 |
+
need_weights=True,
|
| 240 |
+
attn_mask=None,
|
| 241 |
+
average_attn_weights=False,
|
| 242 |
+
is_causal=False,
|
| 243 |
+
):
|
| 244 |
+
return original_forward(
|
| 245 |
+
query,
|
| 246 |
+
key,
|
| 247 |
+
value,
|
| 248 |
+
key_padding_mask=key_padding_mask,
|
| 249 |
+
need_weights=True,
|
| 250 |
+
attn_mask=attn_mask,
|
| 251 |
+
average_attn_weights=False,
|
| 252 |
+
is_causal=is_causal,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
return patched_forward
|
| 256 |
+
|
| 257 |
+
def make_hook(layer_idx: int):
|
| 258 |
+
def hook(_module: nn.Module, _input: tuple, output: tuple):
|
| 259 |
+
if (
|
| 260 |
+
attentions_buffer is not None
|
| 261 |
+
and isinstance(output, tuple)
|
| 262 |
+
and len(output) > 1
|
| 263 |
+
and output[1] is not None
|
| 264 |
+
):
|
| 265 |
+
attentions_buffer[layer_idx] = output[1]
|
| 266 |
+
|
| 267 |
+
return hook
|
| 268 |
+
|
| 269 |
+
if output_attentions:
|
| 270 |
+
attentions_buffer = [None] * len(self.encoder.layers)
|
| 271 |
+
for idx, layer in enumerate(self.encoder.layers):
|
| 272 |
+
attn_module = layer.self_attn
|
| 273 |
+
original_forwards[idx] = attn_module.forward
|
| 274 |
+
attn_module.forward = make_patched_forward(original_forwards[idx])
|
| 275 |
+
hooks.append(attn_module.register_forward_hook(make_hook(idx)))
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
x = embeddings
|
| 279 |
+
for layer in self.encoder.layers:
|
| 280 |
+
x = layer(x, src_key_padding_mask=src_key_padding_mask)
|
| 281 |
+
if hidden_states_by_layer is not None:
|
| 282 |
+
hidden_states_by_layer.append(x)
|
| 283 |
+
hidden_states = x
|
| 284 |
+
|
| 285 |
+
if attentions_buffer is not None:
|
| 286 |
+
if any(a is None for a in attentions_buffer):
|
| 287 |
+
missing = [i for i, a in enumerate(attentions_buffer) if a is None]
|
| 288 |
+
raise RuntimeError(
|
| 289 |
+
f"Failed to capture attention weights for layers: {missing}."
|
| 290 |
+
)
|
| 291 |
+
attentions = tuple(attentions_buffer) # type: ignore[assignment]
|
| 292 |
+
finally:
|
| 293 |
+
for hook in hooks:
|
| 294 |
+
hook.remove()
|
| 295 |
+
for idx, layer in enumerate(self.encoder.layers):
|
| 296 |
+
if idx in original_forwards:
|
| 297 |
+
layer.self_attn.forward = original_forwards[idx]
|
| 298 |
+
|
| 299 |
+
path_len = path_coords.shape[1]
|
| 300 |
+
char_len = input_ids.shape[1]
|
| 301 |
|
| 302 |
+
# Character prediction (text segment only)
|
| 303 |
+
char_logits = None
|
| 304 |
+
if self.char_head is not None:
|
| 305 |
+
# Sequence is: [CLS] + path + [SEP] + chars
|
| 306 |
+
char_start = 1 + path_len + 1
|
| 307 |
+
char_hidden = hidden_states[:, char_start : char_start + char_len, :]
|
| 308 |
+
char_logits = self.char_head(char_hidden)
|
| 309 |
|
| 310 |
+
# Path prediction (path segment only, if enabled)
|
| 311 |
path_logits = None
|
| 312 |
if self.path_head is not None:
|
| 313 |
+
path_hidden = hidden_states[:, 1 : 1 + path_len, :]
|
| 314 |
+
path_logits = self.path_head(path_hidden)
|
| 315 |
|
| 316 |
# Length prediction from CLS token
|
| 317 |
cls_hidden = hidden_states[:, 0, :] # [batch, d_model] - CLS at position 0
|
| 318 |
+
length_logits = self.length_head(cls_hidden) if self.length_head is not None else None
|
| 319 |
|
| 320 |
# Extract SEP token embedding for pooler output (embeddings/similarity tasks)
|
| 321 |
# SEP is at position 1 + path_len
|
|
|
|
| 322 |
sep_position = 1 + path_len
|
| 323 |
pooler_output = hidden_states[:, sep_position, :] # [batch, d_model]
|
| 324 |
|
| 325 |
+
# Compute loss if labels provided (masked-only; -100 = ignore)
|
| 326 |
loss = None
|
| 327 |
+
if char_labels is not None and self.char_head is not None:
|
| 328 |
+
# Predict only the text segment
|
| 329 |
+
char_pred = char_logits # [B, char_len, V]
|
| 330 |
+
labels_flat = char_labels.reshape(-1)
|
| 331 |
+
mask = labels_flat != -100
|
| 332 |
+
if mask.any():
|
| 333 |
+
logits_flat = char_pred.reshape(-1, self.config.vocab_size)[mask]
|
| 334 |
+
labels_flat = labels_flat[mask]
|
| 335 |
+
loss = nn.functional.cross_entropy(logits_flat, labels_flat, reduction="mean")
|
| 336 |
+
else:
|
| 337 |
+
loss = torch.tensor(0.0, device=hidden_states.device)
|
| 338 |
|
| 339 |
if not return_dict:
|
| 340 |
+
hidden_tuple = None
|
| 341 |
+
if hidden_states_by_layer is not None:
|
| 342 |
+
hidden_tuple = (embeddings,) + tuple(hidden_states_by_layer)
|
| 343 |
+
output = (
|
| 344 |
+
char_logits,
|
| 345 |
+
path_logits,
|
| 346 |
+
length_logits,
|
| 347 |
+
hidden_states,
|
| 348 |
+
pooler_output,
|
| 349 |
+
hidden_tuple,
|
| 350 |
+
attentions,
|
| 351 |
+
)
|
| 352 |
+
return (loss,) + output if loss is not None else output
|
| 353 |
+
|
| 354 |
+
all_hidden_states = None
|
| 355 |
+
if hidden_states_by_layer is not None:
|
| 356 |
+
all_hidden_states = (embeddings,) + tuple(hidden_states_by_layer)
|
| 357 |
|
| 358 |
return SwipeTransformerOutput(
|
| 359 |
loss=loss,
|
|
|
|
| 362 |
length_logits=length_logits,
|
| 363 |
last_hidden_state=hidden_states,
|
| 364 |
pooler_output=pooler_output,
|
| 365 |
+
hidden_states=all_hidden_states,
|
| 366 |
+
attentions=attentions,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
)
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
#
|
| 371 |
+
# Legacy note:
|
| 372 |
+
# `SwipeModel` (embeddings-only) has been removed; use `SwipeTransformerModel` and read
|
| 373 |
+
# `outputs.pooler_output` for embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
preprocessing.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared preprocessing utilities for swipe path data.
|
| 2 |
+
|
| 3 |
+
This module provides a single source of truth for path preprocessing,
|
| 4 |
+
used by both the training dataset and the HuggingFace processor.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def preprocess_raw_path_to_features(
|
| 11 |
+
data_points: list[dict],
|
| 12 |
+
max_len: int,
|
| 13 |
+
*,
|
| 14 |
+
resample_mode: str = "spatial",
|
| 15 |
+
dt_clamp_min_ms: float = 1.0,
|
| 16 |
+
dt_clamp_max_ms: float = 200.0,
|
| 17 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 18 |
+
"""Convert a raw `{"x","y","t"}` path to fixed-length engineered features.
|
| 19 |
+
|
| 20 |
+
This is the fast path used by training and the HuggingFace processor. It avoids
|
| 21 |
+
building an intermediate list-of-dicts representation by:
|
| 22 |
+
1) extracting x/y/t arrays once,
|
| 23 |
+
2) resampling x/y using spatial- or time-uniform interpolation,
|
| 24 |
+
3) recomputing dx/dy/ds and log_dt on the resampled trajectory.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
data_points: Raw path as a list of dicts with keys: "x", "y", "t".
|
| 28 |
+
max_len: Target length.
|
| 29 |
+
resample_mode: "spatial" (arc-length) or "time" (cumulative dt).
|
| 30 |
+
dt_clamp_min_ms: Clamp for dt feature after resampling (first dt remains 0).
|
| 31 |
+
dt_clamp_max_ms: Clamp for dt feature after resampling.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
(features, mask) where:
|
| 35 |
+
- features: [max_len, 6] float32 array (x, y, dx, dy, ds, log_dt)
|
| 36 |
+
- mask: [max_len] int64 array (1 for valid; all-ones for non-empty paths)
|
| 37 |
+
"""
|
| 38 |
+
num_points = len(data_points)
|
| 39 |
+
if num_points == 0:
|
| 40 |
+
return (
|
| 41 |
+
np.zeros((max_len, 6), dtype=np.float32),
|
| 42 |
+
np.zeros(max_len, dtype=np.int64),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
x = np.fromiter((p["x"] for p in data_points), dtype=np.float64, count=num_points)
|
| 46 |
+
y = np.fromiter((p["y"] for p in data_points), dtype=np.float64, count=num_points)
|
| 47 |
+
t = np.fromiter((p["t"] for p in data_points), dtype=np.float64, count=num_points)
|
| 48 |
+
|
| 49 |
+
x = np.clip(x, 0.0, 1.0)
|
| 50 |
+
y = np.clip(y, 0.0, 1.0)
|
| 51 |
+
|
| 52 |
+
# Per-step deltas and axes for resampling
|
| 53 |
+
dx_in = np.concatenate([[0.0], np.diff(x)])
|
| 54 |
+
dy_in = np.concatenate([[0.0], np.diff(y)])
|
| 55 |
+
ds_in = np.hypot(dx_in, dy_in)
|
| 56 |
+
dt_raw_in = np.concatenate([[0.0], np.diff(t)])
|
| 57 |
+
|
| 58 |
+
s = np.cumsum(ds_in)
|
| 59 |
+
tau = np.cumsum(dt_raw_in)
|
| 60 |
+
|
| 61 |
+
if resample_mode not in {"spatial", "time"}:
|
| 62 |
+
raise ValueError(f"Unknown resample_mode={resample_mode!r} (use 'spatial' or 'time')")
|
| 63 |
+
|
| 64 |
+
eps = 1e-12
|
| 65 |
+
if resample_mode == "time" and tau[-1] > eps:
|
| 66 |
+
target_tau = np.linspace(0.0, float(tau[-1]), max_len, dtype=np.float64)
|
| 67 |
+
x_r = np.interp(target_tau, tau, x)
|
| 68 |
+
y_r = np.interp(target_tau, tau, y)
|
| 69 |
+
tau_r = target_tau
|
| 70 |
+
else:
|
| 71 |
+
# Spatial sampling (or fallback when time axis is degenerate).
|
| 72 |
+
if s[-1] <= eps:
|
| 73 |
+
original = np.arange(num_points, dtype=np.float64)
|
| 74 |
+
target = np.linspace(0, num_points - 1, max_len, dtype=np.float64)
|
| 75 |
+
x_r = np.interp(target, original, x)
|
| 76 |
+
y_r = np.interp(target, original, y)
|
| 77 |
+
tau_r = np.interp(target, original, tau)
|
| 78 |
+
else:
|
| 79 |
+
target_s = np.linspace(0.0, float(s[-1]), max_len, dtype=np.float64)
|
| 80 |
+
x_r = np.interp(target_s, s, x)
|
| 81 |
+
y_r = np.interp(target_s, s, y)
|
| 82 |
+
tau_r = np.interp(target_s, s, tau)
|
| 83 |
+
|
| 84 |
+
dx = np.concatenate([[0.0], np.diff(x_r)])
|
| 85 |
+
dy = np.concatenate([[0.0], np.diff(y_r)])
|
| 86 |
+
ds = np.hypot(dx, dy)
|
| 87 |
+
dt_raw_r = np.concatenate([[0.0], np.diff(tau_r)])
|
| 88 |
+
dt_feat = np.clip(dt_raw_r, dt_clamp_min_ms, dt_clamp_max_ms)
|
| 89 |
+
dt_feat[0] = 0.0
|
| 90 |
+
log_dt = np.log1p(np.maximum(0.0, dt_feat))
|
| 91 |
+
|
| 92 |
+
mask = np.ones(max_len, dtype=np.int64)
|
| 93 |
+
features = np.stack([x_r, y_r, dx, dy, ds, log_dt], axis=-1).astype(np.float32)
|
| 94 |
+
return features, mask
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def normalize_and_compute_features(
|
| 98 |
+
data_points: list[dict],
|
| 99 |
+
dt_clamp_min_ms: float = 1.0,
|
| 100 |
+
dt_clamp_max_ms: float = 200.0,
|
| 101 |
+
) -> list[dict]:
|
| 102 |
+
"""
|
| 103 |
+
Normalize coordinates and compute motion features.
|
| 104 |
+
|
| 105 |
+
Computes delta features (dx, dy, dt) and log-scaled time deltas.
|
| 106 |
+
First point has dx=dy=dt=0 by convention.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
data_points: List of {"x", "y", "t"} dicts
|
| 110 |
+
dt_clamp_min_ms: Minimum dt in milliseconds (inclusive).
|
| 111 |
+
dt_clamp_max_ms: Maximum dt in milliseconds (inclusive).
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List of dicts with keys:
|
| 115 |
+
- x, y: normalized coordinates in [0, 1]
|
| 116 |
+
- t: raw timestamp from input (passed through)
|
| 117 |
+
- dx, dy: deltas in x/y
|
| 118 |
+
- ds: sqrt(dx^2 + dy^2)
|
| 119 |
+
- dt_raw: raw time delta (unclamped)
|
| 120 |
+
- dt: clamped time delta used for feature stability
|
| 121 |
+
- log_dt: log1p(dt)
|
| 122 |
+
"""
|
| 123 |
+
if not data_points:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
num_points = len(data_points)
|
| 127 |
+
x = np.fromiter((p["x"] for p in data_points), dtype=np.float64, count=num_points)
|
| 128 |
+
y = np.fromiter((p["y"] for p in data_points), dtype=np.float64, count=num_points)
|
| 129 |
+
t = np.fromiter((p["t"] for p in data_points), dtype=np.float64, count=num_points)
|
| 130 |
+
|
| 131 |
+
x = np.clip(x, 0.0, 1.0)
|
| 132 |
+
y = np.clip(y, 0.0, 1.0)
|
| 133 |
+
|
| 134 |
+
dx = np.concatenate([[0.0], np.diff(x)])
|
| 135 |
+
dy = np.concatenate([[0.0], np.diff(y)])
|
| 136 |
+
ds = np.hypot(dx, dy)
|
| 137 |
+
dt_raw = np.concatenate([[0.0], np.diff(t)])
|
| 138 |
+
|
| 139 |
+
dt = np.clip(dt_raw, dt_clamp_min_ms, dt_clamp_max_ms)
|
| 140 |
+
dt[0] = 0.0
|
| 141 |
+
log_dt = np.log1p(np.maximum(0.0, dt))
|
| 142 |
+
|
| 143 |
+
out: list[dict] = []
|
| 144 |
+
for i in range(num_points):
|
| 145 |
+
out.append(
|
| 146 |
+
{
|
| 147 |
+
"x": float(x[i]),
|
| 148 |
+
"y": float(y[i]),
|
| 149 |
+
"t": float(t[i]),
|
| 150 |
+
"dx": float(dx[i]),
|
| 151 |
+
"dy": float(dy[i]),
|
| 152 |
+
"ds": float(ds[i]),
|
| 153 |
+
"dt_raw": float(dt_raw[i]),
|
| 154 |
+
"dt": float(dt[i]),
|
| 155 |
+
"log_dt": float(log_dt[i]),
|
| 156 |
+
}
|
| 157 |
+
)
|
| 158 |
+
return out
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def sample_path_points_with_features(
|
| 162 |
+
data_points: list[dict],
|
| 163 |
+
max_len: int,
|
| 164 |
+
*,
|
| 165 |
+
resample_mode: str = "spatial",
|
| 166 |
+
dt_clamp_min_ms: float = 1.0,
|
| 167 |
+
dt_clamp_max_ms: float = 200.0,
|
| 168 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 169 |
+
"""
|
| 170 |
+
Sample path points with motion features to fixed length using interpolation.
|
| 171 |
+
|
| 172 |
+
Always uses interpolation (no zero-padding) to preserve feature structure.
|
| 173 |
+
Paths shorter than max_len are upsampled; longer paths are downsampled.
|
| 174 |
+
|
| 175 |
+
Modes:
|
| 176 |
+
- resample_mode="spatial": sample approximately uniformly in arc length (distance).
|
| 177 |
+
- resample_mode="time": sample uniformly in time (dwell regions get more samples).
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
data_points: List of coordinate dicts. Expected keys: x, y and either:
|
| 181 |
+
- dx, dy (preferred), plus optional ds, dt, log_dt; or
|
| 182 |
+
- ds/log_dt/dt (ds can be derived from dx/dy; dt from log_dt).
|
| 183 |
+
max_len: Target length
|
| 184 |
+
resample_mode: "spatial" or "time"
|
| 185 |
+
dt_clamp_min_ms: Clamp for dt feature after resampling (first dt remains 0).
|
| 186 |
+
dt_clamp_max_ms: Clamp for dt feature after resampling.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Tuple of (features, mask) where:
|
| 190 |
+
- features: [max_len, 6] array with (x, y, dx, dy, ds, log_dt)
|
| 191 |
+
- mask: [max_len] binary mask (all 1s since we always interpolate)
|
| 192 |
+
"""
|
| 193 |
+
num_points = len(data_points)
|
| 194 |
+
|
| 195 |
+
if num_points == 0:
|
| 196 |
+
# Empty path - return zeros
|
| 197 |
+
return (
|
| 198 |
+
np.zeros((max_len, 6), dtype=np.float32),
|
| 199 |
+
np.zeros(max_len, dtype=np.int64),
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Extract base signals
|
| 203 |
+
x = np.fromiter((p["x"] for p in data_points), dtype=np.float64, count=num_points)
|
| 204 |
+
y = np.fromiter((p["y"] for p in data_points), dtype=np.float64, count=num_points)
|
| 205 |
+
|
| 206 |
+
# Prefer provided dx/dy, otherwise derive from x/y
|
| 207 |
+
if all("dx" in p for p in data_points) and all("dy" in p for p in data_points):
|
| 208 |
+
dx_in = np.fromiter((p["dx"] for p in data_points), dtype=np.float64, count=num_points)
|
| 209 |
+
dy_in = np.fromiter((p["dy"] for p in data_points), dtype=np.float64, count=num_points)
|
| 210 |
+
else:
|
| 211 |
+
dx_in = np.concatenate([[0.0], np.diff(x)])
|
| 212 |
+
dy_in = np.concatenate([[0.0], np.diff(y)])
|
| 213 |
+
|
| 214 |
+
# ds can be provided or derived from dx/dy
|
| 215 |
+
if all("ds" in p for p in data_points):
|
| 216 |
+
ds_in = np.fromiter((p["ds"] for p in data_points), dtype=np.float64, count=num_points)
|
| 217 |
+
else:
|
| 218 |
+
ds_in = np.sqrt(dx_in**2 + dy_in**2)
|
| 219 |
+
|
| 220 |
+
# Time axis for resampling: prefer dt_raw (unclamped) so "dwell" gets represented.
|
| 221 |
+
if all("dt_raw" in p for p in data_points):
|
| 222 |
+
dt_axis = np.fromiter(
|
| 223 |
+
(p["dt_raw"] for p in data_points), dtype=np.float64, count=num_points
|
| 224 |
+
)
|
| 225 |
+
elif all("dt" in p for p in data_points):
|
| 226 |
+
dt_axis = np.fromiter((p["dt"] for p in data_points), dtype=np.float64, count=num_points)
|
| 227 |
+
elif all("log_dt" in p for p in data_points):
|
| 228 |
+
log_dt_in_raw = np.fromiter(
|
| 229 |
+
(p["log_dt"] for p in data_points), dtype=np.float64, count=num_points
|
| 230 |
+
)
|
| 231 |
+
dt_axis = np.expm1(log_dt_in_raw)
|
| 232 |
+
else:
|
| 233 |
+
dt_axis = np.zeros(num_points, dtype=np.float64)
|
| 234 |
+
|
| 235 |
+
# Cumulative arc length (s) and cumulative time (tau) for resampling
|
| 236 |
+
s = np.cumsum(ds_in)
|
| 237 |
+
tau = np.cumsum(dt_axis)
|
| 238 |
+
|
| 239 |
+
if resample_mode not in {"spatial", "time"}:
|
| 240 |
+
raise ValueError(f"Unknown resample_mode={resample_mode!r} (use 'spatial' or 'time')")
|
| 241 |
+
|
| 242 |
+
eps = 1e-12
|
| 243 |
+
|
| 244 |
+
if resample_mode == "time" and tau[-1] > eps:
|
| 245 |
+
target_tau = np.linspace(0.0, float(tau[-1]), max_len, dtype=np.float64)
|
| 246 |
+
x_r = np.interp(target_tau, tau, x)
|
| 247 |
+
y_r = np.interp(target_tau, tau, y)
|
| 248 |
+
tau_r = target_tau
|
| 249 |
+
else:
|
| 250 |
+
# Spatial sampling (or fallback when time axis is degenerate).
|
| 251 |
+
# Handle degenerate paths (zero movement): fall back to index-based interpolation
|
| 252 |
+
if s[-1] <= eps:
|
| 253 |
+
original = np.arange(num_points, dtype=np.float64)
|
| 254 |
+
target = np.linspace(0, num_points - 1, max_len, dtype=np.float64)
|
| 255 |
+
x_r = np.interp(target, original, x)
|
| 256 |
+
y_r = np.interp(target, original, y)
|
| 257 |
+
tau_r = np.interp(target, original, tau)
|
| 258 |
+
else:
|
| 259 |
+
target_s = np.linspace(0.0, float(s[-1]), max_len, dtype=np.float64)
|
| 260 |
+
x_r = np.interp(target_s, s, x)
|
| 261 |
+
y_r = np.interp(target_s, s, y)
|
| 262 |
+
tau_r = np.interp(target_s, s, tau)
|
| 263 |
+
|
| 264 |
+
# Recompute deltas on the resampled path for consistency
|
| 265 |
+
dx = np.concatenate([[0.0], np.diff(x_r)])
|
| 266 |
+
dy = np.concatenate([[0.0], np.diff(y_r)])
|
| 267 |
+
ds = np.sqrt(dx**2 + dy**2)
|
| 268 |
+
dt_raw_r = np.concatenate([[0.0], np.diff(tau_r)])
|
| 269 |
+
dt_feat = np.clip(dt_raw_r, dt_clamp_min_ms, dt_clamp_max_ms)
|
| 270 |
+
dt_feat[0] = 0.0
|
| 271 |
+
log_dt = np.log1p(np.maximum(0.0, dt_feat))
|
| 272 |
+
|
| 273 |
+
mask = np.ones(max_len, dtype=np.int64)
|
| 274 |
+
features = np.stack([x_r, y_r, dx, dy, ds, log_dt], axis=-1).astype(np.float32)
|
| 275 |
+
return features, mask
|
processing_swipe.py
CHANGED
|
@@ -1,9 +1,15 @@
|
|
| 1 |
"""Processor for handling multimodal swipe inputs (path + text)."""
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
from transformers import ProcessorMixin
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class SwipeProcessor(ProcessorMixin):
|
| 9 |
"""
|
|
@@ -21,10 +27,19 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 21 |
attributes = ["tokenizer"]
|
| 22 |
tokenizer_class = "AutoTokenizer" # Will use auto_map from tokenizer_config.json
|
| 23 |
|
| 24 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
self.tokenizer = tokenizer
|
| 26 |
self.max_path_len = max_path_len
|
| 27 |
self.max_char_len = max_char_len
|
|
|
|
|
|
|
| 28 |
# Attributes expected by newer transformers (not used for swipe models)
|
| 29 |
self.chat_template = None
|
| 30 |
self.audio_tokenizer = None
|
|
@@ -33,21 +48,36 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 33 |
|
| 34 |
def __call__(
|
| 35 |
self,
|
| 36 |
-
path_coords:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
text: str | list[str] | None = None,
|
| 38 |
padding: bool | str = True,
|
| 39 |
truncation: bool = True,
|
| 40 |
max_length: int | None = None,
|
| 41 |
return_tensors: str | None = "pt",
|
| 42 |
-
**kwargs,
|
| 43 |
):
|
| 44 |
"""
|
| 45 |
Process path coordinates and text into model inputs.
|
| 46 |
|
| 47 |
Args:
|
| 48 |
-
path_coords:
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
padding: Whether to pad sequences. Can be True/False or "max_length"
|
| 52 |
truncation: Whether to truncate sequences
|
| 53 |
max_length: Maximum sequence length for text (overrides max_char_len)
|
|
@@ -56,9 +86,10 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 56 |
|
| 57 |
Returns:
|
| 58 |
Dictionary with:
|
| 59 |
-
- path_coords: [batch, max_path_len,
|
|
|
|
| 60 |
- input_ids: [batch, max_char_len] (if text provided)
|
| 61 |
-
- attention_mask: [batch, total_seq_len]
|
| 62 |
"""
|
| 63 |
if path_coords is None and text is None:
|
| 64 |
raise ValueError("Must provide either path_coords or text (or both)")
|
|
@@ -67,24 +98,43 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 67 |
if path_coords is not None:
|
| 68 |
# Handle path coordinates
|
| 69 |
if isinstance(path_coords, (list, tuple)):
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# Batch of paths [[path1], [path2], ...]
|
| 73 |
-
path_coords = torch.tensor(path_coords, dtype=torch.float32)
|
| 74 |
else:
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
elif isinstance(path_coords, np.ndarray):
|
| 78 |
path_coords = torch.from_numpy(path_coords).float()
|
| 79 |
if path_coords.dim() == 2:
|
| 80 |
# Single path, add batch dimension
|
| 81 |
path_coords = path_coords.unsqueeze(0)
|
|
|
|
| 82 |
elif isinstance(path_coords, torch.Tensor):
|
| 83 |
if path_coords.dim() == 2:
|
| 84 |
# Single path, add batch dimension
|
| 85 |
path_coords = path_coords.unsqueeze(0)
|
| 86 |
-
|
| 87 |
-
batch_size = path_coords.shape[0]
|
| 88 |
elif text is not None:
|
| 89 |
if isinstance(text, str):
|
| 90 |
batch_size = 1
|
|
@@ -98,31 +148,120 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 98 |
|
| 99 |
# Process path coordinates
|
| 100 |
if path_coords is not None:
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
path_coords = path_coords[:, : self.max_path_len, :]
|
| 106 |
-
current_path_len = self.max_path_len
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
|
| 120 |
result["path_coords"] = path_coords
|
| 121 |
-
# Store path_mask internally for attention_mask construction
|
| 122 |
-
_path_mask = path_mask
|
| 123 |
else:
|
| 124 |
# No path coords provided, create empty/zero tensors
|
| 125 |
-
path_coords = torch.zeros(batch_size, self.max_path_len,
|
| 126 |
_path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
|
| 127 |
result["path_coords"] = path_coords
|
| 128 |
|
|
@@ -157,9 +296,9 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 157 |
# Truncate but preserve EOS at the end
|
| 158 |
for i in range(len(encoded_raw["input_ids"])):
|
| 159 |
if len(encoded_raw["input_ids"][i]) > text_max_length:
|
| 160 |
-
encoded_raw["input_ids"][i] =
|
| 161 |
-
|
| 162 |
-
|
| 163 |
|
| 164 |
# Pad sequences
|
| 165 |
if padding:
|
|
@@ -264,104 +403,48 @@ class SwipeProcessor(ProcessorMixin):
|
|
| 264 |
"""
|
| 265 |
return self.tokenizer.decode(token_ids, **kwargs)
|
| 266 |
|
| 267 |
-
def
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
"""
|
| 271 |
-
Normalize swipe coordinates and timestamps.
|
| 272 |
-
|
| 273 |
-
Args:
|
| 274 |
-
data_points: List of dicts with 'x', 'y', 't' keys
|
| 275 |
-
canvas_width: Canvas width (not used - kept for compatibility)
|
| 276 |
-
canvas_height: Canvas height (not used - kept for compatibility)
|
| 277 |
-
|
| 278 |
-
Returns:
|
| 279 |
-
List of normalized coordinate dicts with x, y in [0,1] and t in [0,1]
|
| 280 |
-
|
| 281 |
-
Note:
|
| 282 |
-
For futo-org/swipe.futo.org dataset, x and y are already normalized to [0,1].
|
| 283 |
-
This function clamps them to ensure they stay in bounds and normalizes timestamps.
|
| 284 |
-
"""
|
| 285 |
-
if not data_points:
|
| 286 |
-
return []
|
| 287 |
-
|
| 288 |
-
# Extract timestamps for normalization
|
| 289 |
-
timestamps = [p["t"] for p in data_points]
|
| 290 |
-
t_min = min(timestamps)
|
| 291 |
-
t_max = max(timestamps)
|
| 292 |
-
t_range = t_max - t_min if t_max > t_min else 1.0
|
| 293 |
-
|
| 294 |
-
normalized = []
|
| 295 |
-
for point in data_points:
|
| 296 |
-
# x and y are already normalized to [0,1] in the dataset
|
| 297 |
-
# But sometimes they go slightly outside bounds, so clamp them
|
| 298 |
-
x_norm = max(0.0, min(1.0, point["x"]))
|
| 299 |
-
y_norm = max(0.0, min(1.0, point["y"]))
|
| 300 |
-
|
| 301 |
-
# Normalize timestamp to [0, 1]
|
| 302 |
-
t_norm = (point["t"] - t_min) / t_range
|
| 303 |
|
| 304 |
-
|
|
|
|
|
|
|
| 305 |
|
| 306 |
-
|
|
|
|
| 307 |
|
| 308 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
"""
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
Args:
|
| 313 |
-
data_points: List of coordinate dicts with 'x', 'y', 't' keys
|
| 314 |
-
max_len: Target length (defaults to self.max_path_len if not specified)
|
| 315 |
-
|
| 316 |
-
Returns:
|
| 317 |
-
Tuple of (sampled_points, mask) where:
|
| 318 |
-
- sampled_points: numpy array of shape [max_len, 3] with (x, y, t) coordinates
|
| 319 |
-
- mask: numpy array of shape [max_len] indicating valid (1) vs padding (0) points
|
| 320 |
-
|
| 321 |
-
Note:
|
| 322 |
-
- If path has fewer points than max_len, it's zero-padded
|
| 323 |
-
- If path has more points than max_len, it's downsampled using linear interpolation
|
| 324 |
-
- If path has exactly max_len points, it's returned as-is
|
| 325 |
"""
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
target_indices = np.linspace(0, num_points - 1, max_len)
|
| 350 |
-
|
| 351 |
-
# Interpolate each coordinate independently
|
| 352 |
-
x_interp = np.interp(target_indices, original_indices, x_coords)
|
| 353 |
-
y_interp = np.interp(target_indices, original_indices, y_coords)
|
| 354 |
-
t_interp = np.interp(target_indices, original_indices, t_coords)
|
| 355 |
-
|
| 356 |
-
# Reconstruct points
|
| 357 |
-
points = [
|
| 358 |
-
{"x": float(x), "y": float(y), "t": float(t)}
|
| 359 |
-
for x, y, t in zip(x_interp, y_interp, t_interp, strict=True)
|
| 360 |
-
]
|
| 361 |
-
mask = [1] * max_len
|
| 362 |
|
| 363 |
-
|
| 364 |
-
coords = np.array([[p["x"], p["y"], p["t"]] for p in points], dtype=np.float32)
|
| 365 |
-
mask = np.array(mask, dtype=np.int64)
|
| 366 |
-
|
| 367 |
-
return coords, mask
|
|
|
|
| 1 |
"""Processor for handling multimodal swipe inputs (path + text)."""
|
| 2 |
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
from transformers import ProcessorMixin
|
| 10 |
|
| 11 |
+
from .preprocessing import preprocess_raw_path_to_features
|
| 12 |
+
|
| 13 |
|
| 14 |
class SwipeProcessor(ProcessorMixin):
|
| 15 |
"""
|
|
|
|
| 27 |
attributes = ["tokenizer"]
|
| 28 |
tokenizer_class = "AutoTokenizer" # Will use auto_map from tokenizer_config.json
|
| 29 |
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
tokenizer=None,
|
| 33 |
+
max_path_len: int = 64,
|
| 34 |
+
max_char_len: int = 38,
|
| 35 |
+
path_input_dim: int = 6,
|
| 36 |
+
path_resample_mode: str = "time",
|
| 37 |
+
):
|
| 38 |
self.tokenizer = tokenizer
|
| 39 |
self.max_path_len = max_path_len
|
| 40 |
self.max_char_len = max_char_len
|
| 41 |
+
self.path_input_dim = path_input_dim
|
| 42 |
+
self.path_resample_mode = path_resample_mode
|
| 43 |
# Attributes expected by newer transformers (not used for swipe models)
|
| 44 |
self.chat_template = None
|
| 45 |
self.audio_tokenizer = None
|
|
|
|
| 48 |
|
| 49 |
def __call__(
|
| 50 |
self,
|
| 51 |
+
path_coords: (
|
| 52 |
+
list[dict[str, float]]
|
| 53 |
+
| list[list[dict[str, float]]]
|
| 54 |
+
| list[list[list[float]]]
|
| 55 |
+
| torch.Tensor
|
| 56 |
+
| np.ndarray
|
| 57 |
+
| None
|
| 58 |
+
) = None,
|
| 59 |
text: str | list[str] | None = None,
|
| 60 |
padding: bool | str = True,
|
| 61 |
truncation: bool = True,
|
| 62 |
max_length: int | None = None,
|
| 63 |
return_tensors: str | None = "pt",
|
| 64 |
+
**kwargs: Any,
|
| 65 |
):
|
| 66 |
"""
|
| 67 |
Process path coordinates and text into model inputs.
|
| 68 |
|
| 69 |
Args:
|
| 70 |
+
path_coords:
|
| 71 |
+
Swipe paths in one of the supported formats:
|
| 72 |
+
- Raw path (single example): list of dicts like `{"x": ..., "y": ..., "t": ...}`
|
| 73 |
+
- Raw batch: list of raw paths
|
| 74 |
+
- Numeric arrays/tensors: `[batch, path_len, D]` or `[path_len, D]`
|
| 75 |
+
If `D==3` and `path_input_dim==6`, raw `(x,y,t)` triples are converted to engineered
|
| 76 |
+
`(x, y, dx, dy, ds, log_dt)` features and resampled to `max_path_len`.
|
| 77 |
+
If omitted, the processor emits a zero path with a zero path attention mask.
|
| 78 |
+
text:
|
| 79 |
+
String or list of strings to encode.
|
| 80 |
+
If omitted, the processor emits padded text tokens with a zero text attention mask.
|
| 81 |
padding: Whether to pad sequences. Can be True/False or "max_length"
|
| 82 |
truncation: Whether to truncate sequences
|
| 83 |
max_length: Maximum sequence length for text (overrides max_char_len)
|
|
|
|
| 86 |
|
| 87 |
Returns:
|
| 88 |
Dictionary with:
|
| 89 |
+
- path_coords: [batch, max_path_len, path_input_dim] (if path_coords provided)
|
| 90 |
+
Default: [batch, max_path_len, 6] for (x, y, dx, dy, ds, log_dt)
|
| 91 |
- input_ids: [batch, max_char_len] (if text provided)
|
| 92 |
+
- attention_mask: [batch, total_seq_len] (covers `[CLS] + path + [SEP] + text`)
|
| 93 |
"""
|
| 94 |
if path_coords is None and text is None:
|
| 95 |
raise ValueError("Must provide either path_coords or text (or both)")
|
|
|
|
| 98 |
if path_coords is not None:
|
| 99 |
# Handle path coordinates
|
| 100 |
if isinstance(path_coords, (list, tuple)):
|
| 101 |
+
if len(path_coords) == 0:
|
| 102 |
+
batch_size = 1
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
+
first = path_coords[0]
|
| 105 |
+
# Raw single path: [{"x","y","t"}, ...]
|
| 106 |
+
if isinstance(first, dict):
|
| 107 |
+
batch_size = 1
|
| 108 |
+
# Raw batch of paths: [[{"x","y","t"}, ...], ...]
|
| 109 |
+
elif (
|
| 110 |
+
isinstance(first, (list, tuple))
|
| 111 |
+
and len(first) > 0
|
| 112 |
+
and isinstance(first[0], dict)
|
| 113 |
+
):
|
| 114 |
+
batch_size = len(path_coords)
|
| 115 |
+
# Numeric batch: [[[...], ...], ...] where points are lists/tuples
|
| 116 |
+
elif (
|
| 117 |
+
isinstance(first, (list, tuple))
|
| 118 |
+
and len(first) > 0
|
| 119 |
+
and isinstance(first[0], (list, tuple))
|
| 120 |
+
):
|
| 121 |
+
path_coords = torch.tensor(path_coords, dtype=torch.float32)
|
| 122 |
+
batch_size = path_coords.shape[0]
|
| 123 |
+
else:
|
| 124 |
+
# Numeric single path: [[...], [...], ...]
|
| 125 |
+
path_coords = torch.tensor([path_coords], dtype=torch.float32)
|
| 126 |
+
batch_size = path_coords.shape[0]
|
| 127 |
elif isinstance(path_coords, np.ndarray):
|
| 128 |
path_coords = torch.from_numpy(path_coords).float()
|
| 129 |
if path_coords.dim() == 2:
|
| 130 |
# Single path, add batch dimension
|
| 131 |
path_coords = path_coords.unsqueeze(0)
|
| 132 |
+
batch_size = path_coords.shape[0]
|
| 133 |
elif isinstance(path_coords, torch.Tensor):
|
| 134 |
if path_coords.dim() == 2:
|
| 135 |
# Single path, add batch dimension
|
| 136 |
path_coords = path_coords.unsqueeze(0)
|
| 137 |
+
batch_size = path_coords.shape[0]
|
|
|
|
| 138 |
elif text is not None:
|
| 139 |
if isinstance(text, str):
|
| 140 |
batch_size = 1
|
|
|
|
| 148 |
|
| 149 |
# Process path coordinates
|
| 150 |
if path_coords is not None:
|
| 151 |
+
# Check if path_coords is raw data (list of dicts) or already a tensor
|
| 152 |
+
if isinstance(path_coords, (list, tuple)) and len(path_coords) > 0:
|
| 153 |
+
first_elem = path_coords[0]
|
| 154 |
+
|
| 155 |
+
# Raw single path: [{"x","y","t"}, ...]
|
| 156 |
+
if isinstance(first_elem, dict) and "x" in first_elem:
|
| 157 |
+
path_feats, mask = preprocess_raw_path_to_features(
|
| 158 |
+
path_coords,
|
| 159 |
+
self.max_path_len,
|
| 160 |
+
resample_mode=self.path_resample_mode,
|
| 161 |
+
)
|
| 162 |
+
if return_tensors == "pt":
|
| 163 |
+
path_coords = torch.from_numpy(path_feats).float().unsqueeze(0)
|
| 164 |
+
_path_mask = torch.from_numpy(mask).long().unsqueeze(0)
|
| 165 |
+
else:
|
| 166 |
+
path_coords = np.expand_dims(path_feats, axis=0)
|
| 167 |
+
_path_mask = np.expand_dims(mask, axis=0)
|
| 168 |
+
|
| 169 |
+
# Raw batch of paths: [[{"x","y","t"}, ...], ...]
|
| 170 |
+
elif (
|
| 171 |
+
isinstance(first_elem, (list, tuple))
|
| 172 |
+
and len(first_elem) > 0
|
| 173 |
+
and isinstance(first_elem[0], dict)
|
| 174 |
+
and "x" in first_elem[0]
|
| 175 |
+
):
|
| 176 |
+
processed_paths = []
|
| 177 |
+
path_masks = []
|
| 178 |
+
for path in path_coords:
|
| 179 |
+
path_feats, mask = preprocess_raw_path_to_features(
|
| 180 |
+
path,
|
| 181 |
+
self.max_path_len,
|
| 182 |
+
resample_mode=self.path_resample_mode,
|
| 183 |
+
)
|
| 184 |
+
processed_paths.append(path_feats)
|
| 185 |
+
path_masks.append(mask)
|
| 186 |
|
| 187 |
+
path_coords = np.stack(processed_paths) # [batch, max_path_len, 6]
|
| 188 |
+
_path_mask = np.stack(path_masks) # [batch, max_path_len]
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
if return_tensors == "pt":
|
| 191 |
+
path_coords = torch.from_numpy(path_coords).float()
|
| 192 |
+
_path_mask = torch.from_numpy(_path_mask).long()
|
| 193 |
+
|
| 194 |
+
else:
|
| 195 |
+
# Numeric list input; process as before
|
| 196 |
+
path_coords = torch.tensor(path_coords, dtype=torch.float32)
|
| 197 |
+
if path_coords.dim() == 2:
|
| 198 |
+
path_coords = path_coords.unsqueeze(0)
|
| 199 |
+
|
| 200 |
+
current_path_len = path_coords.shape[1]
|
| 201 |
+
if truncation and current_path_len > self.max_path_len:
|
| 202 |
+
path_coords = path_coords[:, : self.max_path_len, :]
|
| 203 |
+
if padding and current_path_len < self.max_path_len:
|
| 204 |
+
pad_len = self.max_path_len - current_path_len
|
| 205 |
+
pad_shape = (batch_size, pad_len, self.path_input_dim)
|
| 206 |
+
path_coords = torch.cat([path_coords, torch.zeros(pad_shape)], dim=1)
|
| 207 |
+
|
| 208 |
+
_path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
|
| 209 |
+
is_padding = (path_coords == 0).all(dim=-1)
|
| 210 |
+
_path_mask[is_padding] = 0
|
| 211 |
+
elif isinstance(path_coords, np.ndarray):
|
| 212 |
+
path_coords = torch.from_numpy(path_coords).float()
|
| 213 |
+
if path_coords.dim() == 2:
|
| 214 |
+
path_coords = path_coords.unsqueeze(0)
|
| 215 |
+
# If user provided raw (x,y,t) triples but model expects engineered features,
|
| 216 |
+
# convert to motion features and resample.
|
| 217 |
+
if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
|
| 218 |
+
processed_paths = []
|
| 219 |
+
path_masks = []
|
| 220 |
+
for path in path_coords.cpu().numpy():
|
| 221 |
+
raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
|
| 222 |
+
path_feats, mask = preprocess_raw_path_to_features(
|
| 223 |
+
raw,
|
| 224 |
+
self.max_path_len,
|
| 225 |
+
resample_mode=self.path_resample_mode,
|
| 226 |
+
)
|
| 227 |
+
processed_paths.append(path_feats)
|
| 228 |
+
path_masks.append(mask)
|
| 229 |
+
|
| 230 |
+
path_coords = torch.from_numpy(np.stack(processed_paths)).float()
|
| 231 |
+
_path_mask = torch.from_numpy(np.stack(path_masks)).long()
|
| 232 |
+
else:
|
| 233 |
+
_path_mask = torch.ones(
|
| 234 |
+
path_coords.shape[0], self.max_path_len, dtype=torch.long
|
| 235 |
+
)
|
| 236 |
+
elif isinstance(path_coords, torch.Tensor):
|
| 237 |
+
if path_coords.dim() == 2:
|
| 238 |
+
path_coords = path_coords.unsqueeze(0)
|
| 239 |
+
# If user provided raw (x,y,t) triples but model expects engineered features,
|
| 240 |
+
# convert to motion features and resample.
|
| 241 |
+
if path_coords.shape[-1] == 3 and self.path_input_dim == 6:
|
| 242 |
+
processed_paths = []
|
| 243 |
+
path_masks = []
|
| 244 |
+
for path in path_coords.detach().cpu().numpy():
|
| 245 |
+
raw = [{"x": float(p[0]), "y": float(p[1]), "t": float(p[2])} for p in path]
|
| 246 |
+
path_feats, mask = preprocess_raw_path_to_features(
|
| 247 |
+
raw,
|
| 248 |
+
self.max_path_len,
|
| 249 |
+
resample_mode=self.path_resample_mode,
|
| 250 |
+
)
|
| 251 |
+
processed_paths.append(path_feats)
|
| 252 |
+
path_masks.append(mask)
|
| 253 |
|
| 254 |
+
path_coords = torch.from_numpy(np.stack(processed_paths)).float()
|
| 255 |
+
_path_mask = torch.from_numpy(np.stack(path_masks)).long()
|
| 256 |
+
else:
|
| 257 |
+
_path_mask = torch.ones(
|
| 258 |
+
path_coords.shape[0], self.max_path_len, dtype=torch.long
|
| 259 |
+
)
|
| 260 |
|
| 261 |
result["path_coords"] = path_coords
|
|
|
|
|
|
|
| 262 |
else:
|
| 263 |
# No path coords provided, create empty/zero tensors
|
| 264 |
+
path_coords = torch.zeros(batch_size, self.max_path_len, self.path_input_dim)
|
| 265 |
_path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
|
| 266 |
result["path_coords"] = path_coords
|
| 267 |
|
|
|
|
| 296 |
# Truncate but preserve EOS at the end
|
| 297 |
for i in range(len(encoded_raw["input_ids"])):
|
| 298 |
if len(encoded_raw["input_ids"][i]) > text_max_length:
|
| 299 |
+
encoded_raw["input_ids"][i] = encoded_raw["input_ids"][i][
|
| 300 |
+
: text_max_length - 1
|
| 301 |
+
] + [eos_id]
|
| 302 |
|
| 303 |
# Pad sequences
|
| 304 |
if padding:
|
|
|
|
| 403 |
"""
|
| 404 |
return self.tokenizer.decode(token_ids, **kwargs)
|
| 405 |
|
| 406 |
+
def encode_path(self, path_coords, *, return_tensors: str | None = "pt", **kwargs: Any):
|
| 407 |
+
"""Create model inputs from a swipe path only (no text)."""
|
| 408 |
+
return self(path_coords=path_coords, text=None, return_tensors=return_tensors, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
+
def encode_text(self, text, *, return_tensors: str | None = "pt", **kwargs: Any):
|
| 411 |
+
"""Create model inputs from text only (no path)."""
|
| 412 |
+
return self(path_coords=None, text=text, return_tensors=return_tensors, **kwargs)
|
| 413 |
|
| 414 |
+
# Preprocessing methods are now imported from shared preprocessing module
|
| 415 |
+
# See src/swipealot/data/preprocessing.py for the implementation
|
| 416 |
|
| 417 |
+
def save_pretrained(
|
| 418 |
+
self,
|
| 419 |
+
save_directory,
|
| 420 |
+
push_to_hub=False,
|
| 421 |
+
**kwargs,
|
| 422 |
+
):
|
| 423 |
"""
|
| 424 |
+
Save the processor to a directory, ensuring auto_map is included.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
"""
|
| 426 |
+
# Call parent save_pretrained
|
| 427 |
+
result = super().save_pretrained(
|
| 428 |
+
save_directory,
|
| 429 |
+
push_to_hub=push_to_hub,
|
| 430 |
+
**kwargs,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Add auto_map to processor_config.json for AutoProcessor compatibility
|
| 434 |
+
import json
|
| 435 |
+
from pathlib import Path
|
| 436 |
+
|
| 437 |
+
# Try both possible config file names
|
| 438 |
+
for config_name in ["preprocessor_config.json", "processor_config.json"]:
|
| 439 |
+
processor_config_path = Path(save_directory) / config_name
|
| 440 |
+
if processor_config_path.exists():
|
| 441 |
+
with open(processor_config_path) as f:
|
| 442 |
+
config = json.load(f)
|
| 443 |
+
|
| 444 |
+
config["auto_map"] = {"AutoProcessor": "processing_swipe.SwipeProcessor"}
|
| 445 |
+
|
| 446 |
+
with open(processor_config_path, "w") as f:
|
| 447 |
+
json.dump(config, f, indent=2)
|
| 448 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
|
| 450 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
processor_config.json
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
{
|
| 2 |
"max_char_len": 48,
|
| 3 |
"max_path_len": 128,
|
|
|
|
|
|
|
| 4 |
"processor_class": "SwipeProcessor",
|
| 5 |
"auto_map": {
|
| 6 |
"AutoProcessor": "processing_swipe.SwipeProcessor"
|
|
|
|
| 1 |
{
|
| 2 |
"max_char_len": 48,
|
| 3 |
"max_path_len": 128,
|
| 4 |
+
"path_input_dim": 6,
|
| 5 |
+
"path_resample_mode": "time",
|
| 6 |
"processor_class": "SwipeProcessor",
|
| 7 |
"auto_map": {
|
| 8 |
"AutoProcessor": "processing_swipe.SwipeProcessor"
|
special_tokens_map.json
CHANGED
|
@@ -1,8 +1,44 @@
|
|
| 1 |
{
|
| 2 |
-
"cls_token":
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"cls_token": {
|
| 3 |
+
"content": "[CLS]",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "[EOS]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"mask_token": {
|
| 17 |
+
"content": "[MASK]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"pad_token": {
|
| 24 |
+
"content": "[PAD]",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
"sep_token": {
|
| 31 |
+
"content": "[SEP]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
"unk_token": {
|
| 38 |
+
"content": "[UNK]",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
}
|
| 44 |
}
|
tokenization_swipe.py
CHANGED
|
@@ -117,7 +117,8 @@ class SwipeTokenizer(PreTrainedTokenizer):
|
|
| 117 |
Returns:
|
| 118 |
int: Token ID
|
| 119 |
"""
|
| 120 |
-
|
|
|
|
| 121 |
|
| 122 |
def _convert_id_to_token(self, index: int) -> str:
|
| 123 |
"""
|
|
@@ -154,6 +155,41 @@ class SwipeTokenizer(PreTrainedTokenizer):
|
|
| 154 |
filtered = [t for t in tokens if t not in special_tokens]
|
| 155 |
return "".join(filtered)
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple:
|
| 158 |
"""
|
| 159 |
Save the tokenizer vocabulary to a directory.
|
|
|
|
| 117 |
Returns:
|
| 118 |
int: Token ID
|
| 119 |
"""
|
| 120 |
+
# Delegate to the core tokenizer to keep token/id mapping logic in one place.
|
| 121 |
+
return self._tokenizer.token_to_id(token)
|
| 122 |
|
| 123 |
def _convert_id_to_token(self, index: int) -> str:
|
| 124 |
"""
|
|
|
|
| 155 |
filtered = [t for t in tokens if t not in special_tokens]
|
| 156 |
return "".join(filtered)
|
| 157 |
|
| 158 |
+
def save_pretrained(
|
| 159 |
+
self,
|
| 160 |
+
save_directory,
|
| 161 |
+
legacy_format=None,
|
| 162 |
+
filename_prefix=None,
|
| 163 |
+
push_to_hub=False,
|
| 164 |
+
**kwargs,
|
| 165 |
+
):
|
| 166 |
+
"""Save the tokenizer and write `auto_map` for `AutoTokenizer` loading."""
|
| 167 |
+
# Call parent save_pretrained
|
| 168 |
+
result = super().save_pretrained(
|
| 169 |
+
save_directory,
|
| 170 |
+
legacy_format=legacy_format,
|
| 171 |
+
filename_prefix=filename_prefix,
|
| 172 |
+
push_to_hub=push_to_hub,
|
| 173 |
+
**kwargs,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Add auto_map to tokenizer_config.json for AutoTokenizer compatibility
|
| 177 |
+
from pathlib import Path
|
| 178 |
+
|
| 179 |
+
tokenizer_config_path = Path(save_directory) / "tokenizer_config.json"
|
| 180 |
+
if tokenizer_config_path.exists():
|
| 181 |
+
with open(tokenizer_config_path) as f:
|
| 182 |
+
config = json.load(f)
|
| 183 |
+
|
| 184 |
+
# For tokenizers, Transformers expects the 2-tuple form: [slow, fast].
|
| 185 |
+
# We only ship a slow tokenizer implementation, so fast is None.
|
| 186 |
+
config["auto_map"] = {"AutoTokenizer": ["tokenization_swipe.SwipeTokenizer", None]}
|
| 187 |
+
|
| 188 |
+
with open(tokenizer_config_path, "w") as f:
|
| 189 |
+
json.dump(config, f, indent=2)
|
| 190 |
+
|
| 191 |
+
return result
|
| 192 |
+
|
| 193 |
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple:
|
| 194 |
"""
|
| 195 |
Save the tokenizer vocabulary to a directory.
|
tokenizer.py
CHANGED
|
@@ -48,16 +48,27 @@ class CharacterTokenizer:
|
|
| 48 |
self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
|
| 49 |
self.vocab_size = len(self.char_to_id)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
def encode(self, text: str) -> list[int]:
|
| 52 |
"""Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
|
| 53 |
-
unk_id = self.char_to_id[self.unk_token]
|
| 54 |
-
punc_id = self.char_to_id[self.punc_token]
|
| 55 |
tokens = []
|
| 56 |
for char in text.lower():
|
| 57 |
-
|
| 58 |
-
tokens.append(self.char_to_id.get(char, unk_id))
|
| 59 |
-
else:
|
| 60 |
-
tokens.append(punc_id)
|
| 61 |
return tokens
|
| 62 |
|
| 63 |
def decode(self, token_ids: list[int]) -> str:
|
|
|
|
| 48 |
self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
|
| 49 |
self.vocab_size = len(self.char_to_id)
|
| 50 |
|
| 51 |
+
def encode_char(self, char: str) -> int:
|
| 52 |
+
"""Encode a single character to a token id (case-insensitive; punctuation -> [PUNC])."""
|
| 53 |
+
char = char.lower()
|
| 54 |
+
if char.isalpha() or char.isdigit():
|
| 55 |
+
return self.char_to_id.get(char, self.unk_token_id)
|
| 56 |
+
return self.punc_token_id
|
| 57 |
+
|
| 58 |
+
def token_to_id(self, token: str) -> int:
|
| 59 |
+
"""Map a token string to its id (supports specials and single characters)."""
|
| 60 |
+
direct = self.char_to_id.get(token)
|
| 61 |
+
if direct is not None:
|
| 62 |
+
return direct
|
| 63 |
+
if len(token) == 1:
|
| 64 |
+
return self.encode_char(token)
|
| 65 |
+
return self.unk_token_id
|
| 66 |
+
|
| 67 |
def encode(self, text: str) -> list[int]:
|
| 68 |
"""Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
|
|
|
|
|
|
|
| 69 |
tokens = []
|
| 70 |
for char in text.lower():
|
| 71 |
+
tokens.append(self.encode_char(char))
|
|
|
|
|
|
|
|
|
|
| 72 |
return tokens
|
| 73 |
|
| 74 |
def decode(self, token_ids: list[int]) -> str:
|
tokenizer_config.json
CHANGED
|
@@ -49,6 +49,12 @@
|
|
| 49 |
"special": true
|
| 50 |
}
|
| 51 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"clean_up_tokenization_spaces": false,
|
| 53 |
"cls_token": "[CLS]",
|
| 54 |
"eos_token": "[EOS]",
|
|
@@ -59,11 +65,5 @@
|
|
| 59 |
"processor_class": "SwipeProcessor",
|
| 60 |
"sep_token": "[SEP]",
|
| 61 |
"tokenizer_class": "SwipeTokenizer",
|
| 62 |
-
"unk_token": "[UNK]"
|
| 63 |
-
"auto_map": {
|
| 64 |
-
"AutoTokenizer": [
|
| 65 |
-
"tokenization_swipe.SwipeTokenizer",
|
| 66 |
-
null
|
| 67 |
-
]
|
| 68 |
-
}
|
| 69 |
}
|
|
|
|
| 49 |
"special": true
|
| 50 |
}
|
| 51 |
},
|
| 52 |
+
"auto_map": {
|
| 53 |
+
"AutoTokenizer": [
|
| 54 |
+
"tokenization_swipe.SwipeTokenizer",
|
| 55 |
+
null
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
"clean_up_tokenization_spaces": false,
|
| 59 |
"cls_token": "[CLS]",
|
| 60 |
"eos_token": "[EOS]",
|
|
|
|
| 65 |
"processor_class": "SwipeProcessor",
|
| 66 |
"sep_token": "[SEP]",
|
| 67 |
"tokenizer_class": "SwipeTokenizer",
|
| 68 |
+
"unk_token": "[UNK]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|