Twitch-BPE / src /autotune.py
Soldier-Boy's picture
create: src files
c6e5251 verified
from __future__ import annotations
from typing import List, Dict
from . import config as CFG
from .bpe_trainer import train_bpe
from .bpe_tokenizer import BPETokenizer
from .eval_metrics import compression_ratio, roundtrip_accuracy
from .utils import get_logger
logger = get_logger()
def autotune(train_lines: List[str], val_lines: List[str]) -> Dict:
"""Iteratively train with adjusted knobs until constraints met or max rounds hit."""
vocab_targets = [6000, 8000, 10000, 12000, 14000]
best = None
for round_id in range(min(CFG.AUTOTUNE_MAX_ROUNDS, len(vocab_targets))):
vt = vocab_targets[round_id]
art = train_bpe(train_lines, vocab_target=vt)
art.save(CFG.TOKENIZER_DIR)
tok = BPETokenizer(CFG.TOKENIZER_DIR)
ratio = compression_ratio(val_lines, tok)
acc, mismatches = roundtrip_accuracy(val_lines, tok, sample=min(2000, len(val_lines)))
logger.info(f"[autotune] round={round_id+1} vocab={len(tok.vocab)} ratio={ratio:.3f} rt_acc={acc:.3f}")
record = {
"round": round_id+1,
"vocab_size": len(tok.vocab),
"compression_ratio": ratio,
"roundtrip_accuracy": acc,
}
if best is None or ratio > best[0]:
best = (ratio, record)
if len(tok.vocab) > 5000 and ratio >= CFG.RATIO_TARGET:
return record
return best[1] if best else {"error": "autotune_failed"}