File size: 3,574 Bytes
7ac5f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
BPE tokenizer wrapper for H4 RAG.

Uses tiktoken's GPT-2 BPE but maps to a smaller vocabulary suitable for
our model size. The full 50K GPT-2 vocab is too large for a 64-256 dim model.
We build a restricted vocab from the training data — tokens that actually
appear — and map everything else to UNK.

This gives ~5x compression over character-level (512 BPE tokens ≈ 2500 chars)
which means seq_len=512 fits virtually all SQuAD contexts.
"""

import tiktoken
import numpy as np
from typing import Dict, List, Tuple, Optional
from collections import Counter


class BPETokenizer:
    """
    Restricted-vocabulary BPE tokenizer.

    Uses tiktoken GPT-2 as the base BPE, then maps token IDs to a
    smaller vocabulary built from training data. Unknown tokens map to UNK.

    Special tokens:
        0 = PAD
        1 = UNK
        2 = SEP (separator between context, question, answer)
    """

    N_SPECIAL = 3  # PAD, UNK, SEP

    def __init__(self, max_vocab: int = 4096):
        self.max_vocab = max_vocab
        self.base_enc = tiktoken.get_encoding('gpt2')

        # Will be populated by build_vocab()
        self.base_to_local: Dict[int, int] = {}
        self.local_to_base: Dict[int, int] = {}
        self.vocab_size = self.N_SPECIAL  # updated after build_vocab

    def build_vocab(self, texts: List[str]):
        """
        Build restricted vocabulary from training texts.
        Keeps the most frequent max_vocab - N_SPECIAL tokens.
        """
        counter = Counter()
        for text in texts:
            tokens = self.base_enc.encode(text, disallowed_special=())
            counter.update(tokens)

        # Keep most frequent tokens
        n_keep = self.max_vocab - self.N_SPECIAL
        most_common = counter.most_common(n_keep)

        self.base_to_local = {}
        self.local_to_base = {}
        for i, (base_id, _count) in enumerate(most_common):
            local_id = i + self.N_SPECIAL
            self.base_to_local[base_id] = local_id
            self.local_to_base[local_id] = base_id

        self.vocab_size = len(self.base_to_local) + self.N_SPECIAL
        coverage = sum(c for _, c in most_common) / sum(counter.values()) if counter else 0
        print(f"BPE vocab: {self.vocab_size} tokens "
              f"(from {len(counter)} unique, {coverage:.1%} coverage)")

    def encode(self, text: str) -> List[int]:
        """Encode text to local token IDs."""
        base_tokens = self.base_enc.encode(text, disallowed_special=())
        return [self.base_to_local.get(t, 1) for t in base_tokens]  # 1 = UNK

    def decode(self, ids: List[int]) -> str:
        """Decode local token IDs back to text."""
        base_ids = []
        for local_id in ids:
            if local_id < self.N_SPECIAL:
                continue  # skip PAD, UNK, SEP
            base_id = self.local_to_base.get(local_id)
            if base_id is not None:
                base_ids.append(base_id)
        return self.base_enc.decode(base_ids)

    def encode_qa(self, context: str, question: str, answer: Optional[str] = None) -> Tuple[List[int], Optional[List[int]]]:
        """
        Encode a QA pair.

        Returns:
            input_ids: [context_tokens, SEP, question_tokens, SEP]
            answer_ids: [answer_tokens] or None
        """
        ctx_ids = self.encode(context)
        q_ids = self.encode(question)
        input_ids = ctx_ids + [2] + q_ids + [2]  # 2 = SEP

        answer_ids = None
        if answer is not None:
            answer_ids = self.encode(answer)

        return input_ids, answer_ids