SARF-Tokenizer / tokenizer_benchmark.py
almaghrabima's picture
Update: rank by parity+efficiency, add Falcon-H1-7B
5362025
"""
Multi-tokenizer comparison benchmark.
Evaluates SARF against 11 other tokenizers on Arabic+English text,
computing fertility, chars/token, parity, and a composite score.
"""
import os, sys, re, json, argparse, time
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Load .env file
_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env")
if os.path.exists(_env_path):
with open(_env_path) as _f:
for _line in _f:
_line = _line.strip()
if _line and not _line.startswith("#") and "=" in _line:
_k, _v = _line.split("=", 1)
os.environ.setdefault(_k.strip(), _v.strip())
# Disable hf_transfer if not installed
try:
import hf_transfer # noqa: F401
except ImportError:
os.environ.pop("HF_HUB_ENABLE_HF_TRANSFER", None)
import pyarrow.parquet as pq
import glob as globmod
from scripts.rewrite_bytes import ByteRewriter
# ── Tokenizer wrappers ──────────────────────────────────────────────
class SarfTokenizer:
def __init__(self, tokenizer_dir, morf_map_path):
from transformers import PreTrainedTokenizerFast
self._tok = PreTrainedTokenizerFast(
tokenizer_file=os.path.join(tokenizer_dir, "tokenizer.json")
)
self._rewriter = ByteRewriter(morf_map_path)
def encode(self, text):
return self._tok.encode(self._rewriter.rewrite_text(text), add_special_tokens=False)
@property
def vocab_size(self):
return len(self._tok)
@property
def name(self):
return "SARF (Ours)"
class TiktokenTokenizer:
def __init__(self, encoding_name, display_name=None):
import tiktoken
self._enc = tiktoken.get_encoding(encoding_name)
self._name = display_name or encoding_name
def encode(self, text):
return self._enc.encode(text, allowed_special="all")
@property
def vocab_size(self):
return self._enc.n_vocab
@property
def name(self):
return self._name
class HFTokenizer:
def __init__(self, model_id, display_name=None):
from transformers import AutoTokenizer
try:
self._tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
except Exception:
self._tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
self._name = display_name or model_id.split("/")[-1]
def encode(self, text):
return self._tok.encode(text, add_special_tokens=False)
@property
def vocab_size(self):
return len(self._tok)
@property
def name(self):
return self._name
# ── Tokenizer registry ──────────────────────────────────────────────
TOKENIZER_DEFS = [
# (display_name, type, source)
("SARF (Ours)", "sarf", None),
("GPT-4o", "tiktoken", "o200k_base"),
("GPT-4", "tiktoken", "cl100k_base"),
("ALLaM-7B", "hf", "humain-ai/ALLaM-7B-Instruct-preview"),
("AceGPT-13B", "hf", "FreedomIntelligence/AceGPT-13B-chat"),
("Gemma-3-4B", "hf", "google/gemma-3-4b-it"),
("Command-R-Arabic", "hf", "CohereLabs/c4ai-command-r7b-arabic-02-2025"),
("Fanar-1-9B", "hf", "QCRI/Fanar-1-9B-Instruct"),
("Hala-9B", "hf", "hammh0a/Hala-9B"),
("Qwen3-4B", "hf", "Qwen/Qwen3-4B-Instruct-2507"),
("Qwen3-VL-4B", "hf", "Qwen/Qwen3-VL-4B-Instruct"),
("Mistral-7B-v0.3", "hf", "mistralai/Mistral-7B-Instruct-v0.3"),
("Falcon-H1-7B", "hf", "tiiuae/Falcon-H1-7B-Instruct"),
]
def load_all_tokenizers(tokenizer_dir, morf_map_path):
"""Load all tokenizers. Returns list of wrapper objects."""
tokenizers = []
for display_name, typ, source in TOKENIZER_DEFS:
print(f"Loading {display_name}...", end=" ", flush=True)
t0 = time.time()
try:
if typ == "sarf":
tok = SarfTokenizer(tokenizer_dir, morf_map_path)
elif typ == "tiktoken":
tok = TiktokenTokenizer(source, display_name)
elif typ == "hf":
tok = HFTokenizer(source, display_name)
else:
raise ValueError(f"Unknown type: {typ}")
print(f"OK (vocab={tok.vocab_size:,}, {time.time()-t0:.1f}s)")
tokenizers.append(tok)
except Exception as e:
print(f"FAILED: {e}")
return tokenizers
# ── Data loading ─────────────────────────────────────────────────────
AR_DETECT = re.compile(r'[\u0600-\u06FF]')
def load_samples(data_dir, num_ar=5000, num_en=5000):
parquet_files = sorted(globmod.glob(os.path.join(data_dir, '*.parquet')))
ar_samples, en_samples = [], []
for filepath in parquet_files:
if len(ar_samples) >= num_ar and len(en_samples) >= num_en:
break
pf = pq.ParquetFile(filepath)
for rg_idx in range(pf.num_row_groups):
rg = pf.read_row_group(rg_idx)
for text in rg.column("text").to_pylist():
if len(text) < 100:
continue
ar_chars = len(AR_DETECT.findall(text))
ar_ratio = ar_chars / len(text)
if ar_ratio > 0.3 and len(ar_samples) < num_ar:
ar_samples.append(text[:2000])
elif ar_ratio < 0.05 and len(en_samples) < num_en:
en_samples.append(text[:2000])
if len(ar_samples) >= num_ar and len(en_samples) >= num_en:
break
print(f"Loaded {len(ar_samples)} Arabic, {len(en_samples)} English samples")
return ar_samples, en_samples
# ── Metrics ──────────────────────────────────────────────────────────
AR_WORD = re.compile(r'[\u0600-\u06FF]+')
EN_WORD = re.compile(r'[a-zA-Z]+')
def compute_metrics(tokenizer, ar_texts, en_texts):
"""Compute fertility, chars/token, and parity for one tokenizer."""
ar_total_chars = ar_total_tokens = ar_total_words = ar_total_word_tokens = 0
for text in ar_texts:
tokens = tokenizer.encode(text)
ar_total_chars += len(text)
ar_total_tokens += len(tokens)
words = AR_WORD.findall(text)
ar_total_words += len(words)
for w in words:
ar_total_word_tokens += len(tokenizer.encode(w))
en_total_chars = en_total_tokens = en_total_words = en_total_word_tokens = 0
for text in en_texts:
tokens = tokenizer.encode(text)
en_total_chars += len(text)
en_total_tokens += len(tokens)
words = EN_WORD.findall(text)
en_total_words += len(words)
for w in words:
en_total_word_tokens += len(tokenizer.encode(w))
ar_fertility = ar_total_word_tokens / ar_total_words if ar_total_words else 0
ar_cpt = ar_total_chars / ar_total_tokens if ar_total_tokens else 0
en_fertility = en_total_word_tokens / en_total_words if en_total_words else 0
en_cpt = en_total_chars / en_total_tokens if en_total_tokens else 0
parity = ar_cpt / en_cpt if en_cpt else 0
return {
"name": tokenizer.name,
"vocab_size": tokenizer.vocab_size,
"ar_fertility": round(ar_fertility, 4),
"ar_chars_per_token": round(ar_cpt, 4),
"en_fertility": round(en_fertility, 4),
"en_chars_per_token": round(en_cpt, 4),
"parity": round(parity, 4),
}
# ── Ranking ──────────────────────────────────────────────────────────
def rank_key(r):
"""Sort by parity (closer to 1.0 first), then by avg chars/token (higher first)."""
parity_dev = abs(1.0 - r["parity"])
avg_cpt = (r["ar_chars_per_token"] + r["en_chars_per_token"]) / 2.0
return (parity_dev, -avg_cpt)
# ── Display ──────────────────────────────────────────────────────────
def print_table(results):
results_sorted = sorted(results, key=rank_key)
header = f"{'Rank':<5} {'Tokenizer':<22} {'Vocab':>9} {'AR Fert':>9} {'AR C/T':>9} {'EN Fert':>9} {'EN C/T':>9} {'Parity':>9}"
print("\n" + "=" * len(header))
print("TOKENIZER BENCHMARK RESULTS")
print("=" * len(header))
print(header)
print("-" * len(header))
for rank, r in enumerate(results_sorted, 1):
print(f"{rank:<5} {r['name']:<22} {r['vocab_size']:>9,} {r['ar_fertility']:>9.3f} {r['ar_chars_per_token']:>9.3f} {r['en_fertility']:>9.3f} {r['en_chars_per_token']:>9.3f} {r['parity']:>9.4f}")
print("=" * len(header))
print("AR Fert = Arabic tokens/word (lower=better)")
print("AR C/T = Arabic chars/token (higher=better)")
print("EN Fert = English tokens/word (lower=better)")
print("EN C/T = English chars/token (higher=better)")
print("Parity = AR_C/T / EN_C/T (closer to 1.0=better)")
print("Ranked by: parity (closest to 1.0), then avg chars/token\n")
def results_to_markdown(results):
"""Return a markdown table string for the results."""
results_sorted = sorted(results, key=rank_key)
lines = [
"| Rank | Tokenizer | Vocab | AR Fertility | AR Chars/Tok | EN Fertility | EN Chars/Tok | Parity |",
"|------|-----------|------:|-------------:|-------------:|-------------:|-------------:|-------:|",
]
for rank, r in enumerate(results_sorted, 1):
lines.append(
f"| {rank} | {r['name']} | {r['vocab_size']:,} | {r['ar_fertility']:.3f} | {r['ar_chars_per_token']:.3f} | {r['en_fertility']:.3f} | {r['en_chars_per_token']:.3f} | {r['parity']:.4f} |"
)
return "\n".join(lines)
# ── Main ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Multi-tokenizer comparison benchmark")
parser.add_argument("--data_dir", default="/root/.cache/Deeplatent/eval_1b/data")
parser.add_argument("--tokenizer_dir", default="/root/.cache/deeplatent/tokenizer_parity")
parser.add_argument("--morf_map_path", default="/root/.cache/deeplatent/morfessor_models/morf_map.json")
parser.add_argument("--num_samples", type=int, default=5000)
parser.add_argument("--output", default="benchmark_results.json")
parser.add_argument("--dry_run", action="store_true", help="Test on 10 samples first")
args = parser.parse_args()
# Load tokenizers
print("Loading tokenizers...")
tokenizers = load_all_tokenizers(args.tokenizer_dir, args.morf_map_path)
print(f"\nLoaded {len(tokenizers)} tokenizers successfully.\n")
# Load data
n = 10 if args.dry_run else args.num_samples
print(f"Loading {n} samples per language...")
ar_texts, en_texts = load_samples(args.data_dir, n, n)
# Evaluate
results = []
for tok in tokenizers:
print(f"Evaluating {tok.name}...", end=" ", flush=True)
t0 = time.time()
m = compute_metrics(tok, ar_texts, en_texts)
print(f"done ({time.time()-t0:.1f}s)")
results.append(m)
# Display
print_table(results)
# Save
output = {
"num_ar_samples": len(ar_texts),
"num_en_samples": len(en_texts),
"results": sorted(results, key=rank_key),
"markdown_table": results_to_markdown(results),
}
with open(args.output, 'w') as f:
json.dump(output, f, indent=2, ensure_ascii=False)
print(f"Results saved to {args.output}")
if __name__ == "__main__":
main()