#!/usr/bin/env python3
"""
compare_with_external.py
Compare our best 3 tokenizers (one per vocab size) against existing
Arabic/Darija tokenizers from HuggingFace.
Our tokenizers (from benchmark):
- concat_bpe_8000 (V=8K)
- concat_wordpiece_16000 (V=16K)
- concat_wordpiece_32000 (V=32K)
External tokenizers:
- CAMeL-Lab/bert-base-arabic-camelbert-msa (WordPiece 30K, MSA)
- asafaya/bert-base-arabic (WordPiece 32K, MSA)
- riotu-lab/Aranizer-SP-86k (SentencePiece 86K, MSA)
- SI2M-Lab/DarijaBERT (WordPiece 80K, Darija Arabic)
- SI2M-Lab/DarijaBERT-arabizi (WordPiece 110K, Darija Arabizi)
"""
import json, os, sys, time, re, warnings
from collections import Counter
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Tuple
import numpy as np
warnings.filterwarnings("ignore")
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
BASE = "/root/oiq_cc_tokenizer"
RESULTS = os.path.join(BASE, "results")
CORPORA = os.path.join(RESULTS, "corpora")
TOKENIZER_DIR = os.path.join(RESULTS, "tokenizers")
TRANS_DIR = os.path.join(RESULTS, "transformers_tokenizers")
import regex
_WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE)
_AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]")
def segment_words(text):
return _WORD_PAT.findall(text)
def count_graphemes(text):
return len(regex.findall(r"\X", text))
def detect_script(text):
return "ar" if len(_AR_PAT.findall(text)) > len(text) * 0.3 else "az"
# ---------------------------------------------------------------------------
# Load test corpora
# ---------------------------------------------------------------------------
def load_test_texts():
texts = {"ar": [], "az": [], "mi": []}
for split in ("test", "val"):
for script in ("ar", "az", "mi"):
path = os.path.join(CORPORA, f"{split}_{script}.txt")
if os.path.exists(path):
with open(path, encoding="utf-8") as f:
texts[script].extend(l.strip() for l in f if l.strip())
return texts
# ---------------------------------------------------------------------------
# Tokenizer wrappers
# ---------------------------------------------------------------------------
class OurConcatTokenizer:
"""Wrapper for our concatenated tokenizers (HuggingFace tokenizers lib)."""
def __init__(self, ar_dir, az_dir):
from tokenizers import Tokenizer
self.tok_ar = Tokenizer.from_file(os.path.join(ar_dir, "tokenizer.json"))
self.tok_az = Tokenizer.from_file(os.path.join(az_dir, "tokenizer.json"))
def encode(self, text):
script = detect_script(text)
if script == "ar":
enc = self.tok_ar.encode(text)
else:
enc = self.tok_az.encode(text)
return enc.tokens, enc.ids
def decode(self, ids, script=None):
if script == "ar":
return self.tok_ar.decode(ids, skip_special_tokens=True)
else:
return self.tok_az.decode(ids, skip_special_tokens=True)
class HFTokenizer:
"""Wrapper for HuggingFace transformers tokenizers."""
def __init__(self, repo_id):
from transformers import AutoTokenizer
self.tok = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
def encode(self, text):
ids = self.tok.encode(text, add_special_tokens=False)
tokens = self.tok.convert_ids_to_tokens(ids)
return tokens, ids
def decode(self, ids, script=None):
return self.tok.decode(ids, skip_special_tokens=True)
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@dataclass
class CompMetrics:
name: str = ""
source: str = ""
vocab_size: int = 0
fertility_ar: float = 0.0
fertility_az: float = 0.0
fertility_overall: float = 0.0
disparity: float = 0.0
cpt_ar: float = 0.0
cpt_az: float = 0.0
exact_match_ar: float = 0.0
exact_match_az: float = 0.0
def evaluate_tokenizer(tok, name, source, vocab_size, test_texts):
metrics = CompMetrics(name=name, source=source, vocab_size=vocab_size)
# Track per-script stats
ar_fert_list, az_fert_list = [], []
ar_cpt_list, az_cpt_list = [], []
ar_match, az_match = 0, 0
ar_total, az_total = 0, 0
all_texts = test_texts["ar"] + test_texts["az"] + test_texts["mi"]
all_fert = []
n_total = len(all_texts)
for i, text in enumerate(all_texts):
if (i + 1) % 5000 == 0:
print(f" [{i+1}/{n_total}] {name}", flush=True)
script = detect_script(text)
try:
tokens, ids = tok.encode(text)
filtered = [t for t in tokens if not t.startswith("[") and not t.startswith("<") and t not in ("[CLS]", "[SEP]", "[PAD]", "[UNK]", "", "", "", "")]
words = segment_words(text)
if len(words) == 0:
continue
fertility = len(filtered) / len(words)
all_fert.append(fertility)
try:
decoded = tok.decode(ids, script=script)
exact = decoded.strip() == text.strip()
except Exception:
exact = False
if script == "ar":
ar_fert_list.append(fertility)
ar_cpt_list.append(count_graphemes(text) / max(len(filtered), 1))
ar_total += 1
if exact:
ar_match += 1
else:
az_fert_list.append(fertility)
az_cpt_list.append(count_graphemes(text) / max(len(filtered), 1))
az_total += 1
if exact:
az_match += 1
except Exception as e:
pass
metrics.fertility_ar = float(np.mean(ar_fert_list)) if ar_fert_list else 0
metrics.fertility_az = float(np.mean(az_fert_list)) if az_fert_list else 0
metrics.fertility_overall = float(np.mean(all_fert)) if all_fert else 0
metrics.disparity = abs(metrics.fertility_ar - metrics.fertility_az) / max(metrics.fertility_ar, metrics.fertility_az, 1e-9)
metrics.cpt_ar = float(np.mean(ar_cpt_list)) if ar_cpt_list else 0
metrics.cpt_az = float(np.mean(az_cpt_list)) if az_cpt_list else 0
metrics.exact_match_ar = ar_match / max(ar_total, 1)
metrics.exact_match_az = az_match / max(az_total, 1)
return metrics
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print("Loading test texts...")
test_texts = load_test_texts()
total = sum(len(v) for v in test_texts.values())
print(f" Total: {total} texts (ar={len(test_texts['ar'])}, az={len(test_texts['az'])}, mi={len(test_texts['mi'])})")
all_results = []
# --- Our tokenizers ---
ours = [
("Ours: concat_bpe_8K", "ours", 8000,
os.path.join(TRANS_DIR, "concat_bpe_8000_tokenizer_ar"),
os.path.join(TRANS_DIR, "concat_bpe_8000_tokenizer_az")),
("Ours: concat_wp_16K", "ours", 16000,
os.path.join(TRANS_DIR, "concat_wordpiece_16000_tokenizer_ar"),
os.path.join(TRANS_DIR, "concat_wordpiece_16000_tokenizer_az")),
("Ours: concat_wp_32K", "ours", 32000,
os.path.join(TRANS_DIR, "concat_wordpiece_32000_tokenizer_ar"),
os.path.join(TRANS_DIR, "concat_wordpiece_32000_tokenizer_az")),
]
for name, source, vsz, ar_dir, az_dir in ours:
if os.path.exists(ar_dir) and os.path.exists(az_dir):
print(f"\nEvaluating {name}...")
t0 = time.perf_counter()
tok = OurConcatTokenizer(ar_dir, az_dir)
m = evaluate_tokenizer(tok, name, source, vsz, test_texts)
print(f" [{time.perf_counter()-t0:.1f}s] Fert={m.fertility_overall:.3f} Disp={m.disparity:.3f} EM_ar={m.exact_match_ar:.2%} EM_az={m.exact_match_az:.2%}")
all_results.append(m)
else:
print(f"\nSKIP {name} (missing: {ar_dir} or {az_dir})")
# --- External tokenizers ---
externals = [
("CaMeLBERT-MSA (30K WP)", "external_msa", 30000, "CAMeL-Lab/bert-base-arabic-camelbert-msa"),
("Asafaya-BERT (32K WP)", "external_msa", 32000, "asafaya/bert-base-arabic"),
("Aranizer (86K SP)", "external_msa", 86000, "riotu-lab/Aranizer-SP-86k"),
("DarijaBERT-ar (80K WP)", "external_darija", 80000, "SI2M-Lab/DarijaBERT"),
("DarijaBERT-az (110K WP)", "external_darija", 110000, "SI2M-Lab/DarijaBERT-arabizi"),
]
for name, source, vsz, repo in externals:
print(f"\nEvaluating {name} ({repo})...")
try:
t0 = time.perf_counter()
tok = HFTokenizer(repo)
m = evaluate_tokenizer(tok, name, source, vsz, test_texts)
print(f" [{time.perf_counter()-t0:.1f}s] Fert={m.fertility_overall:.3f} Disp={m.disparity:.3f} EM_ar={m.exact_match_ar:.2%} EM_az={m.exact_match_az:.2%}")
all_results.append(m)
except Exception as e:
print(f" FAILED: {e}")
# --- Save results ---
out_csv = os.path.join(RESULTS, "external_comparison.csv")
out_json = os.path.join(RESULTS, "external_comparison.json")
import csv
with open(out_csv, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=[k for k in asdict(all_results[0]).keys()])
w.writeheader()
for m in all_results:
w.writerow(asdict(m))
with open(out_json, "w", encoding="utf-8") as f:
json.dump([asdict(m) for m in all_results], f, indent=2)
print(f"\nResults saved: {out_csv}, {out_json}")
# --- Print summary table ---
print("\n" + "=" * 120)
print(f"{'Name':<30} {'Source':<16} {'V':>6} {'Fert':>7} {'F_ar':>7} {'F_az':>7} {'Disp':>7} {'CPT_ar':>7} {'CPT_az':>7} {'EM_ar':>7} {'EM_az':>7}")
print("-" * 120)
for m in sorted(all_results, key=lambda x: (x.source, x.vocab_size)):
print(f"{m.name:<30} {m.source:<16} {m.vocab_size:>6,} {m.fertility_overall:>7.3f} {m.fertility_ar:>7.3f} {m.fertility_az:>7.3f} {m.disparity:>7.3f} {m.cpt_ar:>7.3f} {m.cpt_az:>7.3f} {m.exact_match_ar:>7.2%} {m.exact_match_az:>7.2%}")
print("=" * 120)
if __name__ == "__main__":
main()