File size: 4,904 Bytes
360e354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from datasets import load_dataset
import argparse
import logging.config
import os
import random
import re
import sentencepiece as spm

from utils import default_logging_config

logger = logging.getLogger(__name__)

arg_parser = argparse.ArgumentParser(description="Train a sentencepiece tokenization model.")
arg_parser.add_argument("--train", action="store_true", default=False,
                        help="Train a sentencepiece tokenization model.")
arg_parser.add_argument("--wikipedia", action="store_true", default=False,
                        help="Use wikipedia dataset.")
args = arg_parser.parse_args()
logging.config.dictConfig(default_logging_config)

input_sentence_size = 9_000_000
max_line_char_len = 4192
vocab_size = 900_000

corpus_dir = "sp_data"
corpus_file_prefix = f"{corpus_dir}/sp_corpus"
model_file_prefix = "sp"
uber_chunk_file = f"{corpus_dir}/wikipedia_uber_chunks.txt"
white_space_pattern = re.compile(r"\s+")

if args.wikipedia:
    wikipedia_dataset_name = "20231101.en"
    wikipedia_dataset = load_dataset("wikimedia/wikipedia", wikipedia_dataset_name)
    total_page_cnt = len(wikipedia_dataset["train"])
    logger.info(f"loaded {wikipedia_dataset_name} containing {total_page_cnt} pages")

    max_processed_pages = total_page_cnt  # Change to single digits for spot checking / debugging
    pages_processed_cnt = 0

    corpus_file_part_idx = 0
    current_corpus_file_char_len = 0
    is_completed = False
    iter_idx = 0
    while not is_completed:  # Do till completed
        with open(f"{corpus_file_prefix}_{corpus_file_part_idx}.txt", "a", encoding="utf-8") as f:
            while iter_idx < (total_page_cnt - 1):
                page = wikipedia_dataset["train"][iter_idx]
                page_char_len = len(page["text"])  # Character len because bytes requires encoding
                if page_char_len + current_corpus_file_char_len > 1_000_000_000:
                    corpus_file_part_idx += 1  # New partition
                    current_corpus_file_char_len = 0  # Reset tally
                    break

                page_chunk_cnt = 0
                for page_chunk in page["text"].split("\n\n"):
                    page_chunk_len = len(page_chunk)
                    if not page_chunk or page_chunk[0] == " ":
                        continue
                    elif page_chunk_len > max_line_char_len:
                        with open(uber_chunk_file, "a", encoding="utf-8") as uber_chunk_f:
                            uber_chunk_f.write(page_chunk + "\n\n")
                        continue

                    page_chunk_lines = page_chunk.split("\n")
                    for chunk_line in page_chunk_lines:
                        if not chunk_line or chunk_line[0] == " ":
                            continue
                        elif len(white_space_pattern.split(chunk_line)) > 10:  # Require at least 10 naive tokens
                            f.write(chunk_line + "\n")
                            current_corpus_file_char_len += len(chunk_line)
                    page_chunk_cnt += 1

                iter_idx += 1
                pages_processed_cnt += 1

                if (pages_processed_cnt % 100) == 0:
                    logger.info(f"processed {pages_processed_cnt}/{total_page_cnt} pages")
                if pages_processed_cnt >= max_processed_pages:
                    is_completed = True
                    break
            if not is_completed and iter_idx == (total_page_cnt - 1):
                is_completed = True

if args.train:
    corpus_files = [f"{corpus_dir}/{f}" for f in os.listdir(corpus_dir) if f.startswith("sp_corpus")]
    logger.info(f"corpus_files: {corpus_files}")

    spm_training_args = [
        f"--model_prefix={model_file_prefix}",
        "--model_type=word",
        "--shuffle_input_sentence=true",
        #"--split_digits=true",
        "--split_digits=false",
        f"--input={','.join(random.sample(corpus_files, 15))}",
        f"--input_sentence_size={input_sentence_size}",
        f"--max_sentence_length={max_line_char_len}",
        f"--vocab_size={vocab_size}",
    ]
    spm.SentencePieceTrainer.Train(" ".join(spm_training_args))

# Now you can load the model and test it:
sp = spm.SentencePieceProcessor()
sp.LoadFromFile(f"{model_file_prefix}.model")

print(sp.EncodeAsPieces("Hello world!"))
print(sp.EncodeAsPieces("127.0.0.1 is the localhost address."))
print(sp.EncodeAsPieces("1/2 is equivalent to 0.5 or 50%"))
print(sp.EncodeAsPieces("John was running so fast, you can just tell he's a runner."))
print(sp.EncodeAsPieces("He excels at math and competed in the Math Olympiad"))
print(sp.EncodeAsPieces("Watson was on his way to 221B Baker Street when the robbery occurred."))
print(sp.EncodeAsPieces("That's Uncopyrightable."))
print(sp.EncodeAsPieces("She's full of incomprehensibilities."))
print(sp.EncodeAsPieces("He's a total sesquipedalian."))