daa-tokenizers / eval_external_datasets.py
Ouaill's picture
Upload eval_external_datasets.py with huggingface_hub
7ce8a98 verified
Raw
History Blame Contribute Delete
12.8 kB
#!/usr/bin/env python3 -u
"""
eval_external_datasets.py — Evaluate the DarijaBERT family vs our tokenizers
on three independent datasets (10K sample each):
- DODa (Arabizi) atlasia/DODa
- Darija-Wiki (Arabic) atlasia/Moroccan-Darija-Wiki-Dataset
- Atlaset (Arabic) atlasia/Atlaset
For each dataset × tokenizer, compute: F_ar, F_az, F (overall), CPT, and Gain.
"""
import json, os, csv, gc, random, warnings
from dataclasses import dataclass, asdict
import numpy as np
import regex
warnings.filterwarnings("ignore")
random.seed(42)
BASE = "/root/oiq_cc_tokenizer/results"
TOK_DIR = os.path.join(BASE, "tokenizers")
PARQUET_DIR = "/root/oiq_cc_tokenizer/tmp_atlastet"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
N_SAMPLE = 10000
_WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE)
_AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]")
_LAT_PAT = regex.compile(r"[a-zA-Z]")
_SPECIAL = {"<unk>", "<s>", "</s>", "[CLS]", "[SEP]", "[PAD]", "[UNK]", "<pad>", "",
"<|im_start|>", "<|im_end|>"}
def segment_words(t): return _WORD_PAT.findall(t)
def count_graphemes(t): return len(regex.findall(r"\X", t))
def filter_sp(tokens): return [t for t in tokens if t not in _SPECIAL]
def detect_script(t):
return "ar" if len(_AR_PAT.findall(t)) > len(t) * 0.3 else "az"
def normalize_decode(s):
s = s.replace("##", "")
s = " ".join(s.split())
return s
class RawConcat:
def __init__(self, ar_j, az_j):
from tokenizers import Tokenizer
self.ar = Tokenizer.from_file(ar_j)
self.az = Tokenizer.from_file(az_j)
def encode(self, text):
s = detect_script(text)
t = self.ar if s == "ar" else self.az
enc = t.encode(text)
return enc.tokens, enc.ids, s
def decode(self, ids, script):
t = self.ar if script == "ar" else self.az
return t.decode(ids, skip_special_tokens=True)
class HFTok:
def __init__(self, repo):
from transformers import AutoTokenizer
kwargs = {"trust_remote_code": True}
if HF_TOKEN:
kwargs["token"] = HF_TOKEN
self.tok = AutoTokenizer.from_pretrained(repo, **kwargs)
def encode(self, text):
ids = self.tok.encode(text, add_special_tokens=False)
return self.tok.convert_ids_to_tokens(ids), ids, detect_script(text)
def decode(self, ids, script):
return self.tok.decode(ids, skip_special_tokens=True)
@dataclass
class Result:
dataset: str = ""
tokenizer: str = ""
vocab_size: int = 0
source: str = ""
fertility_ar: float = 0.0
fertility_az: float = 0.0
fertility_overall: float = 0.0
cpt_ar: float = 0.0
cpt_az: float = 0.0
cpt_overall: float = 0.0
gain_pct: float = 0.0
n_texts: int = 0
def evaluate(tok, texts):
ar_f, az_f, all_f = [], [], []
ar_c, az_c, all_c = [], [], []
n = 0
for i, text in enumerate(texts):
if (i + 1) % 5000 == 0:
print(f" [{i+1}/{len(texts)}]", flush=True)
try:
tokens, ids, script = tok.encode(text)
content = filter_sp(tokens)
words = segment_words(text)
if not words:
continue
fert = len(content) / len(words)
cpt = count_graphemes(text) / max(len(content), 1)
all_f.append(fert)
all_c.append(cpt)
if script == "ar":
ar_f.append(fert); ar_c.append(cpt)
else:
az_f.append(fert); az_c.append(cpt)
n += 1
except:
pass
return {
"fertility_ar": float(np.mean(ar_f)) if ar_f else 0,
"fertility_az": float(np.mean(az_f)) if az_f else 0,
"fertility_overall": float(np.mean(all_f)) if all_f else 0,
"cpt_ar": float(np.mean(ar_c)) if ar_c else 0,
"cpt_az": float(np.mean(az_c)) if az_c else 0,
"cpt_overall": float(np.mean(all_c)) if all_c else 0,
"n_texts": n,
}
# ── Tokenizer configs ──────────────────────────────────────────────
# (name, source, vocab_size, kind, spec)
# kind: ("concat", ar_file, az_file) or ("hf", repo)
TOKENIZERS = [
("DarijaBERT-ar", "external", 80000, "hf", "SI2M-Lab/DarijaBERT"),
("DarijaBERT-az", "external", 110000, "hf", "SI2M-Lab/DarijaBERT-arabizi"),
("DarijaBERT-mix", "external", 160000, "hf", "SI2M-Lab/DarijaBERT-mix"),
("Ours (80K WP)", "ours", 80000, "concat",
"concat_ar_wordpiece_40000.json", "concat_az_wordpiece_40000.json"),
("Ours (110K WP)", "ours", 110000, "concat",
"concat_ar_wordpiece_55000.json", "concat_az_wordpiece_55000.json"),
("Ours (32K BPE)", "ours", 32000, "concat",
"concat_ar_bpe_16000.json", "concat_az_bpe_16000.json"),
]
# Baselines for gain calculation: (dataset → {our_tokenizer_name → baseline_name})
# Gain = (F_baseline - F_ours) / F_baseline * 100
GAIN_BASELINES = {
"DODa": {"Ours (80K WP)": "DarijaBERT-ar",
"Ours (110K WP)": "DarijaBERT-az",
"Ours (32K BPE)": "DarijaBERT-mix"},
"Darija-Wiki": {"Ours (80K WP)": "DarijaBERT-ar",
"Ours (110K WP)": "DarijaBERT-az",
"Ours (32K BPE)": "DarijaBERT-mix"},
"Atlaset": {"Ours (80K WP)": "DarijaBERT-ar",
"Ours (110K WP)": "DarijaBERT-az",
"Ours (32K BPE)": "DarijaBERT-mix"},
}
def load_dataset_texts(dataset_name):
"""Load 10K random sample from each dataset."""
from datasets import load_dataset
if dataset_name == "DODa":
print(" Loading DODa (atlasia/DODa)...", flush=True)
ds = load_dataset("atlasia/DODa", split="train", streaming=True, token=HF_TOKEN)
all_texts = []
for row in ds:
t = row.get("darija", "")
if isinstance(t, str) and len(t.strip()) >= 3:
all_texts.append(t.strip())
if len(all_texts) >= N_SAMPLE * 3:
break
return random.sample(all_texts, min(N_SAMPLE, len(all_texts)))
elif dataset_name == "Darija-Wiki":
print(" Loading Darija-Wiki (atlasia/Moroccan-Darija-Wiki-Dataset)...", flush=True)
ds = load_dataset("atlasia/Moroccan-Darija-Wiki-Dataset", split="train",
streaming=True, token=HF_TOKEN)
all_texts = []
for row in ds:
t = row.get("content", "")
if isinstance(t, str) and len(t.strip()) >= 10:
all_texts.append(t.strip())
if len(all_texts) >= N_SAMPLE * 3:
break
return random.sample(all_texts, min(N_SAMPLE, len(all_texts)))
elif dataset_name == "Atlaset":
print(" Loading Atlaset from local parquet...", flush=True)
import pyarrow.parquet as pq, glob
files = sorted(glob.glob(os.path.join(PARQUET_DIR, "data", "train-*.parquet")))
all_texts = []
for fp in files:
pf = pq.ParquetFile(fp)
for batch in pf.iter_batches(batch_size=50000, columns=["text"]):
for t in batch.column("text").to_pylist():
if isinstance(t, str) and len(t.strip()) >= 3:
all_texts.append(t.strip())
if len(all_texts) >= N_SAMPLE * 5:
break
if len(all_texts) >= N_SAMPLE * 5:
break
return random.sample(all_texts, min(N_SAMPLE, len(all_texts)))
return []
def main():
datasets = ["DODa", "Darija-Wiki", "Atlaset"]
all_results = []
# Load existing results to allow resume
csv_path = os.path.join(BASE, "external_datasets_eval.csv")
done_datasets = set()
if os.path.exists(csv_path):
with open(csv_path) as f:
for row in csv.DictReader(f):
all_results.append(Result(
dataset=row["dataset"], tokenizer=row["tokenizer"],
source=row.get("source", ""), vocab_size=int(row.get("vocab_size", 0)),
fertility_ar=float(row.get("fertility_ar", 0)),
fertility_az=float(row.get("fertility_az", 0)),
fertility_overall=float(row.get("fertility_overall", 0)),
cpt_ar=float(row.get("cpt_ar", 0)),
cpt_az=float(row.get("cpt_az", 0)),
cpt_overall=float(row.get("cpt_overall", 0)),
gain_pct=float(row.get("gain_pct", 0)),
n_texts=int(row.get("n_texts", 0)),
))
done_datasets.add(row["dataset"])
if done_datasets:
print(f"Resuming — already done: {done_datasets}", flush=True)
for ds_name in datasets:
if ds_name in done_datasets:
print(f"\n{'='*80}")
print(f"SKIP (already done): {ds_name}", flush=True)
continue
print(f"\n{'='*80}")
print(f"DATASET: {ds_name}", flush=True)
texts = load_dataset_texts(ds_name)
print(f" Sampled {len(texts):,} texts", flush=True)
ds_results = {}
for cfg in TOKENIZERS:
name, source, vsz = cfg[0], cfg[1], cfg[2]
kind = cfg[3]
print(f"\n Tokenizer: {name} ({vsz:,})", flush=True)
try:
if kind == "concat":
ar_j = os.path.join(TOK_DIR, cfg[4])
az_j = os.path.join(TOK_DIR, cfg[5])
tok = RawConcat(ar_j, az_j)
else:
repo = cfg[4]
tok = HFTok(repo)
m = evaluate(tok, texts)
r = Result(
dataset=ds_name, tokenizer=name, vocab_size=vsz, source=source,
fertility_ar=round(m["fertility_ar"], 3),
fertility_az=round(m["fertility_az"], 3),
fertility_overall=round(m["fertility_overall"], 3),
cpt_ar=round(m["cpt_ar"], 3),
cpt_az=round(m["cpt_az"], 3),
cpt_overall=round(m["cpt_overall"], 3),
n_texts=m["n_texts"],
)
ds_results[name] = r
print(f" F_ar={r.fertility_ar:.3f} F_az={r.fertility_az:.3f} "
f"F={r.fertility_overall:.3f} CPT={r.cpt_overall:.3f}", flush=True)
del tok; gc.collect()
except Exception as e:
print(f" FAILED: {e}", flush=True)
r = Result(dataset=ds_name, tokenizer=name, vocab_size=vsz, source=source)
ds_results[name] = r
# Compute gains
gain_map = GAIN_BASELINES.get(ds_name, {})
for our_name, baseline_name in gain_map.items():
if our_name in ds_results and baseline_name in ds_results:
f_base = ds_results[baseline_name].fertility_overall
f_ours = ds_results[our_name].fertility_overall
if f_base > 0:
ds_results[our_name].gain_pct = round((f_base - f_ours) / f_base * 100, 1)
all_results.extend(ds_results.values())
# Save incrementally
save_csv(all_results)
print(f"\n Saved intermediate results.", flush=True)
# Print final tables
print(f"\n{'='*120}")
for ds_name in datasets:
ds_rows = [r for r in all_results if r.dataset == ds_name]
print(f"\n {ds_name}:")
print(f" {'Tokenizer':<22} {'V':>7} {'F_ar':>7} {'F_az':>7} {'F':>7} "
f"{'CPT_ar':>7} {'CPT_az':>7} {'CPT':>7} {'Gain':>7}")
print(" " + "-" * 100)
for r in ds_rows:
gain = f"{r.gain_pct}%" if r.gain_pct != 0 else "---"
print(f" {r.tokenizer:<22} {r.vocab_size:>7,} {r.fertility_ar:>7.3f} "
f"{r.fertility_az:>7.3f} {r.fertility_overall:>7.3f} "
f"{r.cpt_ar:>7.3f} {r.cpt_az:>7.3f} {r.cpt_overall:>7.3f} {gain:>7}")
print(f"\n{'='*120}")
# Also save JSON
json_path = os.path.join(BASE, "external_datasets_eval.json")
with open(json_path, "w") as f:
json.dump([asdict(r) for r in all_results], f, indent=2)
print(f"Saved JSON: {json_path}")
print("DONE!")
def save_csv(results):
csv_path = os.path.join(BASE, "external_datasets_eval.csv")
fieldnames = ["dataset", "tokenizer", "source", "vocab_size",
"fertility_ar", "fertility_az", "fertility_overall",
"cpt_ar", "cpt_az", "cpt_overall", "gain_pct", "n_texts"]
with open(csv_path, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
w.writeheader()
for r in results:
w.writerow(asdict(r))
if __name__ == "__main__":
main()