dleemiller commited on
Commit
d22b384
·
verified ·
1 Parent(s): db69a20

Upload folder using huggingface_hub

Browse files
conversion_metadata.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "original_checkpoint": "checkpoints/base_20251213_164813/best.pt",
3
  "original_config": "embedded_in_checkpoint",
4
- "converted_at": "2025-12-15 08:28:11.703039",
5
  "model_type": "base",
6
  "vocab_size": 43,
7
  "epoch": 38,
 
1
  {
2
  "original_checkpoint": "checkpoints/base_20251213_164813/best.pt",
3
  "original_config": "embedded_in_checkpoint",
4
+ "converted_at": "2025-12-15 12:38:28.627410",
5
  "model_type": "base",
6
  "vocab_size": 43,
7
  "epoch": 38,
processing_swipe.py CHANGED
@@ -111,9 +111,11 @@ class SwipeProcessor(ProcessorMixin):
111
  path_coords = torch.cat([path_coords, torch.zeros(batch_size, pad_len, 3)], dim=1)
112
 
113
  # Create path mask (1 = real data, 0 = padding)
 
114
  path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
115
- if padding and current_path_len < self.max_path_len:
116
- path_mask[:, current_path_len:] = 0
 
117
 
118
  result["path_coords"] = path_coords
119
  # Store path_mask internally for attention_mask construction
 
111
  path_coords = torch.cat([path_coords, torch.zeros(batch_size, pad_len, 3)], dim=1)
112
 
113
  # Create path mask (1 = real data, 0 = padding)
114
+ # Detect padding by checking for all-zero coordinates
115
  path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
116
+ # A point is padding if all its coordinates (x, y, t) are zero
117
+ is_padding = (path_coords == 0).all(dim=-1) # [batch, path_len]
118
+ path_mask[is_padding] = 0
119
 
120
  result["path_coords"] = path_coords
121
  # Store path_mask internally for attention_mask construction
tokenization_swipe.py CHANGED
@@ -141,7 +141,7 @@ class SwipeTokenizer(PreTrainedTokenizer):
141
  Returns:
142
  str: Concatenated string
143
  """
144
- # Filter out special tokens
145
  special_tokens = {
146
  self.pad_token,
147
  self.cls_token,
@@ -149,6 +149,7 @@ class SwipeTokenizer(PreTrainedTokenizer):
149
  self.mask_token,
150
  self.unk_token,
151
  self.eos_token,
 
152
  }
153
  filtered = [t for t in tokens if t not in special_tokens]
154
  return "".join(filtered)
 
141
  Returns:
142
  str: Concatenated string
143
  """
144
+ # Filter out special tokens (must include [PUNC] which represents punctuation)
145
  special_tokens = {
146
  self.pad_token,
147
  self.cls_token,
 
149
  self.mask_token,
150
  self.unk_token,
151
  self.eos_token,
152
+ "[PUNC]", # Punctuation token from CharacterTokenizer
153
  }
154
  filtered = [t for t in tokens if t not in special_tokens]
155
  return "".join(filtered)