| """ |
| WWHO(SGPE) GPE Trainer |
| """ |
|
|
| import argparse |
| import gc |
| import heapq |
| import json |
| import logging |
| import os |
| import pickle |
| import re |
| import time |
| from collections import Counter, defaultdict |
| from multiprocessing import Pool, cpu_count |
|
|
| from tqdm import tqdm |
|
|
| from router import CodeSwitchSegmenter |
| from export import export_hf_tokenizer |
|
|
| |
|
|
| try: |
| import psutil as _psutil |
| def _ram_mb() -> str: |
| p = _psutil.Process() |
| rss = p.memory_info().rss / 1024**2 |
| avail = _psutil.virtual_memory().available / 1024**2 |
| return f"RSS={rss:.0f}MB avail={avail:.0f}MB" |
| except ImportError: |
| def _ram_mb() -> str: |
| try: |
| with open("/proc/meminfo") as f: |
| info = {l.split(":")[0]: int(l.split()[1]) |
| for l in f if ":" in l} |
| avail = info.get("MemAvailable", 0) // 1024 |
| return f"avail={avail}MB" |
| except Exception: |
| return "ram=N/A" |
|
|
| _logger: logging.Logger | None = None |
|
|
| def _log(msg: str): |
| full = f"[{time.strftime('%H:%M:%S')}] [{_ram_mb()}] {msg}" |
| print(full, flush=True) |
| if _logger: |
| _logger.info(full) |
|
|
| def _setup_logging(output_dir: str): |
| global _logger |
| os.makedirs(output_dir, exist_ok=True) |
| log_path = os.path.join(output_dir, "training.log") |
| logging.basicConfig( |
| filename=log_path, |
| level=logging.INFO, |
| format="%(message)s", |
| ) |
| _logger = logging.getLogger("wwho_trainer") |
| _log(f"Log started: {log_path}") |
|
|
|
|
| SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] |
|
|
| |
| _worker_segmenter: CodeSwitchSegmenter | None = None |
| _worker_dfa_map: dict | None = None |
| _worker_script_mode: str = "mixed" |
|
|
|
|
| def _init_worker(script_mode: str): |
| global _worker_segmenter, _worker_dfa_map, _worker_script_mode |
| from linguis_trie import load_dfa_map |
| |
| _worker_script_mode = script_mode |
| _worker_dfa_map = load_dfa_map(script_mode) |
| |
| language_blocks = {lang: dfa.unicode_blocks for lang, dfa in _worker_dfa_map.items()} |
| _worker_segmenter = CodeSwitchSegmenter(language_blocks) |
|
|
|
|
| def _pretokenize_line(text: str) -> list[str]: |
| tokens: list[str] = [] |
|
|
| for seg in _worker_segmenter.segment(text): |
| if seg.language == "latin": |
| tokens.append(seg.text) |
| else: |
| dfa = _worker_dfa_map.get(seg.language) |
| if not dfa: |
| tokens.append(seg.text) |
| continue |
| syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space) |
| tokens.extend(syllables) |
|
|
| return tokens |
|
|
|
|
| def _is_boundary_token(token: str) -> bool: |
| for ch in token: |
| if _worker_segmenter: |
| lang = _worker_segmenter._get_char_language(ch) |
| if lang is not None and lang != "latin": |
| return False |
| return True |
|
|
| def segment_into_words(syllables: list[str]) -> list[list[str]]: |
| words: list[list[str]] = [] |
| current: list[str] = [] |
|
|
| for tok in syllables: |
| if _is_boundary_token(tok): |
| if current: |
| words.append(current) |
| current = [] |
| words.append([tok]) |
| else: |
| if tok[0] in (' ', '\t', '\n', '\r') and current: |
| words.append(current) |
| current = [] |
| current.append(tok) |
|
|
| if current: |
| words.append(current) |
| return words |
|
|
|
|
| |
|
|
| class SymbolTable: |
| def __init__(self): |
| self._str_to_id: dict[str, int] = {} |
| self._id_to_str: list[str] = [] |
|
|
| def get_or_add(self, token: str) -> int: |
| if token in self._str_to_id: |
| return self._str_to_id[token] |
| new_id = len(self._id_to_str) |
| self._str_to_id[token] = new_id |
| self._id_to_str.append(token) |
| return new_id |
|
|
| def add_merged(self, a_id: int, b_id: int) -> int: |
| merged_str = self._id_to_str[a_id] + self._id_to_str[b_id] |
| return self.get_or_add(merged_str) |
|
|
| def to_str(self, token_id: int) -> str: |
| return self._id_to_str[token_id] |
|
|
| def to_id(self, token: str) -> int | None: |
| return self._str_to_id.get(token) |
|
|
| def __len__(self) -> int: |
| return len(self._id_to_str) |
|
|
|
|
| |
|
|
| class GPETrainer: |
|
|
| def __init__( |
| self, |
| vocab_size: int = 128_000, |
| min_freq: int = 2, |
| num_workers: int | None = None, |
| checkpoint_every: int = 20_000, |
| prune_freq: int = 100, |
| script_mode: str = "mixed", |
| ): |
| self.target_vocab_size = vocab_size |
| self.min_freq = min_freq |
| self.num_workers = num_workers or max(1, cpu_count() - 1) |
| self.checkpoint_every = checkpoint_every |
| self.prune_freq = prune_freq |
| self.script_mode = script_mode |
| self.merges: list[tuple[int, int]] = [] |
| self.symbols = SymbolTable() |
|
|
| def stream_and_count( |
| self, train_file: str, output_dir: str = "output" |
| ) -> tuple[Counter, set[str]]: |
| |
| print(" counting lines...", end=" ", flush=True) |
| with open(train_file, "r", encoding="utf-8") as f: |
| num_lines = sum(1 for _ in f) |
| print(f"{num_lines:,}") |
|
|
| CHUNK_SIZE = 5_000_000 |
| BATCH = 4_096 |
|
|
| partial_dir = os.path.join(output_dir, "_partial_counters") |
| os.makedirs(partial_dir, exist_ok=True) |
|
|
| _init_worker(self.script_mode) |
|
|
| total_lines = 0 |
| chunk_idx = 0 |
| partial_paths: list[str] = [] |
|
|
| PARTIAL_PRUNE = 2 |
| def _save_partial(counter: Counter, idx: int, n_sent: int): |
| if PARTIAL_PRUNE > 1: |
| to_save = Counter( |
| {k: v for k, v in counter.items() if v >= PARTIAL_PRUNE} |
| ) |
| else: |
| to_save = counter |
| pkl_path = os.path.join(partial_dir, f"partial_{idx:04d}.pkl") |
| with open(pkl_path, "wb") as pf: |
| pickle.dump(to_save, pf, protocol=pickle.HIGHEST_PROTOCOL) |
| partial_paths.append(pkl_path) |
| pkl_mb = os.path.getsize(pkl_path) / 1024**2 |
| pbar.write( |
| f" chunk {idx+1} done: {n_sent:,} sent " |
| f"-> {len(to_save):,} word types (pruned from {len(counter):,}) " |
| f"-> {pkl_path} ({pkl_mb:.0f} MB)" |
| ) |
| _log(f"CHUNK {idx+1} saved: {n_sent:,} sent, " |
| f"{len(to_save):,} word types, {pkl_mb:.0f} MB") |
| del to_save |
| counter.clear() |
| gc.collect() |
| _log(f"CHUNK {idx+1} post-gc") |
|
|
| chunk_counter: Counter = Counter() |
| chunk_sent = 0 |
| batch_buf: list[str] = [] |
|
|
| pool = Pool( |
| processes=self.num_workers, |
| initializer=_init_worker, |
| initargs=(self.script_mode,), |
| ) |
|
|
| with open(train_file, "r", encoding="utf-8") as f: |
| pbar = tqdm(f, total=num_lines, unit=" sent", |
| desc=f" pre-tokenizing [chunk 1]") |
|
|
| for raw_line in pbar: |
| try: |
| obj = json.loads(raw_line) |
| text = obj.get("text", "").strip() |
| except json.JSONDecodeError: |
| text = raw_line.strip() |
| if not text: |
| continue |
|
|
| batch_buf.append(text) |
| total_lines += 1 |
| chunk_sent += 1 |
|
|
| if len(batch_buf) >= BATCH: |
| self._process_batch(pool, batch_buf, chunk_counter) |
| batch_buf = [] |
| if chunk_sent >= CHUNK_SIZE: |
| if batch_buf: |
| self._process_batch(pool, batch_buf, chunk_counter) |
| batch_buf = [] |
| pool.close() |
| pool.join() |
| pool = None |
| gc.collect() |
|
|
| _save_partial(chunk_counter, chunk_idx, chunk_sent) |
| chunk_idx += 1 |
| chunk_sent = 0 |
| pbar.set_description( |
| f" pre-tokenizing [chunk {chunk_idx + 1}]" |
| ) |
| gc.collect() |
|
|
| pool = Pool( |
| processes=self.num_workers, |
| initializer=_init_worker, |
| initargs=(self.script_mode,), |
| ) |
|
|
| if batch_buf: |
| self._process_batch(pool, batch_buf, chunk_counter) |
| pool.close() |
| pool.join() |
| gc.collect() |
|
|
| if chunk_counter: |
| _save_partial(chunk_counter, chunk_idx, chunk_sent) |
| chunk_idx += 1 |
|
|
| pbar.close() |
|
|
| print(f" {total_lines:,} sentences -> {chunk_idx} chunks processed") |
|
|
| |
| _log(f"MERGE START: {len(partial_paths)} partial counters, min_freq={self.min_freq}") |
| N = len(partial_paths) |
| word_counter: Counter = Counter() |
| for i, pkl_path in enumerate(partial_paths): |
| _log(f"MERGE [{i+1}/{N}] loading {pkl_path}") |
| with open(pkl_path, "rb") as pf: |
| partial: Counter = pickle.load(pf) |
| _log(f"MERGE [{i+1}/{N}] loaded {len(partial):,} types, updating master...") |
| word_counter.update(partial) |
| del partial |
| gc.collect() |
| _log(f"MERGE [{i+1}/{N}] after update+gc: {len(word_counter):,} types") |
|
|
| remaining = N - i - 1 |
| safe_prune = max(1, self.min_freq - remaining) |
| before = len(word_counter) |
| |
| if safe_prune > 1: |
| word_counter = Counter( |
| {k: v for k, v in word_counter.items() if v >= safe_prune} |
| ) |
| |
| if i > 0 and i % 5 == 0: |
| hard_threshold = max(2, self.min_freq // 2) |
| word_counter = Counter( |
| {k: v for k, v in word_counter.items() if v >= hard_threshold} |
| ) |
| _log(f"MERGE [{i+1}/{N}] HARD PRUNE TRIGGERED (threshold={hard_threshold})") |
|
|
| gc.collect() |
| pruned_n = before - len(word_counter) |
| |
| if pruned_n > 0: |
| msg = (f" [{i+1}/{N}] merged -> {len(word_counter):,} types " |
| f"(pruned {pruned_n:,})") |
| print(msg, flush=True) |
| _log(f"MERGE [{i+1}/{N}] post-prune: {len(word_counter):,} types " |
| f"(removed {pruned_n:,})") |
| else: |
| print(f" [{i+1}/{N}] merged -> {len(word_counter):,} types", flush=True) |
| _log(f"MERGE [{i+1}/{N}] no prune needed, {len(word_counter):,} types") |
| |
| os.remove(pkl_path) |
| _log(f"MERGE [{i+1}/{N}] deleted {pkl_path}") |
|
|
| try: |
| os.rmdir(partial_dir) |
| except OSError: |
| pass |
|
|
| n_types = len(word_counter) |
| n_instances = sum(word_counter.values()) |
| print(f"\n Final: {total_lines:,} sent -> {n_types:,} word types " |
| f"({n_instances:,} instances)") |
| return word_counter, set() |
|
|
| def _process_batch( |
| self, |
| pool: Pool, |
| batch: list[str], |
| word_counter: Counter, |
| ): |
| syllable_streams = pool.map(_pretokenize_line, batch, chunksize=128) |
|
|
| for stream in syllable_streams: |
| words = segment_into_words(stream) |
| for w in words: |
| if not w: |
| continue |
| if not _is_boundary_token(w[0]): |
| word_counter[tuple(w)] += 1 |
|
|
| @staticmethod |
| def compute_syllable_freqs(word_counter: Counter) -> Counter: |
| syl_freq: Counter[str] = Counter() |
| for word_tuple, word_freq in word_counter.items(): |
| for syl in word_tuple: |
| syl_freq[syl] += word_freq |
| return syl_freq |
|
|
| def build_word_types( |
| self, |
| word_counter: Counter, |
| boundary_tokens: set[str], |
| syl_freq: Counter | None = None, |
| ) -> tuple[list[list[int]], list[int]]: |
| UNK_SENTINEL = -1 |
| pruned_set: set[str] = set() |
|
|
| if syl_freq is not None and self.prune_freq > 0: |
| for syl, freq in syl_freq.items(): |
| if freq < self.prune_freq: |
| pruned_set.add(syl) |
|
|
| word_types: list[list[int]] = [] |
| word_freqs: list[int] = [] |
| pruned_word_count = 0 |
|
|
| for word_tuple, freq in word_counter.items(): |
| ids = [] |
| for tok in word_tuple: |
| if tok in pruned_set: |
| ids.append(UNK_SENTINEL) |
| else: |
| ids.append(self.symbols.get_or_add(tok)) |
| word_types.append(ids) |
| word_freqs.append(freq) |
| if UNK_SENTINEL in ids: |
| pruned_word_count += 1 |
|
|
| if pruned_set: |
| print(f" pruned {len(pruned_set):,} rare syllables (freq < {self.prune_freq})") |
| print(f" {pruned_word_count:,} word types contain [UNK] syllables") |
|
|
| return word_types, word_freqs |
|
|
| @staticmethod |
| def build_token_index(word_types: list[list[int]]) -> dict[int, set[int]]: |
| index: dict[int, set[int]] = defaultdict(set) |
| for wt_idx, wt in enumerate(word_types): |
| for tid in wt: |
| if tid >= 0: |
| index[tid].add(wt_idx) |
| return dict(index) |
|
|
| def count_all_pairs( |
| self, |
| word_types: list[list[int]], |
| word_freqs: list[int], |
| ) -> dict[tuple[int, int], int]: |
| counts: dict[tuple[int, int], int] = defaultdict(int) |
| for wt_idx, wt in enumerate(word_types): |
| f = word_freqs[wt_idx] |
| for i in range(len(wt) - 1): |
| a, b = wt[i], wt[i + 1] |
| if a < 0 or b < 0: |
| continue |
| counts[(a, b)] += f |
| return dict(counts) |
|
|
| @staticmethod |
| def _build_heap(pair_counts: dict) -> list: |
| heap = [(-freq, pair) for pair, freq in pair_counts.items() if freq > 0] |
| heapq.heapify(heap) |
| return heap |
|
|
| @staticmethod |
| def _heap_push(heap, pair, freq): |
| if freq > 0: |
| heapq.heappush(heap, (-freq, pair)) |
|
|
| def _pop_best(self, heap, pair_counts): |
| while heap: |
| neg_freq, pair = heapq.heappop(heap) |
| actual = pair_counts.get(pair, 0) |
| if actual <= 0: |
| continue |
| if actual != -neg_freq: |
| self._heap_push(heap, pair, actual) |
| continue |
| return pair, actual |
| return None, 0 |
|
|
| def merge_and_update( |
| self, |
| word_types: list[list[int]], |
| word_freqs: list[int], |
| pair: tuple[int, int], |
| pair_counts: dict[tuple[int, int], int], |
| token_index: dict[int, set[int]], |
| merged_id: int, |
| heap: list, |
| ) -> int: |
| a, b = pair |
| total_applied = 0 |
| candidates = list(token_index.get(a, set()) & token_index.get(b, set())) |
| pair_counts.pop(pair, None) |
| dirty_pairs: dict[tuple[int, int], int] = {} |
|
|
| for wt_idx in candidates: |
| wt = word_types[wt_idx] |
| freq = word_freqs[wt_idx] |
| if len(wt) < 2: |
| continue |
| new_wt: list[int] = [] |
| i = 0 |
| changed = False |
|
|
| while i < len(wt): |
| if i + 1 < len(wt) and wt[i] == a and wt[i + 1] == b: |
| if new_wt and new_wt[-1] >= 0: |
| lp = (new_wt[-1], a) |
| pair_counts[lp] = pair_counts.get(lp, 0) - freq |
| dirty_pairs[lp] = pair_counts[lp] |
| if i + 2 < len(wt) and wt[i + 2] >= 0: |
| rp = (b, wt[i + 2]) |
| pair_counts[rp] = pair_counts.get(rp, 0) - freq |
| dirty_pairs[rp] = pair_counts[rp] |
| new_wt.append(merged_id) |
| total_applied += freq |
| changed = True |
| if len(new_wt) >= 2 and new_wt[-2] >= 0: |
| lp = (new_wt[-2], merged_id) |
| pair_counts[lp] = pair_counts.get(lp, 0) + freq |
| dirty_pairs[lp] = pair_counts[lp] |
| if i + 2 < len(wt) and wt[i + 2] >= 0: |
| rp = (merged_id, wt[i + 2]) |
| pair_counts[rp] = pair_counts.get(rp, 0) + freq |
| dirty_pairs[rp] = pair_counts[rp] |
| i += 2 |
| else: |
| new_wt.append(wt[i]) |
| i += 1 |
|
|
| if changed: |
| word_types[wt_idx] = new_wt |
| if merged_id not in token_index: |
| token_index[merged_id] = set() |
| token_index[merged_id].add(wt_idx) |
| remaining = set(new_wt) |
| if a not in remaining and wt_idx in token_index.get(a, set()): |
| token_index[a].discard(wt_idx) |
| if b not in remaining and wt_idx in token_index.get(b, set()): |
| token_index[b].discard(wt_idx) |
|
|
| for tok_id in (a, b): |
| if tok_id in token_index and not token_index[tok_id]: |
| del token_index[tok_id] |
|
|
| for p, cnt in dirty_pairs.items(): |
| if cnt <= 0: |
| pair_counts.pop(p, None) |
| else: |
| self._heap_push(heap, p, cnt) |
|
|
| return total_applied |
|
|
| def save_checkpoint(self, step: int, output_dir: str, elapsed: float): |
| merge_strs = [ |
| [self.symbols.to_str(a), self.symbols.to_str(b)] |
| for a, b in self.merges |
| ] |
| ckpt = { |
| "step": step, |
| "script_mode": self.script_mode, |
| "merges": merge_strs, |
| "elapsed_seconds": round(elapsed, 1), |
| } |
| path = os.path.join(output_dir, f"checkpoint_{step}.json") |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(ckpt, f, ensure_ascii=False) |
| size_mb = os.path.getsize(path) / (1024 * 1024) |
| return path, size_mb |
|
|
| def load_checkpoint(self, ckpt_path: str): |
| with open(ckpt_path, "r", encoding="utf-8") as f: |
| ckpt = json.load(f) |
| print(f" loaded checkpoint: step {ckpt['step']}, " |
| f"{len(ckpt['merges'])} merges, " |
| f"{ckpt['elapsed_seconds']:.1f}s elapsed") |
| return ckpt |
|
|
| def replay_merges(self, merge_strs, word_types, word_freqs, token_index, pair_counts): |
| print(f" replaying {len(merge_strs)} merges...", flush=True) |
| t0 = time.time() |
| dummy_heap: list = [] |
| for a_str, b_str in tqdm(merge_strs, desc=" replaying", unit=" merge"): |
| a_id = self.symbols.to_id(a_str) |
| b_id = self.symbols.to_id(b_str) |
| if a_id is None or b_id is None: |
| continue |
| merged_id = self.symbols.add_merged(a_id, b_id) |
| self.merges.append((a_id, b_id)) |
| self.merge_and_update( |
| word_types, word_freqs, (a_id, b_id), pair_counts, |
| token_index, merged_id, dummy_heap, |
| ) |
| print(f" replayed {len(self.merges)} merges in {time.time()-t0:.1f}s") |
|
|
| def train(self, train_file: str, output_dir: str = "output", |
| resume_path: str | None = None): |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| print(f"WWHO (SGPE) GPE Trainer — script_mode={self.script_mode}, " |
| f"workers={self.num_workers}") |
| print(f"Training file: {train_file}\n") |
|
|
| print("[1/5] Streaming pre-tokenization (CodeSwitchRouter)...") |
| t_start = time.time() |
| word_counter, boundary_tokens = self.stream_and_count(train_file, output_dir) |
|
|
| print("\n[2/5] Building ID corpus...") |
| syl_freq = None |
| if self.prune_freq > 0: |
| syl_freq = self.compute_syllable_freqs(word_counter) |
| total_syls = len(syl_freq) |
| surviving = sum(1 for f in syl_freq.values() if f >= self.prune_freq) |
| print(f" syllable pruning: {total_syls:,} unique syllables, " |
| f"{surviving:,} survive (freq >= {self.prune_freq})") |
|
|
| word_types, word_freqs = self.build_word_types( |
| word_counter, boundary_tokens, syl_freq=syl_freq, |
| ) |
| del word_counter, syl_freq |
|
|
| base_vocab = len(self.symbols) |
| total_instances = sum(word_freqs) |
| print(f" base vocab (syllables + boundaries): {base_vocab:,}") |
| print(f" word types: {len(word_types):,} ({total_instances:,} instances)") |
|
|
| print("\n[3/5] Building index and counting pairs...") |
| token_index = self.build_token_index(word_types) |
| pair_counts = self.count_all_pairs(word_types, word_freqs) |
| print(f" {len(pair_counts):,} unique pairs") |
|
|
| start_step = 0 |
| elapsed_prior = 0.0 |
| if resume_path: |
| print(f"\n Resuming from {resume_path}...") |
| ckpt = self.load_checkpoint(resume_path) |
| self.replay_merges( |
| ckpt["merges"], word_types, word_freqs, token_index, pair_counts, |
| ) |
| start_step = ckpt["step"] |
| elapsed_prior = ckpt["elapsed_seconds"] |
| pair_counts = self.count_all_pairs(word_types, word_freqs) |
| print(f" rebuilt pair counts: {len(pair_counts):,} unique pairs") |
|
|
| total_vocab_needed = self.target_vocab_size - len(SPECIAL_TOKENS) |
| num_merges = max(0, total_vocab_needed - base_vocab) |
| remaining = num_merges - start_step |
| print(f"\n merge budget: {num_merges:,} " |
| f"(starting at {start_step}, {remaining:,} remaining, min_freq={self.min_freq})") |
|
|
| print(f"\n[4/5] Merge loop...") |
| heap = self._build_heap(pair_counts) |
| t0 = time.time() |
| pbar = tqdm(range(start_step + 1, num_merges + 1), |
| desc=" merging", unit=" merge") |
|
|
| for step in pbar: |
| best_pair, freq = self._pop_best(heap, pair_counts) |
| if best_pair is None or freq < self.min_freq: |
| pbar.write(f" stopping at step {step}: " |
| f"{'no pairs' if best_pair is None else f'freq={freq} < {self.min_freq}'}") |
| break |
|
|
| a_id, b_id = best_pair |
| merged_id = self.symbols.add_merged(a_id, b_id) |
| self.merges.append(best_pair) |
|
|
| n_applied = self.merge_and_update( |
| word_types, word_freqs, best_pair, pair_counts, |
| token_index, merged_id, heap, |
| ) |
|
|
| if step <= 10 or step % 1000 == 0: |
| a_s = self.symbols.to_str(a_id) |
| b_s = self.symbols.to_str(b_id) |
| m_s = self.symbols.to_str(merged_id) |
| elapsed = time.time() - t0 + elapsed_prior |
| pbar.write(f" [{step:>7}/{num_merges}] " |
| f"'{a_s}' + '{b_s}' -> '{m_s}' " |
| f"(freq={freq:,}, applied={n_applied:,}) [{elapsed:.1f}s]") |
|
|
| if self.checkpoint_every > 0 and step % self.checkpoint_every == 0: |
| elapsed = time.time() - t0 + elapsed_prior |
| path, sz = self.save_checkpoint(step, output_dir, elapsed) |
| pbar.write(f" >> checkpoint: {path} ({sz:.2f} MB)") |
|
|
| pbar.set_postfix(freq=freq, vocab=len(self.symbols)) |
|
|
| pbar.close() |
| merge_elapsed = time.time() - t0 |
| total_elapsed = merge_elapsed + elapsed_prior |
| print(f" done: {len(self.merges)} merges in {merge_elapsed:.1f}s " |
| f"(total {total_elapsed:.1f}s)") |
|
|
| print("\n[5/5] Building vocabulary and exporting...") |
| self._save_output(word_types, word_freqs, boundary_tokens, output_dir) |
|
|
| wall = time.time() - t_start |
| print(f"\nTotal wall time: {wall:.1f}s ({wall/60:.1f} min)") |
|
|
| def _save_output(self, word_types, word_freqs, boundary_tokens, output_dir): |
| final_freq: Counter[int] = Counter() |
| for wt_idx, wt in enumerate(word_types): |
| f = word_freqs[wt_idx] |
| for tid in wt: |
| if tid >= 0: |
| final_freq[tid] += f |
|
|
| vocab: dict[str, int] = {} |
| for i, st in enumerate(SPECIAL_TOKENS): |
| vocab[st] = i |
| next_id = len(SPECIAL_TOKENS) |
|
|
| for tid, _ in final_freq.most_common(): |
| if len(vocab) >= self.target_vocab_size: |
| break |
| tok_str = self.symbols.to_str(tid) |
| if tok_str not in vocab: |
| vocab[tok_str] = next_id |
| next_id += 1 |
|
|
| for sid in range(len(self.symbols)): |
| if len(vocab) >= self.target_vocab_size: |
| break |
| s = self.symbols.to_str(sid) |
| if s not in vocab: |
| vocab[s] = next_id |
| next_id += 1 |
|
|
| print(f" vocab size: {len(vocab):,}") |
| print(f" merge rules: {len(self.merges):,}") |
|
|
| merge_strs = [ |
| [self.symbols.to_str(a), self.symbols.to_str(b)] |
| for a, b in self.merges |
| ] |
|
|
| output = { |
| "version": "wwho_sgpe", |
| "script_mode": self.script_mode, |
| "vocab_size": len(vocab), |
| "special_tokens": SPECIAL_TOKENS, |
| "num_merges": len(self.merges), |
| "prune_freq": self.prune_freq, |
| "leading_space": True, |
| "merges": merge_strs, |
| "vocab": vocab, |
| } |
|
|
| path = os.path.join(output_dir, "vocab.json") |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(output, f, ensure_ascii=False, indent=2) |
| size_mb = os.path.getsize(path) / (1024 * 1024) |
| print(f" saved: {path} ({size_mb:.2f} MB)") |
|
|
| self.save_checkpoint(len(self.merges), output_dir, 0) |
|
|
| hf_path = os.path.join(output_dir, "tokenizer.json") |
| export_hf_tokenizer(vocab, merge_strs, SPECIAL_TOKENS, hf_path, |
| script_mode=self.script_mode) |
|
|
| print(f"\n{'='*60}") |
| print(f"TRAINING COMPLETE — WWHO") |
| print(f" Script mode: {self.script_mode}") |
| print(f" Vocab size: {len(vocab):,}") |
| print(f" Merge rules: {len(self.merges):,}") |
| print(f" Word types: {len(word_types):,}") |
| print(f"{'='*60}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="WWHO (SGPE) GPE Trainer") |
| parser.add_argument("--train_file", type=str, default="dataset/mixed_train.jsonl") |
| parser.add_argument("--vocab_size", type=int, default=128_000, |
| help="Target SGPE vocab size (default 128K)") |
| parser.add_argument("--min_freq", type=int, default=2) |
| parser.add_argument("--prune_freq", type=int, default=100, |
| help="Drop syllables below this corpus frequency to [UNK]") |
| parser.add_argument("--output_dir", type=str, default="output") |
| parser.add_argument("--num_workers", type=int, default=None) |
| parser.add_argument("--checkpoint_every", type=int, default=20_000) |
| parser.add_argument("--resume", type=str, default=None) |
| parser.add_argument("--script_mode", type=str, default="mixed", |
| choices=["sinhala", "devanagari", "mixed"], |
| help="Which Indic script(s) to merge in BPE " |
| "(English/code always stays as boundary tokens)") |
| args = parser.parse_args() |
| _setup_logging(args.output_dir) |
| _log(f"Starting WWHO (SGPE) trainer: train_file={args.train_file} " |
| f"vocab_size={args.vocab_size} script_mode={args.script_mode} " |
| f"prune_freq={args.prune_freq} min_freq={args.min_freq}") |
|
|
| trainer = GPETrainer( |
| vocab_size=args.vocab_size, |
| min_freq=args.min_freq, |
| num_workers=args.num_workers, |
| checkpoint_every=args.checkpoint_every, |
| prune_freq=args.prune_freq, |
| script_mode=args.script_mode, |
| ) |
| trainer.train(args.train_file, args.output_dir, resume_path=args.resume) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|