File size: 6,375 Bytes
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b121266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf31071
 
 
 
b121266
bf31071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""Character tokenizer and vocabulary utilities for swipe keyboard dataset."""

import hashlib

import torch


class CharacterTokenizer:
    """Character-level tokenizer for swipe keyboard words."""

    def __init__(self, vocab: set | None = None):
        """
        Initialize tokenizer with vocabulary.

        Args:
            vocab: Optional set of characters. If None, will use printable ASCII.
        """
        # Special tokens
        self.pad_token = "[PAD]"
        self.cls_token = "[CLS]"
        self.sep_token = "[SEP]"
        self.mask_token = "[MASK]"
        self.unk_token = "[UNK]"
        self.eos_token = "[EOS]"  # End of word token
        self.punc_token = "[PUNC]"

        self.special_tokens = [
            self.pad_token,  # 0
            self.cls_token,  # 1
            self.sep_token,  # 2
            self.mask_token,  # 3
            self.unk_token,  # 4
            self.eos_token,  # 5
            self.punc_token,  # 6
        ]

        # Build vocabulary deterministically (lowercase letters + digits).
        chars = set(chr(i) for i in range(ord("a"), ord("z") + 1))
        chars.update(str(d) for d in range(10))
        if vocab is not None:
            # Allow explicit extension for special cases
            chars.update(vocab)

        self.char_to_id = {token: idx for idx, token in enumerate(self.special_tokens)}
        for idx, char in enumerate(sorted(chars), start=len(self.special_tokens)):
            self.char_to_id[char] = idx

        self.id_to_char = {idx: char for char, idx in self.char_to_id.items()}
        self.vocab_size = len(self.char_to_id)

    def encode_char(self, char: str) -> int:
        """Encode a single character to a token id (case-insensitive; punctuation -> [PUNC])."""
        char = char.lower()
        if char.isalpha() or char.isdigit():
            return self.char_to_id.get(char, self.unk_token_id)
        return self.punc_token_id

    def token_to_id(self, token: str) -> int:
        """Map a token string to its id (supports specials and single characters)."""
        direct = self.char_to_id.get(token)
        if direct is not None:
            return direct
        if len(token) == 1:
            return self.encode_char(token)
        return self.unk_token_id

    def encode(self, text: str) -> list[int]:
        """Encode text to token IDs (case-insensitive, punctuation -> [PUNC])."""
        tokens = []
        for char in text.lower():
            tokens.append(self.encode_char(char))
        return tokens

    def decode(self, token_ids: list[int]) -> str:
        """Decode token IDs to text, stopping at EOS token."""
        chars = []
        for token_id in token_ids:
            if token_id in self.id_to_char:
                char = self.id_to_char[token_id]
                # Stop at EOS token
                if char == self.eos_token:
                    break
                # Skip other special tokens except for debugging
                if char not in self.special_tokens or char == " ":
                    chars.append(char)
        return "".join(chars)

    @property
    def pad_token_id(self) -> int:
        return self.char_to_id[self.pad_token]

    @property
    def cls_token_id(self) -> int:
        return self.char_to_id[self.cls_token]

    @property
    def sep_token_id(self) -> int:
        return self.char_to_id[self.sep_token]

    @property
    def mask_token_id(self) -> int:
        return self.char_to_id[self.mask_token]

    @property
    def unk_token_id(self) -> int:
        return self.char_to_id[self.unk_token]

    @property
    def eos_token_id(self) -> int:
        return self.char_to_id[self.eos_token]

    @property
    def punc_token_id(self) -> int:
        return self.char_to_id[self.punc_token]


def vocab_hash(tokenizer: CharacterTokenizer) -> str:
    """Stable hash of the tokenizer's id->token mapping (includes specials)."""
    ordered_tokens = [tokenizer.id_to_char[i] for i in range(tokenizer.vocab_size)]
    joined = "\n".join(ordered_tokens).encode("utf-8")
    return hashlib.sha256(joined).hexdigest()


def compute_char_frequency_weights(
    tokenizer: CharacterTokenizer,
    dataset,
    max_samples: int | None = None,
    weight_exponent: float = 1.0,
):
    """Compute inverse log frequency weights for characters.

    Args:
        tokenizer: CharacterTokenizer used for encoding
        dataset: HF dataset or iterable of samples with a 'word' field
        max_samples: Optional cap on samples to scan

    Returns:
        torch.Tensor of shape [vocab_size] with weights normalized to mean=1.
        Padding token weight is set to the non-pad mean (not zero) so min>0.
    """
    counts = torch.ones(tokenizer.vocab_size, dtype=torch.float)  # start at 1 for smoothing

    # Collect all token IDs first for vectorized counting
    all_token_ids = []
    for idx, sample in enumerate(dataset):
        if max_samples is not None and idx >= max_samples:
            break

        # Encode lowercase characters and append EOS (matches training labels)
        token_ids = tokenizer.encode(sample["word"]) + [tokenizer.eos_token_id]
        all_token_ids.extend(token_ids)

    # Use bincount for efficient vectorized counting
    if all_token_ids:
        token_tensor = torch.tensor(all_token_ids, dtype=torch.long)
        bincount_result = torch.bincount(token_tensor, minlength=tokenizer.vocab_size).float()
        counts = counts + bincount_result

    # Padding is never a supervised label, but keep a finite weight
    pad_id = tokenizer.pad_token_id
    counts[pad_id] = counts[pad_id]  # leave smoothing value as-is

    # Inverse log weighting; add 1 inside log to avoid div-by-zero
    weights = 1.0 / torch.log1p(counts)

    # Use non-pad mean for pad token to avoid zero/inf
    non_pad_mask = torch.ones_like(weights, dtype=torch.bool)
    non_pad_mask[pad_id] = False
    non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8)
    weights[pad_id] = non_pad_mean

    # Optional tempering (e.g., exponent <1 flattens extremes)
    if weight_exponent != 1.0:
        weights = torch.pow(weights, weight_exponent)

    # Normalize to keep loss scale stable (mean of non-pad tokens -> 1)
    non_pad_mean = weights[non_pad_mask].mean().clamp_min(1e-8)
    weights[pad_id] = non_pad_mean
    weights = weights / non_pad_mean

    return weights