daa-tokenizers / eval_test_set.py
Ouaill's picture
Add eval_test_set.py (test-set-only consistent eval)
24d26d6 verified
Raw
History Blame Contribute Delete
8.15 kB
#!/usr/bin/env python3 -u
"""
eval_test_set.py — Re-evaluate ALL 24 tokenizers on TEST SET ONLY.
Produces a single source of truth for ALL tables in the paper.
"""
import json, os, sys, time, csv, gc, warnings
from collections import Counter
from dataclasses import dataclass, asdict
from typing import List, Dict
import numpy as np
import regex
warnings.filterwarnings("ignore")
BASE = "/root/oiq_cc_tokenizer/results"
CORPORA = os.path.join(BASE, "corpora")
TOK_DIR = os.path.join(BASE, "tokenizers")
_WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE)
_AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]")
_SPECIAL = {"<unk>", "<s>", "</s>", "[CLS]", "[SEP]", "[PAD]", "[UNK]", "<pad>", ""}
def segment_words(t): return _WORD_PAT.findall(t)
def count_graphemes(t): return len(regex.findall(r"\X", t))
def detect_script(t): return "ar" if len(_AR_PAT.findall(t)) > len(t) * 0.3 else "az"
def filter_sp(tokens): return [t for t in tokens if t not in _SPECIAL]
@dataclass
class M:
name: str = ""
tokenizer_type: str = ""
algorithm: str = ""
vocab_size: int = 0
fertility_overall: float = 0.0
fertility_ar: float = 0.0
fertility_az: float = 0.0
cpt_overall: float = 0.0
cpt_ar: float = 0.0
cpt_az: float = 0.0
fertility_disparity: float = 0.0
cpt_disparity: float = 0.0
oov_rate: float = 0.0
vocab_gini: float = 0.0
shannon_entropy: float = 0.0
exact_match_ar: float = 0.0
exact_match_az: float = 0.0
class RawConcat:
def __init__(self, ar_j, az_j):
from tokenizers import Tokenizer
self.ar = Tokenizer.from_file(ar_j)
self.az = Tokenizer.from_file(az_j)
def encode(self, text):
s = detect_script(text)
t = self.ar if s == "ar" else self.az
enc = t.encode(text)
return enc.tokens, enc.ids, s
def decode(self, ids, script):
t = self.ar if script == "ar" else self.az
return t.decode(ids, skip_special_tokens=True)
class RawShared:
def __init__(self, j):
from tokenizers import Tokenizer
self.tok = Tokenizer.from_file(j)
def encode(self, text):
enc = self.tok.encode(text)
return enc.tokens, enc.ids, detect_script(text)
def decode(self, ids, script):
return self.tok.decode(ids, skip_special_tokens=True)
def gini_coefficient(freqs):
if not freqs:
return 0.0
vals = sorted(freqs)
n = len(vals)
total = sum(vals)
if total == 0:
return 0.0
cumsum = np.cumsum(vals)
gini = (n + 1 - 2 * np.sum(cumsum) / total) / n
return float(np.clip(gini, 0, 1))
def shannon_entropy(freqs):
if not freqs:
return 0.0
total = sum(freqs)
if total == 0:
return 0.0
ent = 0.0
for f in freqs:
if f > 0:
p = f / total
ent -= p * np.log2(p)
return float(ent)
def evaluate(tok, name, ttype, algo, vsz, texts):
m = M(name=name, tokenizer_type=ttype, algorithm=algo, vocab_size=vsz)
ar_f, az_f, all_f = [], [], []
ar_c, az_c, all_c = [], [], []
ar_ok, az_ok, ar_n, az_n = 0, 0, 0, 0
token_counts = Counter()
for i, text in enumerate(texts):
if (i + 1) % 5000 == 0:
print(f" [{i+1}/{len(texts)}] {name}", flush=True)
try:
tokens, ids, script = tok.encode(text)
content = filter_sp(tokens)
words = segment_words(text)
if not words:
continue
fert = len(content) / len(words)
all_f.append(fert)
cpt = count_graphemes(text) / max(len(content), 1)
all_c.append(cpt)
for t in content:
token_counts[t] += 1
try:
dec = tok.decode(ids, script)
exact = dec.strip() == text.strip()
except:
exact = False
if script == "ar":
ar_f.append(fert); ar_c.append(cpt); ar_n += 1
if exact: ar_ok += 1
else:
az_f.append(fert); az_c.append(cpt); az_n += 1
if exact: az_ok += 1
except:
pass
m.fertility_ar = float(np.mean(ar_f)) if ar_f else 0
m.fertility_az = float(np.mean(az_f)) if az_f else 0
m.fertility_overall = float(np.mean(all_f)) if all_f else 0
m.cpt_ar = float(np.mean(ar_c)) if ar_c else 0
m.cpt_az = float(np.mean(az_c)) if az_c else 0
m.cpt_overall = float(np.mean(all_c)) if all_c else 0
mx = max(m.fertility_ar, m.fertility_az, 1e-9)
m.fertility_disparity = abs(m.fertility_ar - m.fertility_az) / mx
cpt_mx = max(m.cpt_ar, m.cpt_az, 1e-9)
m.cpt_disparity = abs(m.cpt_ar - m.cpt_az) / cpt_mx
m.exact_match_ar = ar_ok / max(ar_n, 1)
m.exact_match_az = az_ok / max(az_n, 1)
m.vocab_gini = gini_coefficient(list(token_counts.values()))
m.shannon_entropy = shannon_entropy(list(token_counts.values()))
return m
def main():
texts = []
for s in ("test_ar", "test_az", "test_mi"):
p = os.path.join(CORPORA, f"{s}.txt")
if os.path.exists(p):
with open(p) as f:
texts.extend(l.strip() for l in f if l.strip())
print(f"{len(texts)} test texts", flush=True)
results = []
for vsz in (8000, 16000, 32000):
for algo in ("bpe", "unigram", "wordpiece", "bbpe"):
# Shared
jpath = os.path.join(TOK_DIR, f"shared_{algo}_{vsz}.json")
if os.path.exists(jpath):
name = f"shared_{algo}_{vsz}"
print(f"\n{name}", flush=True)
tok = RawShared(jpath)
r = evaluate(tok, name, "shared", algo, vsz, texts)
print(f" F={r.fertility_overall:.4f} F_ar={r.fertility_ar:.4f} F_az={r.fertility_az:.4f} ΔF={r.fertility_disparity:.4f} CPT={r.cpt_overall:.3f} G={r.vocab_gini:.3f} H={r.shannon_entropy:.2f} EM_ar={r.exact_match_ar:.2%}", flush=True)
results.append(r)
del tok; gc.collect()
# Concat: sub-tokenizer vocab = vsz // 2
ar_j = os.path.join(TOK_DIR, f"concat_ar_{algo}_{vsz//2}.json")
az_j = os.path.join(TOK_DIR, f"concat_az_{algo}_{vsz//2}.json")
if os.path.exists(ar_j) and os.path.exists(az_j):
name = f"concat_{algo}_{vsz}"
print(f"\n{name}", flush=True)
tok = RawConcat(ar_j, az_j)
r = evaluate(tok, name, "concatenated", algo, vsz, texts)
print(f" F={r.fertility_overall:.4f} F_ar={r.fertility_ar:.4f} F_az={r.fertility_az:.4f} ΔF={r.fertility_disparity:.4f} CPT={r.cpt_overall:.3f} G={r.vocab_gini:.3f} H={r.shannon_entropy:.2f} EM_ar={r.exact_match_ar:.2%}", flush=True)
results.append(r)
del tok; gc.collect()
# Save
out_csv = os.path.join(BASE, "test_set_results.csv")
out_json = os.path.join(BASE, "test_set_results.json")
with open(out_csv, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=list(asdict(results[0]).keys()))
w.writeheader()
for r in results: w.writerow(asdict(r))
with open(out_json, "w") as f:
json.dump([asdict(r) for r in results], f, indent=2)
print(f"\nSaved: {out_csv}", flush=True)
# Print full table
print("\n" + "=" * 150, flush=True)
hdr = f"{'Name':<25} {'Type':<14} {'V':>5} {'F_all':>7} {'F_ar':>7} {'F_az':>7} {'ΔF':>7} {'CPT_all':>7} {'CPT_ar':>7} {'CPT_az':>7} {'Gini':>6} {'Ent':>6} {'EM_ar':>7} {'EM_az':>7}"
print(hdr, flush=True)
print("-" * 150, flush=True)
for r in sorted(results, key=lambda x: (x.vocab_size, x.tokenizer_type, x.algorithm)):
print(f"{r.name:<25} {r.tokenizer_type:<14} {r.vocab_size:>5,} {r.fertility_overall:>7.4f} {r.fertility_ar:>7.4f} {r.fertility_az:>7.4f} {r.fertility_disparity:>7.4f} {r.cpt_overall:>7.3f} {r.cpt_ar:>7.3f} {r.cpt_az:>7.3f} {r.vocab_gini:>6.3f} {r.shannon_entropy:>6.2f} {r.exact_match_ar:>7.2%} {r.exact_match_az:>7.2%}", flush=True)
print("=" * 150, flush=True)
print("DONE!", flush=True)
if __name__ == "__main__":
main()