File size: 4,485 Bytes
53f0cc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """
Component 2 training script.
This script trains the custom code tokenizer and saves it for reuse.
Supported input formats:
- .jsonl with fields: prompt, code, language
- .txt where each line is one raw sample
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Iterable, Iterator, List
# This makes "src" imports work when script is run from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.tokenizer.code_tokenizer import CodeTokenizer, CodeTokenizerConfig
def stream_jsonl_samples(file_path: Path, tokenizer: CodeTokenizer) -> Iterator[str]:
"""
Streams JSONL rows as training text without loading full file into RAM.
"""
with file_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
prompt = str(row.get("prompt", "")).strip()
code = str(row.get("code", "")).strip()
language = str(row.get("language", "python")).strip().lower()
if not prompt or not code:
continue
if language not in {"python", "javascript"}:
language = "python"
yield tokenizer.format_training_sample(prompt=prompt, code=code, language=language)
def stream_txt_samples(file_path: Path) -> Iterator[str]:
"""
Streams plain text file line by line.
"""
with file_path.open("r", encoding="utf-8") as f:
for line in f:
text = line.strip()
if text:
yield text
def build_stream(input_files: List[Path], tokenizer: CodeTokenizer) -> Iterable[str]:
"""
Creates one merged iterator from many files.
"""
def _generator() -> Iterator[str]:
for path in input_files:
suffix = path.suffix.lower()
if suffix == ".jsonl":
yield from stream_jsonl_samples(path, tokenizer)
elif suffix == ".txt":
yield from stream_txt_samples(path)
else:
print(f"[warning] Skipping unsupported file type: {path}")
return _generator()
def parse_args() -> argparse.Namespace:
"""
Reads command-line settings for tokenizer training.
"""
parser = argparse.ArgumentParser(description="Train custom Python/JavaScript code tokenizer.")
parser.add_argument(
"--input",
nargs="+",
required=True,
help="One or more input files (.jsonl or .txt).",
)
parser.add_argument(
"--output_dir",
default="artifacts/tokenizer/code_tokenizer_v1",
help="Folder where tokenizer files will be saved.",
)
parser.add_argument("--vocab_size", type=int, default=50_000, help="Tokenizer vocabulary size.")
parser.add_argument("--min_frequency", type=int, default=2, help="Minimum token frequency.")
parser.add_argument("--model_max_length", type=int, default=2048, help="Max token length hint.")
return parser.parse_args()
def main() -> None:
"""
Main training entry point with clear error messages.
"""
args = parse_args()
try:
input_files = [Path(p) for p in args.input]
missing = [str(p) for p in input_files if not p.exists()]
if missing:
raise FileNotFoundError(
"Some input files do not exist:\n- " + "\n- ".join(missing)
)
config = CodeTokenizerConfig(
vocab_size=args.vocab_size,
min_frequency=args.min_frequency,
model_max_length=args.model_max_length,
)
tokenizer = CodeTokenizer(config=config)
text_stream = build_stream(input_files=input_files, tokenizer=tokenizer)
tokenizer.train(text_stream)
tokenizer.save(args.output_dir)
print("Tokenizer training completed successfully.")
print(f"Saved tokenizer to: {args.output_dir}")
print("Saved files: tokenizer.json, tokenizer_config.json")
except Exception as exc:
print("Tokenizer training failed.")
print(f"What went wrong: {exc}")
print("Fix suggestion: check file paths and file format, then run again.")
raise SystemExit(1)
if __name__ == "__main__":
main()
|