File size: 6,375 Bytes
bf31071 b121266 bf31071 b121266 bf31071 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
"""Character tokenizer and vocabulary utilities for swipe keyboard dataset."""
import hashlib
import torch
class CharacterTokenizer:
"""Character-level tokenizer for swipe keyboard words."""
def __init__(self, vocab: set | None = None):
"""
Initialize tokenizer with vocabulary.
Args:
vocab: Optional set of characters. If None, will use printable ASCII.
"""
# Special tokens
self.pad_token = "[PAD]"
self.cls_token = "[CLS]"
self.sep_token = "[SEP]"
self.mask_token = "[MASK]"
self.unk_token = "[UNK]"
self.eos_token = "[EOS]" # End of word token
self.punc_token = "[PUNC]"
self.special_tokens = [
self.pad_token, # 0
self.cls_token, # 1
self.sep_token, # 2
self.mask_token, # 3
self.unk_token, # 4
self.eos_token, # 5
self.punc_token, # 6
]
# Build vocabulary deterministically (lowercase letters + digits).
chars = set(chr(i) for i in range(ord("a"), ord("z") + 1))
chars.update(str(d) for d in range(10))
if vocab is not None:
# Allow explicit extension for special cases
chars.update(vocab)
self.char_to_id = {token: idx for idx, token in enumerate(self.special_tokens)}
for idx, char in enumerate(sorted(chars), start=len(self.special_tokens)):
self.char_to_id[char] = idx
self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
self.vocab_size = len(self.char_to_id)
def encode_char(self, char: str) -> int:
"""Encode a single character to a token id (case-insensitive; punctuation -> [PUNC])."""
char = char.lower()
if char.isalpha() or char.isdigit():
return self.char_to_id.get(char, self.unk_token_id)
return self.punc_token_id
def token_to_id(self, token: str) -> int:
"""Map a token string to its id (supports specials and single characters)."""
direct = self.char_to_id.get(token)
if direct is not None:
return direct
if len(token) == 1:
return self.encode_char(token)
return self.unk_token_id
def encode(self, text: str) -> list[int]:
"""Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
tokens = []
for char in text.lower():
tokens.append(self.encode_char(char))
return tokens
def decode(self, token_ids: list[int]) -> str:
"""Decode token IDs to text, stopping at EOS token."""
chars = []
for token_id in token_ids:
if token_id in self.id_to_char:
char = self.id_to_char[token_id]
# Stop at EOS token
if char == self.eos_token:
break
# Skip other special tokens except for debugging
if char not in self.special_tokens or char == " ":
chars.append(char)
return "".join(chars)
@property
def pad_token_id(self) -> int:
return self.char_to_id[self.pad_token]
@property
def cls_token_id(self) -> int:
return self.char_to_id[self.cls_token]
@property
def sep_token_id(self) -> int:
return self.char_to_id[self.sep_token]
@property
def mask_token_id(self) -> int:
return self.char_to_id[self.mask_token]
@property
def unk_token_id(self) -> int:
return self.char_to_id[self.unk_token]
@property
def eos_token_id(self) -> int:
return self.char_to_id[self.eos_token]
@property
def punc_token_id(self) -> int:
return self.char_to_id[self.punc_token]
def vocab_hash(tokenizer: CharacterTokenizer) -> str:
"""Stable hash of the tokenizer's id->token mapping (includes specials)."""
ordered_tokens = [tokenizer.id_to_char[i] for i in range(tokenizer.vocab_size)]
joined = "\n".join(ordered_tokens).encode("utf-8")
return hashlib.sha256(joined).hexdigest()
def compute_char_frequency_weights(
tokenizer: CharacterTokenizer,
dataset,
max_samples: int | None = None,
weight_exponent: float = 1.0,
):
"""Compute inverse log frequency weights for characters.
Args:
tokenizer: CharacterTokenizer used for encoding
dataset: HF dataset or iterable of samples with a 'word' field
max_samples: Optional cap on samples to scan
Returns:
torch.Tensor of shape [vocab_size] with weights normalized to mean=1.
Padding token weight is set to the non-pad mean (not zero) so min>0.
"""
counts = torch.ones(tokenizer.vocab_size, dtype=torch.float) # start at 1 for smoothing
# Collect all token IDs first for vectorized counting
all_token_ids = []
for idx, sample in enumerate(dataset):
if max_samples is not None and idx >= max_samples:
break
# Encode lowercase characters and append EOS (matches training labels)
token_ids = tokenizer.encode(sample["word"]) + [tokenizer.eos_token_id]
all_token_ids.extend(token_ids)
# Use bincount for efficient vectorized counting
if all_token_ids:
token_tensor = torch.tensor(all_token_ids, dtype=torch.long)
bincount_result = torch.bincount(token_tensor, minlength=tokenizer.vocab_size).float()
counts = counts + bincount_result
# Padding is never a supervised label, but keep a finite weight
pad_id = tokenizer.pad_token_id
counts[pad_id] = counts[pad_id] # leave smoothing value as-is
# Inverse log weighting; add 1 inside log to avoid div-by-zero
weights = 1.0 / torch.log1p(counts)
# Use non-pad mean for pad token to avoid zero/inf
non_pad_mask = torch.ones_like(weights, dtype=torch.bool)
non_pad_mask[pad_id] = False
non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8)
weights[pad_id] = non_pad_mean
# Optional tempering (e.g., exponent <1 flattens extremes)
if weight_exponent != 1.0:
weights = torch.pow(weights, weight_exponent)
# Normalize to keep loss scale stable (mean of non-pad tokens -> 1)
non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8)
weights[pad_id] = non_pad_mean
weights = weights / non_pad_mean
return weights
|