| |
| """ |
| eval_external_datasets.py — Evaluate the DarijaBERT family vs our tokenizers |
| on three independent datasets (10K sample each): |
| - DODa (Arabizi) atlasia/DODa |
| - Darija-Wiki (Arabic) atlasia/Moroccan-Darija-Wiki-Dataset |
| - Atlaset (Arabic) atlasia/Atlaset |
| |
| For each dataset × tokenizer, compute: F_ar, F_az, F (overall), CPT, and Gain. |
| """ |
|
|
| import json, os, csv, gc, random, warnings |
| from dataclasses import dataclass, asdict |
|
|
| import numpy as np |
| import regex |
| warnings.filterwarnings("ignore") |
|
|
| random.seed(42) |
|
|
| BASE = "/root/oiq_cc_tokenizer/results" |
| TOK_DIR = os.path.join(BASE, "tokenizers") |
| PARQUET_DIR = "/root/oiq_cc_tokenizer/tmp_atlastet" |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| N_SAMPLE = 10000 |
|
|
| _WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE) |
| _AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]") |
| _LAT_PAT = regex.compile(r"[a-zA-Z]") |
| _SPECIAL = {"<unk>", "<s>", "</s>", "[CLS]", "[SEP]", "[PAD]", "[UNK]", "<pad>", "", |
| "<|im_start|>", "<|im_end|>"} |
|
|
|
|
| def segment_words(t): return _WORD_PAT.findall(t) |
| def count_graphemes(t): return len(regex.findall(r"\X", t)) |
| def filter_sp(tokens): return [t for t in tokens if t not in _SPECIAL] |
|
|
| def detect_script(t): |
| return "ar" if len(_AR_PAT.findall(t)) > len(t) * 0.3 else "az" |
|
|
| def normalize_decode(s): |
| s = s.replace("##", "") |
| s = " ".join(s.split()) |
| return s |
|
|
|
|
| class RawConcat: |
| def __init__(self, ar_j, az_j): |
| from tokenizers import Tokenizer |
| self.ar = Tokenizer.from_file(ar_j) |
| self.az = Tokenizer.from_file(az_j) |
|
|
| def encode(self, text): |
| s = detect_script(text) |
| t = self.ar if s == "ar" else self.az |
| enc = t.encode(text) |
| return enc.tokens, enc.ids, s |
|
|
| def decode(self, ids, script): |
| t = self.ar if script == "ar" else self.az |
| return t.decode(ids, skip_special_tokens=True) |
|
|
|
|
| class HFTok: |
| def __init__(self, repo): |
| from transformers import AutoTokenizer |
| kwargs = {"trust_remote_code": True} |
| if HF_TOKEN: |
| kwargs["token"] = HF_TOKEN |
| self.tok = AutoTokenizer.from_pretrained(repo, **kwargs) |
|
|
| def encode(self, text): |
| ids = self.tok.encode(text, add_special_tokens=False) |
| return self.tok.convert_ids_to_tokens(ids), ids, detect_script(text) |
|
|
| def decode(self, ids, script): |
| return self.tok.decode(ids, skip_special_tokens=True) |
|
|
|
|
| @dataclass |
| class Result: |
| dataset: str = "" |
| tokenizer: str = "" |
| vocab_size: int = 0 |
| source: str = "" |
| fertility_ar: float = 0.0 |
| fertility_az: float = 0.0 |
| fertility_overall: float = 0.0 |
| cpt_ar: float = 0.0 |
| cpt_az: float = 0.0 |
| cpt_overall: float = 0.0 |
| gain_pct: float = 0.0 |
| n_texts: int = 0 |
|
|
|
|
| def evaluate(tok, texts): |
| ar_f, az_f, all_f = [], [], [] |
| ar_c, az_c, all_c = [], [], [] |
| n = 0 |
|
|
| for i, text in enumerate(texts): |
| if (i + 1) % 5000 == 0: |
| print(f" [{i+1}/{len(texts)}]", flush=True) |
| try: |
| tokens, ids, script = tok.encode(text) |
| content = filter_sp(tokens) |
| words = segment_words(text) |
| if not words: |
| continue |
| fert = len(content) / len(words) |
| cpt = count_graphemes(text) / max(len(content), 1) |
| all_f.append(fert) |
| all_c.append(cpt) |
| if script == "ar": |
| ar_f.append(fert); ar_c.append(cpt) |
| else: |
| az_f.append(fert); az_c.append(cpt) |
| n += 1 |
| except: |
| pass |
|
|
| return { |
| "fertility_ar": float(np.mean(ar_f)) if ar_f else 0, |
| "fertility_az": float(np.mean(az_f)) if az_f else 0, |
| "fertility_overall": float(np.mean(all_f)) if all_f else 0, |
| "cpt_ar": float(np.mean(ar_c)) if ar_c else 0, |
| "cpt_az": float(np.mean(az_c)) if az_c else 0, |
| "cpt_overall": float(np.mean(all_c)) if all_c else 0, |
| "n_texts": n, |
| } |
|
|
|
|
| |
| |
| |
| TOKENIZERS = [ |
| ("DarijaBERT-ar", "external", 80000, "hf", "SI2M-Lab/DarijaBERT"), |
| ("DarijaBERT-az", "external", 110000, "hf", "SI2M-Lab/DarijaBERT-arabizi"), |
| ("DarijaBERT-mix", "external", 160000, "hf", "SI2M-Lab/DarijaBERT-mix"), |
| ("Ours (80K WP)", "ours", 80000, "concat", |
| "concat_ar_wordpiece_40000.json", "concat_az_wordpiece_40000.json"), |
| ("Ours (110K WP)", "ours", 110000, "concat", |
| "concat_ar_wordpiece_55000.json", "concat_az_wordpiece_55000.json"), |
| ("Ours (32K BPE)", "ours", 32000, "concat", |
| "concat_ar_bpe_16000.json", "concat_az_bpe_16000.json"), |
| ] |
|
|
| |
| |
| GAIN_BASELINES = { |
| "DODa": {"Ours (80K WP)": "DarijaBERT-ar", |
| "Ours (110K WP)": "DarijaBERT-az", |
| "Ours (32K BPE)": "DarijaBERT-mix"}, |
| "Darija-Wiki": {"Ours (80K WP)": "DarijaBERT-ar", |
| "Ours (110K WP)": "DarijaBERT-az", |
| "Ours (32K BPE)": "DarijaBERT-mix"}, |
| "Atlaset": {"Ours (80K WP)": "DarijaBERT-ar", |
| "Ours (110K WP)": "DarijaBERT-az", |
| "Ours (32K BPE)": "DarijaBERT-mix"}, |
| } |
|
|
|
|
| def load_dataset_texts(dataset_name): |
| """Load 10K random sample from each dataset.""" |
| from datasets import load_dataset |
|
|
| if dataset_name == "DODa": |
| print(" Loading DODa (atlasia/DODa)...", flush=True) |
| ds = load_dataset("atlasia/DODa", split="train", streaming=True, token=HF_TOKEN) |
| all_texts = [] |
| for row in ds: |
| t = row.get("darija", "") |
| if isinstance(t, str) and len(t.strip()) >= 3: |
| all_texts.append(t.strip()) |
| if len(all_texts) >= N_SAMPLE * 3: |
| break |
| return random.sample(all_texts, min(N_SAMPLE, len(all_texts))) |
|
|
| elif dataset_name == "Darija-Wiki": |
| print(" Loading Darija-Wiki (atlasia/Moroccan-Darija-Wiki-Dataset)...", flush=True) |
| ds = load_dataset("atlasia/Moroccan-Darija-Wiki-Dataset", split="train", |
| streaming=True, token=HF_TOKEN) |
| all_texts = [] |
| for row in ds: |
| t = row.get("content", "") |
| if isinstance(t, str) and len(t.strip()) >= 10: |
| all_texts.append(t.strip()) |
| if len(all_texts) >= N_SAMPLE * 3: |
| break |
| return random.sample(all_texts, min(N_SAMPLE, len(all_texts))) |
|
|
| elif dataset_name == "Atlaset": |
| print(" Loading Atlaset from local parquet...", flush=True) |
| import pyarrow.parquet as pq, glob |
| files = sorted(glob.glob(os.path.join(PARQUET_DIR, "data", "train-*.parquet"))) |
| all_texts = [] |
| for fp in files: |
| pf = pq.ParquetFile(fp) |
| for batch in pf.iter_batches(batch_size=50000, columns=["text"]): |
| for t in batch.column("text").to_pylist(): |
| if isinstance(t, str) and len(t.strip()) >= 3: |
| all_texts.append(t.strip()) |
| if len(all_texts) >= N_SAMPLE * 5: |
| break |
| if len(all_texts) >= N_SAMPLE * 5: |
| break |
| return random.sample(all_texts, min(N_SAMPLE, len(all_texts))) |
|
|
| return [] |
|
|
|
|
| def main(): |
| datasets = ["DODa", "Darija-Wiki", "Atlaset"] |
| all_results = [] |
|
|
| |
| csv_path = os.path.join(BASE, "external_datasets_eval.csv") |
| done_datasets = set() |
| if os.path.exists(csv_path): |
| with open(csv_path) as f: |
| for row in csv.DictReader(f): |
| all_results.append(Result( |
| dataset=row["dataset"], tokenizer=row["tokenizer"], |
| source=row.get("source", ""), vocab_size=int(row.get("vocab_size", 0)), |
| fertility_ar=float(row.get("fertility_ar", 0)), |
| fertility_az=float(row.get("fertility_az", 0)), |
| fertility_overall=float(row.get("fertility_overall", 0)), |
| cpt_ar=float(row.get("cpt_ar", 0)), |
| cpt_az=float(row.get("cpt_az", 0)), |
| cpt_overall=float(row.get("cpt_overall", 0)), |
| gain_pct=float(row.get("gain_pct", 0)), |
| n_texts=int(row.get("n_texts", 0)), |
| )) |
| done_datasets.add(row["dataset"]) |
| if done_datasets: |
| print(f"Resuming — already done: {done_datasets}", flush=True) |
|
|
| for ds_name in datasets: |
| if ds_name in done_datasets: |
| print(f"\n{'='*80}") |
| print(f"SKIP (already done): {ds_name}", flush=True) |
| continue |
| print(f"\n{'='*80}") |
| print(f"DATASET: {ds_name}", flush=True) |
| texts = load_dataset_texts(ds_name) |
| print(f" Sampled {len(texts):,} texts", flush=True) |
|
|
| ds_results = {} |
|
|
| for cfg in TOKENIZERS: |
| name, source, vsz = cfg[0], cfg[1], cfg[2] |
| kind = cfg[3] |
| print(f"\n Tokenizer: {name} ({vsz:,})", flush=True) |
|
|
| try: |
| if kind == "concat": |
| ar_j = os.path.join(TOK_DIR, cfg[4]) |
| az_j = os.path.join(TOK_DIR, cfg[5]) |
| tok = RawConcat(ar_j, az_j) |
| else: |
| repo = cfg[4] |
| tok = HFTok(repo) |
|
|
| m = evaluate(tok, texts) |
| r = Result( |
| dataset=ds_name, tokenizer=name, vocab_size=vsz, source=source, |
| fertility_ar=round(m["fertility_ar"], 3), |
| fertility_az=round(m["fertility_az"], 3), |
| fertility_overall=round(m["fertility_overall"], 3), |
| cpt_ar=round(m["cpt_ar"], 3), |
| cpt_az=round(m["cpt_az"], 3), |
| cpt_overall=round(m["cpt_overall"], 3), |
| n_texts=m["n_texts"], |
| ) |
| ds_results[name] = r |
| print(f" F_ar={r.fertility_ar:.3f} F_az={r.fertility_az:.3f} " |
| f"F={r.fertility_overall:.3f} CPT={r.cpt_overall:.3f}", flush=True) |
| del tok; gc.collect() |
| except Exception as e: |
| print(f" FAILED: {e}", flush=True) |
| r = Result(dataset=ds_name, tokenizer=name, vocab_size=vsz, source=source) |
| ds_results[name] = r |
|
|
| |
| gain_map = GAIN_BASELINES.get(ds_name, {}) |
| for our_name, baseline_name in gain_map.items(): |
| if our_name in ds_results and baseline_name in ds_results: |
| f_base = ds_results[baseline_name].fertility_overall |
| f_ours = ds_results[our_name].fertility_overall |
| if f_base > 0: |
| ds_results[our_name].gain_pct = round((f_base - f_ours) / f_base * 100, 1) |
|
|
| all_results.extend(ds_results.values()) |
|
|
| |
| save_csv(all_results) |
| print(f"\n Saved intermediate results.", flush=True) |
|
|
| |
| print(f"\n{'='*120}") |
| for ds_name in datasets: |
| ds_rows = [r for r in all_results if r.dataset == ds_name] |
| print(f"\n {ds_name}:") |
| print(f" {'Tokenizer':<22} {'V':>7} {'F_ar':>7} {'F_az':>7} {'F':>7} " |
| f"{'CPT_ar':>7} {'CPT_az':>7} {'CPT':>7} {'Gain':>7}") |
| print(" " + "-" * 100) |
| for r in ds_rows: |
| gain = f"{r.gain_pct}%" if r.gain_pct != 0 else "---" |
| print(f" {r.tokenizer:<22} {r.vocab_size:>7,} {r.fertility_ar:>7.3f} " |
| f"{r.fertility_az:>7.3f} {r.fertility_overall:>7.3f} " |
| f"{r.cpt_ar:>7.3f} {r.cpt_az:>7.3f} {r.cpt_overall:>7.3f} {gain:>7}") |
| print(f"\n{'='*120}") |
|
|
| |
| json_path = os.path.join(BASE, "external_datasets_eval.json") |
| with open(json_path, "w") as f: |
| json.dump([asdict(r) for r in all_results], f, indent=2) |
| print(f"Saved JSON: {json_path}") |
| print("DONE!") |
|
|
|
|
| def save_csv(results): |
| csv_path = os.path.join(BASE, "external_datasets_eval.csv") |
| fieldnames = ["dataset", "tokenizer", "source", "vocab_size", |
| "fertility_ar", "fertility_az", "fertility_overall", |
| "cpt_ar", "cpt_az", "cpt_overall", "gain_pct", "n_texts"] |
| with open(csv_path, "w", newline="") as f: |
| w = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") |
| w.writeheader() |
| for r in results: |
| w.writerow(asdict(r)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|