File size: 4,874 Bytes
e27ab6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from pathlib import Path
from datasets import Dataset
from tokenizers import (
    Tokenizer,
    models,
    normalizers,
    pre_tokenizers,
    decoders,
    trainers,
)
from tqdm.auto import tqdm
import wandb
from utils import get_raw_data


DATA_PATH = Path(r"..\data\IWSLT-15-en-vi")
# TOKENIZER_NAME = "iwslt_en-vi_tokenizer_16k.json"
TOKENIZER_NAME = "iwslt_en-vi_tokenizer_32k.json"
TOKENIZER_SAVE_PATH = Path(r"..\artifacts\tokenizers") / TOKENIZER_NAME

# VOCAB_SIZE: int = 16_000
VOCAB_SIZE: int = 32_000
SPECIAL_TOKENS: list[str] = ["[PAD]", "[UNK]", "[SOS]", "[EOS]"]

BATCH_SIZE_FOR_TOKENIZER: int = 10000
NUM_WORKERS: int = 8


def get_training_corpus(dataset: Dataset, batch_size: int = 1000):
    """
    A generator function to yield batches of text.

    This implementation uses dataset.iter(batch_size=...), which is the
    highly optimized, zero-copy Arrow iterator.

    We then use list comprehensions to extract the 'en' and 'vi' strings
    from the nested list of dictionaries returned by the iterator.
    """

    # We iterate over the dataset in batches
    # batch will be: {'translation': [list of 1000 dicts]}
    for batch in dataset.iter(batch_size=batch_size):

        # We must iterate through the list 'batch['translation']'
        # to extract the individual strings.

        # This list comprehension is fast and Pythonic.
        en_strings: list[str] = [item["en"] for item in batch["translation"]]
        vi_strings: list[str] = [item["vi"] for item in batch["translation"]]

        # Yield the batch of strings (which the trainer expects)
        yield en_strings
        yield vi_strings


def instantiate_tokenizer() -> Tokenizer:
    # 1. Initialize an empty Tokenizer with a BPE model
    tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))

    # 2. Set up the normalizer and pre-tokenizer
    # Normalizer: Cleans the text (e.g., Unicode, lowercase)
    tokenizer.normalizer = normalizers.Sequence(
        [
            normalizers.NFKC(),  # Unicode normalization
            normalizers.Lowercase(),  # Convert to lowercase
        ]
    )

    # Pre-tokenizer: Splits text into "words" (e.g., by space, punctuation)
    # BPE will then learn to merge sub-words from these.
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

    # Decoder: Reconstructs the string from tokens
    tokenizer.decoder = decoders.BPEDecoder()

    print("Tokenizer (empty) initialized.")
    return tokenizer


def train_tokenizer():
    # Initialize the BpeTrainer
    trainer = trainers.BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=SPECIAL_TOKENS)

    print("Tokenizer Trainer initialized.")

    train_dataset = get_raw_data(DATA_PATH, for_tokenizer=True)
    if not isinstance(train_dataset, Dataset):
        train_dataset = Dataset.from_list(train_dataset)
    print(f"Starting tokenizer training on {len(train_dataset)} pairs...")

    # 1. Define the iterator AND batch size
    text_iterator = get_training_corpus(
        train_dataset,
        batch_size=BATCH_SIZE_FOR_TOKENIZER,
    )

    # 2. Calculate total steps for the progress bar
    total_steps = (len(train_dataset) // BATCH_SIZE_FOR_TOKENIZER) * 2
    if total_steps == 0:
        total_steps = 1  # (Avoid division by zero if dataset is tiny)

    tokenizer: Tokenizer = instantiate_tokenizer()
    # 3. Train with tqdm progress bar
    try:
        tokenizer.train_from_iterator(
            tqdm(
                text_iterator,
                total=total_steps,
                desc="Training Tokenizer (IWSLT-Local)",
            ),
            trainer=trainer,
            length=total_steps,
        )
    except KeyboardInterrupt:
        print("\nTokenizer training interrupted by user.")

    print("Tokenizer training complete.")

    tokenizer.save(str(TOKENIZER_SAVE_PATH))

    print(f"Tokenizer saved to: {TOKENIZER_SAVE_PATH}")
    print(f"Total vocabulary size: {tokenizer.get_vocab_size()}")


if __name__ == "__main__":
    # dataset = get_raw_data()
    # print(type(dataset))

    # tokenizer: Tokenizer = instantiate_tokenizer()
    # tokenizer.save(str(TOKENIZER_SAVE_PATH))

    train_tokenizer()

    run = wandb.init(
        entity="alaindelong-hcmut",
        project="Attention Is All You Build",
        job_type="tokenizer-train",
    )

    # Log tokenizer
    tokenizer_artifact = wandb.Artifact(
        name="iwslt_en-vi_tokenizer",
        type="tokenizer",
        description="BPE Tokenizer trained on IWSLT 15 (133k+ pairs en-vi)",
        metadata={
            "vocab_size": 32000,
            "algorithm": "BPE",
            "framework": "huggingface",
            "training_data": "iwslt-15-en-vi-133k",
            "lower_case": False,
        },
    )
    tokenizer_artifact.add_file(local_path=str(TOKENIZER_SAVE_PATH))
    run.log_artifact(tokenizer_artifact, aliases=["baseline"])

    run.finish()