|
|
"""Character tokenizer and vocabulary utilities for swipe keyboard dataset.""" |
|
|
|
|
|
import hashlib |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class CharacterTokenizer: |
|
|
"""Character-level tokenizer for swipe keyboard words.""" |
|
|
|
|
|
def __init__(self, vocab: set | None = None): |
|
|
""" |
|
|
Initialize tokenizer with vocabulary. |
|
|
|
|
|
Args: |
|
|
vocab: Optional set of characters. If None, will use printable ASCII. |
|
|
""" |
|
|
|
|
|
self.pad_token = "[PAD]" |
|
|
self.cls_token = "[CLS]" |
|
|
self.sep_token = "[SEP]" |
|
|
self.mask_token = "[MASK]" |
|
|
self.unk_token = "[UNK]" |
|
|
self.eos_token = "[EOS]" |
|
|
self.punc_token = "[PUNC]" |
|
|
|
|
|
self.special_tokens = [ |
|
|
self.pad_token, |
|
|
self.cls_token, |
|
|
self.sep_token, |
|
|
self.mask_token, |
|
|
self.unk_token, |
|
|
self.eos_token, |
|
|
self.punc_token, |
|
|
] |
|
|
|
|
|
|
|
|
chars = set(chr(i) for i in range(ord("a"), ord("z") + 1)) |
|
|
chars.update(str(d) for d in range(10)) |
|
|
if vocab is not None: |
|
|
|
|
|
chars.update(vocab) |
|
|
|
|
|
self.char_to_id = {token: idx for idx, token in enumerate(self.special_tokens)} |
|
|
for idx, char in enumerate(sorted(chars), start=len(self.special_tokens)): |
|
|
self.char_to_id[char] = idx |
|
|
|
|
|
self.id_to_char = {idx: char for char, idx in self.char_to_id.items()} |
|
|
self.vocab_size = len(self.char_to_id) |
|
|
|
|
|
def encode_char(self, char: str) -> int: |
|
|
"""Encode a single character to a token id (case-insensitive; punctuation -> [PUNC]).""" |
|
|
char = char.lower() |
|
|
if char.isalpha() or char.isdigit(): |
|
|
return self.char_to_id.get(char, self.unk_token_id) |
|
|
return self.punc_token_id |
|
|
|
|
|
def token_to_id(self, token: str) -> int: |
|
|
"""Map a token string to its id (supports specials and single characters).""" |
|
|
direct = self.char_to_id.get(token) |
|
|
if direct is not None: |
|
|
return direct |
|
|
if len(token) == 1: |
|
|
return self.encode_char(token) |
|
|
return self.unk_token_id |
|
|
|
|
|
def encode(self, text: str) -> list[int]: |
|
|
"""Encode text to token IDs (case-insensitive, punctuation -> [PUNC]).""" |
|
|
tokens = [] |
|
|
for char in text.lower(): |
|
|
tokens.append(self.encode_char(char)) |
|
|
return tokens |
|
|
|
|
|
def decode(self, token_ids: list[int]) -> str: |
|
|
"""Decode token IDs to text, stopping at EOS token.""" |
|
|
chars = [] |
|
|
for token_id in token_ids: |
|
|
if token_id in self.id_to_char: |
|
|
char = self.id_to_char[token_id] |
|
|
|
|
|
if char == self.eos_token: |
|
|
break |
|
|
|
|
|
if char not in self.special_tokens or char == " ": |
|
|
chars.append(char) |
|
|
return "".join(chars) |
|
|
|
|
|
@property |
|
|
def pad_token_id(self) -> int: |
|
|
return self.char_to_id[self.pad_token] |
|
|
|
|
|
@property |
|
|
def cls_token_id(self) -> int: |
|
|
return self.char_to_id[self.cls_token] |
|
|
|
|
|
@property |
|
|
def sep_token_id(self) -> int: |
|
|
return self.char_to_id[self.sep_token] |
|
|
|
|
|
@property |
|
|
def mask_token_id(self) -> int: |
|
|
return self.char_to_id[self.mask_token] |
|
|
|
|
|
@property |
|
|
def unk_token_id(self) -> int: |
|
|
return self.char_to_id[self.unk_token] |
|
|
|
|
|
@property |
|
|
def eos_token_id(self) -> int: |
|
|
return self.char_to_id[self.eos_token] |
|
|
|
|
|
@property |
|
|
def punc_token_id(self) -> int: |
|
|
return self.char_to_id[self.punc_token] |
|
|
|
|
|
|
|
|
def vocab_hash(tokenizer: CharacterTokenizer) -> str: |
|
|
"""Stable hash of the tokenizer's id->token mapping (includes specials).""" |
|
|
ordered_tokens = [tokenizer.id_to_char[i] for i in range(tokenizer.vocab_size)] |
|
|
joined = "\n".join(ordered_tokens).encode("utf-8") |
|
|
return hashlib.sha256(joined).hexdigest() |
|
|
|
|
|
|
|
|
def compute_char_frequency_weights( |
|
|
tokenizer: CharacterTokenizer, |
|
|
dataset, |
|
|
max_samples: int | None = None, |
|
|
weight_exponent: float = 1.0, |
|
|
): |
|
|
"""Compute inverse log frequency weights for characters. |
|
|
|
|
|
Args: |
|
|
tokenizer: CharacterTokenizer used for encoding |
|
|
dataset: HF dataset or iterable of samples with a 'word' field |
|
|
max_samples: Optional cap on samples to scan |
|
|
|
|
|
Returns: |
|
|
torch.Tensor of shape [vocab_size] with weights normalized to mean=1. |
|
|
Padding token weight is set to the non-pad mean (not zero) so min>0. |
|
|
""" |
|
|
counts = torch.ones(tokenizer.vocab_size, dtype=torch.float) |
|
|
|
|
|
|
|
|
all_token_ids = [] |
|
|
for idx, sample in enumerate(dataset): |
|
|
if max_samples is not None and idx >= max_samples: |
|
|
break |
|
|
|
|
|
|
|
|
token_ids = tokenizer.encode(sample["word"]) + [tokenizer.eos_token_id] |
|
|
all_token_ids.extend(token_ids) |
|
|
|
|
|
|
|
|
if all_token_ids: |
|
|
token_tensor = torch.tensor(all_token_ids, dtype=torch.long) |
|
|
bincount_result = torch.bincount(token_tensor, minlength=tokenizer.vocab_size).float() |
|
|
counts = counts + bincount_result |
|
|
|
|
|
|
|
|
pad_id = tokenizer.pad_token_id |
|
|
counts[pad_id] = counts[pad_id] |
|
|
|
|
|
|
|
|
weights = 1.0 / torch.log1p(counts) |
|
|
|
|
|
|
|
|
non_pad_mask = torch.ones_like(weights, dtype=torch.bool) |
|
|
non_pad_mask[pad_id] = False |
|
|
non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8) |
|
|
weights[pad_id] = non_pad_mean |
|
|
|
|
|
|
|
|
if weight_exponent != 1.0: |
|
|
weights = torch.pow(weights, weight_exponent) |
|
|
|
|
|
|
|
|
non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8) |
|
|
weights[pad_id] = non_pad_mean |
|
|
weights = weights / non_pad_mean |
|
|
|
|
|
return weights |
|
|
|