"""Character tokenizer and vocabulary utilities for swipe keyboard dataset.""" import hashlib import torch class CharacterTokenizer: """Character-level tokenizer for swipe keyboard words.""" def __init__(self, vocab: set | None = None): """ Initialize tokenizer with vocabulary. Args: vocab: Optional set of characters. If None, will use printable ASCII. """ # Special tokens self.pad_token = "[PAD]" self.cls_token = "[CLS]" self.sep_token = "[SEP]" self.mask_token = "[MASK]" self.unk_token = "[UNK]" self.eos_token = "[EOS]" # End of word token self.punc_token = "[PUNC]" self.special_tokens = [ self.pad_token, # 0 self.cls_token, # 1 self.sep_token, # 2 self.mask_token, # 3 self.unk_token, # 4 self.eos_token, # 5 self.punc_token, # 6 ] # Build vocabulary deterministically (lowercase letters + digits). chars = set(chr(i) for i in range(ord("a"), ord("z") + 1)) chars.update(str(d) for d in range(10)) if vocab is not None: # Allow explicit extension for special cases chars.update(vocab) self.char_to_id = {token: idx for idx, token in enumerate(self.special_tokens)} for idx, char in enumerate(sorted(chars), start=len(self.special_tokens)): self.char_to_id[char] = idx self.id_to_char = {idx: char for char, idx in self.char_to_id.items()} self.vocab_size = len(self.char_to_id) def encode_char(self, char: str) -> int: """Encode a single character to a token id (case-insensitive; punctuation -> [PUNC]).""" char = char.lower() if char.isalpha() or char.isdigit(): return self.char_to_id.get(char, self.unk_token_id) return self.punc_token_id def token_to_id(self, token: str) -> int: """Map a token string to its id (supports specials and single characters).""" direct = self.char_to_id.get(token) if direct is not None: return direct if len(token) == 1: return self.encode_char(token) return self.unk_token_id def encode(self, text: str) -> list[int]: """Encode text to token IDs (case-insensitive, punctuation -> [PUNC]).""" tokens = [] for char in text.lower(): tokens.append(self.encode_char(char)) return tokens def decode(self, token_ids: list[int]) -> str: """Decode token IDs to text, stopping at EOS token.""" chars = [] for token_id in token_ids: if token_id in self.id_to_char: char = self.id_to_char[token_id] # Stop at EOS token if char == self.eos_token: break # Skip other special tokens except for debugging if char not in self.special_tokens or char == " ": chars.append(char) return "".join(chars) @property def pad_token_id(self) -> int: return self.char_to_id[self.pad_token] @property def cls_token_id(self) -> int: return self.char_to_id[self.cls_token] @property def sep_token_id(self) -> int: return self.char_to_id[self.sep_token] @property def mask_token_id(self) -> int: return self.char_to_id[self.mask_token] @property def unk_token_id(self) -> int: return self.char_to_id[self.unk_token] @property def eos_token_id(self) -> int: return self.char_to_id[self.eos_token] @property def punc_token_id(self) -> int: return self.char_to_id[self.punc_token] def vocab_hash(tokenizer: CharacterTokenizer) -> str: """Stable hash of the tokenizer's id->token mapping (includes specials).""" ordered_tokens = [tokenizer.id_to_char[i] for i in range(tokenizer.vocab_size)] joined = "\n".join(ordered_tokens).encode("utf-8") return hashlib.sha256(joined).hexdigest() def compute_char_frequency_weights( tokenizer: CharacterTokenizer, dataset, max_samples: int | None = None, weight_exponent: float = 1.0, ): """Compute inverse log frequency weights for characters. Args: tokenizer: CharacterTokenizer used for encoding dataset: HF dataset or iterable of samples with a 'word' field max_samples: Optional cap on samples to scan Returns: torch.Tensor of shape [vocab_size] with weights normalized to mean=1. Padding token weight is set to the non-pad mean (not zero) so min>0. """ counts = torch.ones(tokenizer.vocab_size, dtype=torch.float) # start at 1 for smoothing # Collect all token IDs first for vectorized counting all_token_ids = [] for idx, sample in enumerate(dataset): if max_samples is not None and idx >= max_samples: break # Encode lowercase characters and append EOS (matches training labels) token_ids = tokenizer.encode(sample["word"]) + [tokenizer.eos_token_id] all_token_ids.extend(token_ids) # Use bincount for efficient vectorized counting if all_token_ids: token_tensor = torch.tensor(all_token_ids, dtype=torch.long) bincount_result = torch.bincount(token_tensor, minlength=tokenizer.vocab_size).float() counts = counts + bincount_result # Padding is never a supervised label, but keep a finite weight pad_id = tokenizer.pad_token_id counts[pad_id] = counts[pad_id] # leave smoothing value as-is # Inverse log weighting; add 1 inside log to avoid div-by-zero weights = 1.0 / torch.log1p(counts) # Use non-pad mean for pad token to avoid zero/inf non_pad_mask = torch.ones_like(weights, dtype=torch.bool) non_pad_mask[pad_id] = False non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8) weights[pad_id] = non_pad_mean # Optional tempering (e.g., exponent <1 flattens extremes) if weight_exponent != 1.0: weights = torch.pow(weights, weight_exponent) # Normalize to keep loss scale stable (mean of non-pad tokens -> 1) non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8) weights[pad_id] = non_pad_mean weights = weights / non_pad_mean return weights