frankenstallm / source /tokenizer /train_sp_tokenizer.py
pathcosmos's picture
Upload folder using huggingface_hub (#18)
5df4ae4
#!/usr/bin/env python3
"""
tokenizer/train_sp_tokenizer.py โ€” SentencePiece Unigram ํ•œ๊ตญ์–ด ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต.
ํ•œ๊ตญ์–ด 1์Œ์ ˆ(UTF-8 3๋ฐ”์ดํŠธ) = 1ํ† ํฐ์ด ๋˜๋„๋ก Unigram ๋ชจ๋ธ์„ ์‚ฌ์šฉ.
character_coverage=0.9995๋กœ ํ•œ๊ธ€ 11,172 ์Œ์ ˆ ์ „์ฒด ์ปค๋ฒ„.
Usage:
python tokenizer/train_sp_tokenizer.py \
--input "data/raw/namuwiki_ko/*.txt,data/raw/ko_wiki_0000.txt" \
--vocab_size 64000 \
--output_dir tokenizer/korean_sp
Output:
tokenizer/korean_sp/tokenizer.model (SentencePiece ๋ชจ๋ธ)
tokenizer/korean_sp/tokenizer.vocab (์–ดํœ˜ ๋ชฉ๋ก)
"""
from __future__ import annotations
import argparse
import glob
import os
import sys
import tempfile
from pathlib import Path
def expand_inputs(input_spec: str) -> list[str]:
"""์ฝค๋งˆ๋กœ ๊ตฌ๋ถ„๋œ ๊ธ€๋กœ๋ธŒ ํŒจํ„ด๋“ค์„ ์‹ค์ œ ํŒŒ์ผ ๊ฒฝ๋กœ ๋ชฉ๋ก์œผ๋กœ ํ™•์žฅ."""
files: list[str] = []
for pattern in input_spec.split(","):
pattern = pattern.strip()
if any(c in pattern for c in ("*", "?", "[")):
matched = sorted(glob.glob(pattern, recursive=True))
if not matched:
print(f"WARNING: ํŒจํ„ด์— ์ผ์น˜ํ•˜๋Š” ํŒŒ์ผ ์—†์Œ: {pattern!r}", file=sys.stderr)
files.extend(matched)
else:
if Path(pattern).exists():
files.append(pattern)
else:
print(f"WARNING: ํŒŒ์ผ ์—†์Œ: {pattern!r}", file=sys.stderr)
return files
def train(
input_files: list[str],
output_dir: Path,
vocab_size: int,
num_threads: int,
input_sentence_size: int,
) -> None:
try:
import sentencepiece as spm
except ImportError:
print(
"ERROR: sentencepiece๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์Œ.\n"
" pip install --break-system-packages sentencepiece",
file=sys.stderr,
)
sys.exit(1)
output_dir.mkdir(parents=True, exist_ok=True)
model_prefix = str(output_dir / "tokenizer")
print(f"์ž…๋ ฅ ํŒŒ์ผ ์ˆ˜: {len(input_files)}")
for f in input_files[:5]:
print(f" {f}")
if len(input_files) > 5:
print(f" ... ์™ธ {len(input_files) - 5}๊ฐœ")
print(f"์–ดํœ˜ ํฌ๊ธฐ: {vocab_size:,}")
print(f"์ถœ๋ ฅ ๊ฒฝ๋กœ: {model_prefix}.model / .vocab")
print()
# SentencePiece๋Š” ํŒŒ์ผ ๋ชฉ๋ก์„ ์ฝค๋งˆ๋กœ ๊ตฌ๋ถ„๋œ ๋‹จ์ผ ๋ฌธ์ž์—ด๋กœ ๋ฐ›๋Š”๋‹ค
input_str = ",".join(input_files)
spm.SentencePieceTrainer.train(
input=input_str,
model_prefix=model_prefix,
vocab_size=vocab_size,
model_type="unigram", # BPE๋ณด๋‹ค ํ•œ๊ตญ์–ด์— ์ž์—ฐ์Šค๋Ÿฌ์›€
character_coverage=0.9995, # ํ•œ๊ธ€ 11,172 ์Œ์ ˆ ์™„์ „ ์ปค๋ฒ„
normalization_rule_name="nfkc", # Unicode NFKC ์ •๊ทœํ™” (ํ•œ๊ตญ์–ด ํ˜ธํ™˜๋ฌธ์ž ํ†ต์ผ)
pad_id=0,
bos_id=1,
eos_id=2,
unk_id=3,
pad_piece="<pad>",
bos_piece="<s>",
eos_piece="</s>",
unk_piece="<unk>",
user_defined_symbols=[],
num_threads=num_threads,
input_sentence_size=input_sentence_size,
shuffle_input_sentence=True,
# ํ•™์Šต ์•ˆ์ •์„ฑ
seed_sentencepiece_size=1_000_000,
shrinking_factor=0.75,
max_sentence_length=4096,
)
model_path = Path(f"{model_prefix}.model")
vocab_path = Path(f"{model_prefix}.vocab")
if model_path.exists():
size_mb = model_path.stat().st_size / 1e6
print(f"ํ•™์Šต ์™„๋ฃŒ!")
print(f" ๋ชจ๋ธ: {model_path} ({size_mb:.1f} MB)")
print(f" ์–ดํœ˜: {vocab_path}")
print()
print("๋‹ค์Œ ๋‹จ๊ณ„:")
print(f" python tokenizer/convert_sp_to_hf.py \\")
print(f" --model {model_path} \\")
print(f" --output {output_dir}/tokenizer.json")
else:
print("ERROR: ํ•™์Šต ์‹คํŒจ โ€” ์ถœ๋ ฅ ํŒŒ์ผ์ด ์ƒ์„ฑ๋˜์ง€ ์•Š์Œ", file=sys.stderr)
sys.exit(1)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="SentencePiece Unigram ํ•œ๊ตญ์–ด ํ† ํฌ๋‚˜์ด์ € ํ•™์Šต",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input",
required=True,
help="์ฝค๋งˆ๋กœ ๊ตฌ๋ถ„๋œ ํŒŒ์ผ/๊ธ€๋กœ๋ธŒ ํŒจํ„ด (์˜ˆ: 'data/raw/ko/*.txt,data/raw/wiki.txt')",
)
parser.add_argument(
"--vocab_size",
type=int,
default=64000,
help="์–ดํœ˜ ํฌ๊ธฐ",
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("tokenizer/korean_sp"),
help="๋ชจ๋ธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ",
)
parser.add_argument(
"--num_threads",
type=int,
default=64,
help="ํ•™์Šต์— ์‚ฌ์šฉํ•  CPU ์Šค๋ ˆ๋“œ ์ˆ˜",
)
parser.add_argument(
"--input_sentence_size",
type=int,
default=10_000_000,
help="ํ•™์Šต์— ์‚ฌ์šฉํ•  ์ตœ๋Œ€ ๋ฌธ์žฅ ์ˆ˜ (0 = ๋ฌด์ œํ•œ)",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
input_files = expand_inputs(args.input)
if not input_files:
print("ERROR: ์ž…๋ ฅ ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค.", file=sys.stderr)
sys.exit(1)
train(
input_files=input_files,
output_dir=args.output_dir,
vocab_size=args.vocab_size,
num_threads=args.num_threads,
input_sentence_size=args.input_sentence_size,
)
if __name__ == "__main__":
main()