WWHO / gpe_trainer.py
thekusaldarshana's picture
Seperate Before you Compress
e51bea7
"""
WWHO(SGPE) GPE Trainer
"""
import argparse
import gc
import heapq
import json
import logging
import os
import pickle
import re
import time
from collections import Counter, defaultdict
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from router import CodeSwitchSegmenter
from export import export_hf_tokenizer
# ─── Logging ──────
try:
import psutil as _psutil
def _ram_mb() -> str:
p = _psutil.Process()
rss = p.memory_info().rss / 1024**2
avail = _psutil.virtual_memory().available / 1024**2
return f"RSS={rss:.0f}MB avail={avail:.0f}MB"
except ImportError:
def _ram_mb() -> str:
try:
with open("/proc/meminfo") as f:
info = {l.split(":")[0]: int(l.split()[1])
for l in f if ":" in l}
avail = info.get("MemAvailable", 0) // 1024
return f"avail={avail}MB"
except Exception:
return "ram=N/A"
_logger: logging.Logger | None = None
def _log(msg: str):
full = f"[{time.strftime('%H:%M:%S')}] [{_ram_mb()}] {msg}"
print(full, flush=True)
if _logger:
_logger.info(full)
def _setup_logging(output_dir: str):
global _logger
os.makedirs(output_dir, exist_ok=True)
log_path = os.path.join(output_dir, "training.log")
logging.basicConfig(
filename=log_path,
level=logging.INFO,
format="%(message)s",
)
_logger = logging.getLogger("wwho_trainer")
_log(f"Log started: {log_path}")
SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
# ─── Multiprocessing ──────
_worker_segmenter: CodeSwitchSegmenter | None = None
_worker_dfa_map: dict | None = None
_worker_script_mode: str = "mixed"
def _init_worker(script_mode: str):
global _worker_segmenter, _worker_dfa_map, _worker_script_mode
from linguis_trie import load_dfa_map
_worker_script_mode = script_mode
_worker_dfa_map = load_dfa_map(script_mode)
language_blocks = {lang: dfa.unicode_blocks for lang, dfa in _worker_dfa_map.items()}
_worker_segmenter = CodeSwitchSegmenter(language_blocks)
def _pretokenize_line(text: str) -> list[str]:
tokens: list[str] = []
for seg in _worker_segmenter.segment(text):
if seg.language == "latin":
tokens.append(seg.text)
else:
dfa = _worker_dfa_map.get(seg.language)
if not dfa:
tokens.append(seg.text)
continue
syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space)
tokens.extend(syllables)
return tokens
def _is_boundary_token(token: str) -> bool:
for ch in token:
if _worker_segmenter:
lang = _worker_segmenter._get_char_language(ch)
if lang is not None and lang != "latin":
return False
return True
def segment_into_words(syllables: list[str]) -> list[list[str]]:
words: list[list[str]] = []
current: list[str] = []
for tok in syllables:
if _is_boundary_token(tok):
if current:
words.append(current)
current = []
words.append([tok])
else:
if tok[0] in (' ', '\t', '\n', '\r') and current:
words.append(current)
current = []
current.append(tok)
if current:
words.append(current)
return words
# ─── Symbol Table ──────
class SymbolTable:
def __init__(self):
self._str_to_id: dict[str, int] = {}
self._id_to_str: list[str] = []
def get_or_add(self, token: str) -> int:
if token in self._str_to_id:
return self._str_to_id[token]
new_id = len(self._id_to_str)
self._str_to_id[token] = new_id
self._id_to_str.append(token)
return new_id
def add_merged(self, a_id: int, b_id: int) -> int:
merged_str = self._id_to_str[a_id] + self._id_to_str[b_id]
return self.get_or_add(merged_str)
def to_str(self, token_id: int) -> str:
return self._id_to_str[token_id]
def to_id(self, token: str) -> int | None:
return self._str_to_id.get(token)
def __len__(self) -> int:
return len(self._id_to_str)
# ─── GPETrainer ──────
class GPETrainer:
def __init__(
self,
vocab_size: int = 128_000,
min_freq: int = 2,
num_workers: int | None = None,
checkpoint_every: int = 20_000,
prune_freq: int = 100,
script_mode: str = "mixed",
):
self.target_vocab_size = vocab_size
self.min_freq = min_freq
self.num_workers = num_workers or max(1, cpu_count() - 1)
self.checkpoint_every = checkpoint_every
self.prune_freq = prune_freq
self.script_mode = script_mode
self.merges: list[tuple[int, int]] = []
self.symbols = SymbolTable()
def stream_and_count(
self, train_file: str, output_dir: str = "output"
) -> tuple[Counter, set[str]]:
# ── 1. Count lines ──────
print(" counting lines...", end=" ", flush=True)
with open(train_file, "r", encoding="utf-8") as f:
num_lines = sum(1 for _ in f)
print(f"{num_lines:,}")
CHUNK_SIZE = 5_000_000
BATCH = 4_096
partial_dir = os.path.join(output_dir, "_partial_counters")
os.makedirs(partial_dir, exist_ok=True)
_init_worker(self.script_mode)
total_lines = 0
chunk_idx = 0
partial_paths: list[str] = []
PARTIAL_PRUNE = 2
def _save_partial(counter: Counter, idx: int, n_sent: int):
if PARTIAL_PRUNE > 1:
to_save = Counter(
{k: v for k, v in counter.items() if v >= PARTIAL_PRUNE}
)
else:
to_save = counter
pkl_path = os.path.join(partial_dir, f"partial_{idx:04d}.pkl")
with open(pkl_path, "wb") as pf:
pickle.dump(to_save, pf, protocol=pickle.HIGHEST_PROTOCOL)
partial_paths.append(pkl_path)
pkl_mb = os.path.getsize(pkl_path) / 1024**2
pbar.write(
f" chunk {idx+1} done: {n_sent:,} sent "
f"-> {len(to_save):,} word types (pruned from {len(counter):,}) "
f"-> {pkl_path} ({pkl_mb:.0f} MB)"
)
_log(f"CHUNK {idx+1} saved: {n_sent:,} sent, "
f"{len(to_save):,} word types, {pkl_mb:.0f} MB")
del to_save
counter.clear()
gc.collect()
_log(f"CHUNK {idx+1} post-gc")
chunk_counter: Counter = Counter()
chunk_sent = 0
batch_buf: list[str] = []
pool = Pool(
processes=self.num_workers,
initializer=_init_worker,
initargs=(self.script_mode,),
)
with open(train_file, "r", encoding="utf-8") as f:
pbar = tqdm(f, total=num_lines, unit=" sent",
desc=f" pre-tokenizing [chunk 1]")
for raw_line in pbar:
try:
obj = json.loads(raw_line)
text = obj.get("text", "").strip()
except json.JSONDecodeError:
text = raw_line.strip()
if not text:
continue
batch_buf.append(text)
total_lines += 1
chunk_sent += 1
if len(batch_buf) >= BATCH:
self._process_batch(pool, batch_buf, chunk_counter)
batch_buf = []
if chunk_sent >= CHUNK_SIZE:
if batch_buf:
self._process_batch(pool, batch_buf, chunk_counter)
batch_buf = []
pool.close()
pool.join()
pool = None
gc.collect()
_save_partial(chunk_counter, chunk_idx, chunk_sent)
chunk_idx += 1
chunk_sent = 0
pbar.set_description(
f" pre-tokenizing [chunk {chunk_idx + 1}]"
)
gc.collect()
pool = Pool(
processes=self.num_workers,
initializer=_init_worker,
initargs=(self.script_mode,),
)
if batch_buf:
self._process_batch(pool, batch_buf, chunk_counter)
pool.close()
pool.join()
gc.collect()
if chunk_counter:
_save_partial(chunk_counter, chunk_idx, chunk_sent)
chunk_idx += 1
pbar.close()
print(f" {total_lines:,} sentences -> {chunk_idx} chunks processed")
# ── 3. Sequential merge with intermediate pruning ──────
_log(f"MERGE START: {len(partial_paths)} partial counters, min_freq={self.min_freq}")
N = len(partial_paths)
word_counter: Counter = Counter()
for i, pkl_path in enumerate(partial_paths):
_log(f"MERGE [{i+1}/{N}] loading {pkl_path}")
with open(pkl_path, "rb") as pf:
partial: Counter = pickle.load(pf)
_log(f"MERGE [{i+1}/{N}] loaded {len(partial):,} types, updating master...")
word_counter.update(partial)
del partial
gc.collect()
_log(f"MERGE [{i+1}/{N}] after update+gc: {len(word_counter):,} types")
remaining = N - i - 1
safe_prune = max(1, self.min_freq - remaining)
before = len(word_counter)
if safe_prune > 1:
word_counter = Counter(
{k: v for k, v in word_counter.items() if v >= safe_prune}
)
if i > 0 and i % 5 == 0:
hard_threshold = max(2, self.min_freq // 2)
word_counter = Counter(
{k: v for k, v in word_counter.items() if v >= hard_threshold}
)
_log(f"MERGE [{i+1}/{N}] HARD PRUNE TRIGGERED (threshold={hard_threshold})")
gc.collect()
pruned_n = before - len(word_counter)
if pruned_n > 0:
msg = (f" [{i+1}/{N}] merged -> {len(word_counter):,} types "
f"(pruned {pruned_n:,})")
print(msg, flush=True)
_log(f"MERGE [{i+1}/{N}] post-prune: {len(word_counter):,} types "
f"(removed {pruned_n:,})")
else:
print(f" [{i+1}/{N}] merged -> {len(word_counter):,} types", flush=True)
_log(f"MERGE [{i+1}/{N}] no prune needed, {len(word_counter):,} types")
os.remove(pkl_path)
_log(f"MERGE [{i+1}/{N}] deleted {pkl_path}")
try:
os.rmdir(partial_dir)
except OSError:
pass
n_types = len(word_counter)
n_instances = sum(word_counter.values())
print(f"\n Final: {total_lines:,} sent -> {n_types:,} word types "
f"({n_instances:,} instances)")
return word_counter, set()
def _process_batch(
self,
pool: Pool,
batch: list[str],
word_counter: Counter,
):
syllable_streams = pool.map(_pretokenize_line, batch, chunksize=128)
for stream in syllable_streams:
words = segment_into_words(stream)
for w in words:
if not w:
continue
if not _is_boundary_token(w[0]):
word_counter[tuple(w)] += 1
@staticmethod
def compute_syllable_freqs(word_counter: Counter) -> Counter:
syl_freq: Counter[str] = Counter()
for word_tuple, word_freq in word_counter.items():
for syl in word_tuple:
syl_freq[syl] += word_freq
return syl_freq
def build_word_types(
self,
word_counter: Counter,
boundary_tokens: set[str],
syl_freq: Counter | None = None,
) -> tuple[list[list[int]], list[int]]:
UNK_SENTINEL = -1
pruned_set: set[str] = set()
if syl_freq is not None and self.prune_freq > 0:
for syl, freq in syl_freq.items():
if freq < self.prune_freq:
pruned_set.add(syl)
word_types: list[list[int]] = []
word_freqs: list[int] = []
pruned_word_count = 0
for word_tuple, freq in word_counter.items():
ids = []
for tok in word_tuple:
if tok in pruned_set:
ids.append(UNK_SENTINEL)
else:
ids.append(self.symbols.get_or_add(tok))
word_types.append(ids)
word_freqs.append(freq)
if UNK_SENTINEL in ids:
pruned_word_count += 1
if pruned_set:
print(f" pruned {len(pruned_set):,} rare syllables (freq < {self.prune_freq})")
print(f" {pruned_word_count:,} word types contain [UNK] syllables")
return word_types, word_freqs
@staticmethod
def build_token_index(word_types: list[list[int]]) -> dict[int, set[int]]:
index: dict[int, set[int]] = defaultdict(set)
for wt_idx, wt in enumerate(word_types):
for tid in wt:
if tid >= 0:
index[tid].add(wt_idx)
return dict(index)
def count_all_pairs(
self,
word_types: list[list[int]],
word_freqs: list[int],
) -> dict[tuple[int, int], int]:
counts: dict[tuple[int, int], int] = defaultdict(int)
for wt_idx, wt in enumerate(word_types):
f = word_freqs[wt_idx]
for i in range(len(wt) - 1):
a, b = wt[i], wt[i + 1]
if a < 0 or b < 0:
continue
counts[(a, b)] += f
return dict(counts)
@staticmethod
def _build_heap(pair_counts: dict) -> list:
heap = [(-freq, pair) for pair, freq in pair_counts.items() if freq > 0]
heapq.heapify(heap)
return heap
@staticmethod
def _heap_push(heap, pair, freq):
if freq > 0:
heapq.heappush(heap, (-freq, pair))
def _pop_best(self, heap, pair_counts):
while heap:
neg_freq, pair = heapq.heappop(heap)
actual = pair_counts.get(pair, 0)
if actual <= 0:
continue
if actual != -neg_freq:
self._heap_push(heap, pair, actual)
continue
return pair, actual
return None, 0
def merge_and_update(
self,
word_types: list[list[int]],
word_freqs: list[int],
pair: tuple[int, int],
pair_counts: dict[tuple[int, int], int],
token_index: dict[int, set[int]],
merged_id: int,
heap: list,
) -> int:
a, b = pair
total_applied = 0
candidates = list(token_index.get(a, set()) & token_index.get(b, set()))
pair_counts.pop(pair, None)
dirty_pairs: dict[tuple[int, int], int] = {}
for wt_idx in candidates:
wt = word_types[wt_idx]
freq = word_freqs[wt_idx]
if len(wt) < 2:
continue
new_wt: list[int] = []
i = 0
changed = False
while i < len(wt):
if i + 1 < len(wt) and wt[i] == a and wt[i + 1] == b:
if new_wt and new_wt[-1] >= 0:
lp = (new_wt[-1], a)
pair_counts[lp] = pair_counts.get(lp, 0) - freq
dirty_pairs[lp] = pair_counts[lp]
if i + 2 < len(wt) and wt[i + 2] >= 0:
rp = (b, wt[i + 2])
pair_counts[rp] = pair_counts.get(rp, 0) - freq
dirty_pairs[rp] = pair_counts[rp]
new_wt.append(merged_id)
total_applied += freq
changed = True
if len(new_wt) >= 2 and new_wt[-2] >= 0:
lp = (new_wt[-2], merged_id)
pair_counts[lp] = pair_counts.get(lp, 0) + freq
dirty_pairs[lp] = pair_counts[lp]
if i + 2 < len(wt) and wt[i + 2] >= 0:
rp = (merged_id, wt[i + 2])
pair_counts[rp] = pair_counts.get(rp, 0) + freq
dirty_pairs[rp] = pair_counts[rp]
i += 2
else:
new_wt.append(wt[i])
i += 1
if changed:
word_types[wt_idx] = new_wt
if merged_id not in token_index:
token_index[merged_id] = set()
token_index[merged_id].add(wt_idx)
remaining = set(new_wt)
if a not in remaining and wt_idx in token_index.get(a, set()):
token_index[a].discard(wt_idx)
if b not in remaining and wt_idx in token_index.get(b, set()):
token_index[b].discard(wt_idx)
for tok_id in (a, b):
if tok_id in token_index and not token_index[tok_id]:
del token_index[tok_id]
for p, cnt in dirty_pairs.items():
if cnt <= 0:
pair_counts.pop(p, None)
else:
self._heap_push(heap, p, cnt)
return total_applied
def save_checkpoint(self, step: int, output_dir: str, elapsed: float):
merge_strs = [
[self.symbols.to_str(a), self.symbols.to_str(b)]
for a, b in self.merges
]
ckpt = {
"step": step,
"script_mode": self.script_mode,
"merges": merge_strs,
"elapsed_seconds": round(elapsed, 1),
}
path = os.path.join(output_dir, f"checkpoint_{step}.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(ckpt, f, ensure_ascii=False)
size_mb = os.path.getsize(path) / (1024 * 1024)
return path, size_mb
def load_checkpoint(self, ckpt_path: str):
with open(ckpt_path, "r", encoding="utf-8") as f:
ckpt = json.load(f)
print(f" loaded checkpoint: step {ckpt['step']}, "
f"{len(ckpt['merges'])} merges, "
f"{ckpt['elapsed_seconds']:.1f}s elapsed")
return ckpt
def replay_merges(self, merge_strs, word_types, word_freqs, token_index, pair_counts):
print(f" replaying {len(merge_strs)} merges...", flush=True)
t0 = time.time()
dummy_heap: list = []
for a_str, b_str in tqdm(merge_strs, desc=" replaying", unit=" merge"):
a_id = self.symbols.to_id(a_str)
b_id = self.symbols.to_id(b_str)
if a_id is None or b_id is None:
continue
merged_id = self.symbols.add_merged(a_id, b_id)
self.merges.append((a_id, b_id))
self.merge_and_update(
word_types, word_freqs, (a_id, b_id), pair_counts,
token_index, merged_id, dummy_heap,
)
print(f" replayed {len(self.merges)} merges in {time.time()-t0:.1f}s")
def train(self, train_file: str, output_dir: str = "output",
resume_path: str | None = None):
os.makedirs(output_dir, exist_ok=True)
print(f"WWHO (SGPE) GPE Trainer — script_mode={self.script_mode}, "
f"workers={self.num_workers}")
print(f"Training file: {train_file}\n")
print("[1/5] Streaming pre-tokenization (CodeSwitchRouter)...")
t_start = time.time()
word_counter, boundary_tokens = self.stream_and_count(train_file, output_dir)
print("\n[2/5] Building ID corpus...")
syl_freq = None
if self.prune_freq > 0:
syl_freq = self.compute_syllable_freqs(word_counter)
total_syls = len(syl_freq)
surviving = sum(1 for f in syl_freq.values() if f >= self.prune_freq)
print(f" syllable pruning: {total_syls:,} unique syllables, "
f"{surviving:,} survive (freq >= {self.prune_freq})")
word_types, word_freqs = self.build_word_types(
word_counter, boundary_tokens, syl_freq=syl_freq,
)
del word_counter, syl_freq
base_vocab = len(self.symbols)
total_instances = sum(word_freqs)
print(f" base vocab (syllables + boundaries): {base_vocab:,}")
print(f" word types: {len(word_types):,} ({total_instances:,} instances)")
print("\n[3/5] Building index and counting pairs...")
token_index = self.build_token_index(word_types)
pair_counts = self.count_all_pairs(word_types, word_freqs)
print(f" {len(pair_counts):,} unique pairs")
start_step = 0
elapsed_prior = 0.0
if resume_path:
print(f"\n Resuming from {resume_path}...")
ckpt = self.load_checkpoint(resume_path)
self.replay_merges(
ckpt["merges"], word_types, word_freqs, token_index, pair_counts,
)
start_step = ckpt["step"]
elapsed_prior = ckpt["elapsed_seconds"]
pair_counts = self.count_all_pairs(word_types, word_freqs)
print(f" rebuilt pair counts: {len(pair_counts):,} unique pairs")
total_vocab_needed = self.target_vocab_size - len(SPECIAL_TOKENS)
num_merges = max(0, total_vocab_needed - base_vocab)
remaining = num_merges - start_step
print(f"\n merge budget: {num_merges:,} "
f"(starting at {start_step}, {remaining:,} remaining, min_freq={self.min_freq})")
print(f"\n[4/5] Merge loop...")
heap = self._build_heap(pair_counts)
t0 = time.time()
pbar = tqdm(range(start_step + 1, num_merges + 1),
desc=" merging", unit=" merge")
for step in pbar:
best_pair, freq = self._pop_best(heap, pair_counts)
if best_pair is None or freq < self.min_freq:
pbar.write(f" stopping at step {step}: "
f"{'no pairs' if best_pair is None else f'freq={freq} < {self.min_freq}'}")
break
a_id, b_id = best_pair
merged_id = self.symbols.add_merged(a_id, b_id)
self.merges.append(best_pair)
n_applied = self.merge_and_update(
word_types, word_freqs, best_pair, pair_counts,
token_index, merged_id, heap,
)
if step <= 10 or step % 1000 == 0:
a_s = self.symbols.to_str(a_id)
b_s = self.symbols.to_str(b_id)
m_s = self.symbols.to_str(merged_id)
elapsed = time.time() - t0 + elapsed_prior
pbar.write(f" [{step:>7}/{num_merges}] "
f"'{a_s}' + '{b_s}' -> '{m_s}' "
f"(freq={freq:,}, applied={n_applied:,}) [{elapsed:.1f}s]")
if self.checkpoint_every > 0 and step % self.checkpoint_every == 0:
elapsed = time.time() - t0 + elapsed_prior
path, sz = self.save_checkpoint(step, output_dir, elapsed)
pbar.write(f" >> checkpoint: {path} ({sz:.2f} MB)")
pbar.set_postfix(freq=freq, vocab=len(self.symbols))
pbar.close()
merge_elapsed = time.time() - t0
total_elapsed = merge_elapsed + elapsed_prior
print(f" done: {len(self.merges)} merges in {merge_elapsed:.1f}s "
f"(total {total_elapsed:.1f}s)")
print("\n[5/5] Building vocabulary and exporting...")
self._save_output(word_types, word_freqs, boundary_tokens, output_dir)
wall = time.time() - t_start
print(f"\nTotal wall time: {wall:.1f}s ({wall/60:.1f} min)")
def _save_output(self, word_types, word_freqs, boundary_tokens, output_dir):
final_freq: Counter[int] = Counter()
for wt_idx, wt in enumerate(word_types):
f = word_freqs[wt_idx]
for tid in wt:
if tid >= 0:
final_freq[tid] += f
vocab: dict[str, int] = {}
for i, st in enumerate(SPECIAL_TOKENS):
vocab[st] = i
next_id = len(SPECIAL_TOKENS)
for tid, _ in final_freq.most_common():
if len(vocab) >= self.target_vocab_size:
break
tok_str = self.symbols.to_str(tid)
if tok_str not in vocab:
vocab[tok_str] = next_id
next_id += 1
for sid in range(len(self.symbols)):
if len(vocab) >= self.target_vocab_size:
break
s = self.symbols.to_str(sid)
if s not in vocab:
vocab[s] = next_id
next_id += 1
print(f" vocab size: {len(vocab):,}")
print(f" merge rules: {len(self.merges):,}")
merge_strs = [
[self.symbols.to_str(a), self.symbols.to_str(b)]
for a, b in self.merges
]
output = {
"version": "wwho_sgpe",
"script_mode": self.script_mode,
"vocab_size": len(vocab),
"special_tokens": SPECIAL_TOKENS,
"num_merges": len(self.merges),
"prune_freq": self.prune_freq,
"leading_space": True,
"merges": merge_strs,
"vocab": vocab,
}
path = os.path.join(output_dir, "vocab.json")
with open(path, "w", encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=2)
size_mb = os.path.getsize(path) / (1024 * 1024)
print(f" saved: {path} ({size_mb:.2f} MB)")
self.save_checkpoint(len(self.merges), output_dir, 0)
hf_path = os.path.join(output_dir, "tokenizer.json")
export_hf_tokenizer(vocab, merge_strs, SPECIAL_TOKENS, hf_path,
script_mode=self.script_mode)
print(f"\n{'='*60}")
print(f"TRAINING COMPLETE — WWHO")
print(f" Script mode: {self.script_mode}")
print(f" Vocab size: {len(vocab):,}")
print(f" Merge rules: {len(self.merges):,}")
print(f" Word types: {len(word_types):,}")
print(f"{'='*60}")
def main():
parser = argparse.ArgumentParser(description="WWHO (SGPE) GPE Trainer")
parser.add_argument("--train_file", type=str, default="dataset/mixed_train.jsonl")
parser.add_argument("--vocab_size", type=int, default=128_000,
help="Target SGPE vocab size (default 128K)")
parser.add_argument("--min_freq", type=int, default=2)
parser.add_argument("--prune_freq", type=int, default=100,
help="Drop syllables below this corpus frequency to [UNK]")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--num_workers", type=int, default=None)
parser.add_argument("--checkpoint_every", type=int, default=20_000)
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--script_mode", type=str, default="mixed",
choices=["sinhala", "devanagari", "mixed"],
help="Which Indic script(s) to merge in BPE "
"(English/code always stays as boundary tokens)")
args = parser.parse_args()
_setup_logging(args.output_dir)
_log(f"Starting WWHO (SGPE) trainer: train_file={args.train_file} "
f"vocab_size={args.vocab_size} script_mode={args.script_mode} "
f"prune_freq={args.prune_freq} min_freq={args.min_freq}")
trainer = GPETrainer(
vocab_size=args.vocab_size,
min_freq=args.min_freq,
num_workers=args.num_workers,
checkpoint_every=args.checkpoint_every,
prune_freq=args.prune_freq,
script_mode=args.script_mode,
)
trainer.train(args.train_file, args.output_dir, resume_path=args.resume)
if __name__ == "__main__":
main()