#!/usr/bin/env python3 """ compare_with_external.py Compare our best 3 tokenizers (one per vocab size) against existing Arabic/Darija tokenizers from HuggingFace. Our tokenizers (from benchmark): - concat_bpe_8000 (V=8K) - concat_wordpiece_16000 (V=16K) - concat_wordpiece_32000 (V=32K) External tokenizers: - CAMeL-Lab/bert-base-arabic-camelbert-msa (WordPiece 30K, MSA) - asafaya/bert-base-arabic (WordPiece 32K, MSA) - riotu-lab/Aranizer-SP-86k (SentencePiece 86K, MSA) - SI2M-Lab/DarijaBERT (WordPiece 80K, Darija Arabic) - SI2M-Lab/DarijaBERT-arabizi (WordPiece 110K, Darija Arabizi) """ import json, os, sys, time, re, warnings from collections import Counter from dataclasses import dataclass, field, asdict from typing import List, Dict, Tuple import numpy as np warnings.filterwarnings("ignore") # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- BASE = "/root/oiq_cc_tokenizer" RESULTS = os.path.join(BASE, "results") CORPORA = os.path.join(RESULTS, "corpora") TOKENIZER_DIR = os.path.join(RESULTS, "tokenizers") TRANS_DIR = os.path.join(RESULTS, "transformers_tokenizers") import regex _WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE) _AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]") def segment_words(text): return _WORD_PAT.findall(text) def count_graphemes(text): return len(regex.findall(r"\X", text)) def detect_script(text): return "ar" if len(_AR_PAT.findall(text)) > len(text) * 0.3 else "az" # --------------------------------------------------------------------------- # Load test corpora # --------------------------------------------------------------------------- def load_test_texts(): texts = {"ar": [], "az": [], "mi": []} for split in ("test", "val"): for script in ("ar", "az", "mi"): path = os.path.join(CORPORA, f"{split}_{script}.txt") if os.path.exists(path): with open(path, encoding="utf-8") as f: texts[script].extend(l.strip() for l in f if l.strip()) return texts # --------------------------------------------------------------------------- # Tokenizer wrappers # --------------------------------------------------------------------------- class OurConcatTokenizer: """Wrapper for our concatenated tokenizers (HuggingFace tokenizers lib).""" def __init__(self, ar_dir, az_dir): from tokenizers import Tokenizer self.tok_ar = Tokenizer.from_file(os.path.join(ar_dir, "tokenizer.json")) self.tok_az = Tokenizer.from_file(os.path.join(az_dir, "tokenizer.json")) def encode(self, text): script = detect_script(text) if script == "ar": enc = self.tok_ar.encode(text) else: enc = self.tok_az.encode(text) return enc.tokens, enc.ids def decode(self, ids, script=None): if script == "ar": return self.tok_ar.decode(ids, skip_special_tokens=True) else: return self.tok_az.decode(ids, skip_special_tokens=True) class HFTokenizer: """Wrapper for HuggingFace transformers tokenizers.""" def __init__(self, repo_id): from transformers import AutoTokenizer self.tok = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True) def encode(self, text): ids = self.tok.encode(text, add_special_tokens=False) tokens = self.tok.convert_ids_to_tokens(ids) return tokens, ids def decode(self, ids, script=None): return self.tok.decode(ids, skip_special_tokens=True) # --------------------------------------------------------------------------- # Evaluation # --------------------------------------------------------------------------- @dataclass class CompMetrics: name: str = "" source: str = "" vocab_size: int = 0 fertility_ar: float = 0.0 fertility_az: float = 0.0 fertility_overall: float = 0.0 disparity: float = 0.0 cpt_ar: float = 0.0 cpt_az: float = 0.0 exact_match_ar: float = 0.0 exact_match_az: float = 0.0 def evaluate_tokenizer(tok, name, source, vocab_size, test_texts): metrics = CompMetrics(name=name, source=source, vocab_size=vocab_size) # Track per-script stats ar_fert_list, az_fert_list = [], [] ar_cpt_list, az_cpt_list = [], [] ar_match, az_match = 0, 0 ar_total, az_total = 0, 0 all_texts = test_texts["ar"] + test_texts["az"] + test_texts["mi"] all_fert = [] n_total = len(all_texts) for i, text in enumerate(all_texts): if (i + 1) % 5000 == 0: print(f" [{i+1}/{n_total}] {name}", flush=True) script = detect_script(text) try: tokens, ids = tok.encode(text) filtered = [t for t in tokens if not t.startswith("[") and not t.startswith("<") and t not in ("[CLS]", "[SEP]", "[PAD]", "[UNK]", "", "", "", "")] words = segment_words(text) if len(words) == 0: continue fertility = len(filtered) / len(words) all_fert.append(fertility) try: decoded = tok.decode(ids, script=script) exact = decoded.strip() == text.strip() except Exception: exact = False if script == "ar": ar_fert_list.append(fertility) ar_cpt_list.append(count_graphemes(text) / max(len(filtered), 1)) ar_total += 1 if exact: ar_match += 1 else: az_fert_list.append(fertility) az_cpt_list.append(count_graphemes(text) / max(len(filtered), 1)) az_total += 1 if exact: az_match += 1 except Exception as e: pass metrics.fertility_ar = float(np.mean(ar_fert_list)) if ar_fert_list else 0 metrics.fertility_az = float(np.mean(az_fert_list)) if az_fert_list else 0 metrics.fertility_overall = float(np.mean(all_fert)) if all_fert else 0 metrics.disparity = abs(metrics.fertility_ar - metrics.fertility_az) / max(metrics.fertility_ar, metrics.fertility_az, 1e-9) metrics.cpt_ar = float(np.mean(ar_cpt_list)) if ar_cpt_list else 0 metrics.cpt_az = float(np.mean(az_cpt_list)) if az_cpt_list else 0 metrics.exact_match_ar = ar_match / max(ar_total, 1) metrics.exact_match_az = az_match / max(az_total, 1) return metrics # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): print("Loading test texts...") test_texts = load_test_texts() total = sum(len(v) for v in test_texts.values()) print(f" Total: {total} texts (ar={len(test_texts['ar'])}, az={len(test_texts['az'])}, mi={len(test_texts['mi'])})") all_results = [] # --- Our tokenizers --- ours = [ ("Ours: concat_bpe_8K", "ours", 8000, os.path.join(TRANS_DIR, "concat_bpe_8000_tokenizer_ar"), os.path.join(TRANS_DIR, "concat_bpe_8000_tokenizer_az")), ("Ours: concat_wp_16K", "ours", 16000, os.path.join(TRANS_DIR, "concat_wordpiece_16000_tokenizer_ar"), os.path.join(TRANS_DIR, "concat_wordpiece_16000_tokenizer_az")), ("Ours: concat_wp_32K", "ours", 32000, os.path.join(TRANS_DIR, "concat_wordpiece_32000_tokenizer_ar"), os.path.join(TRANS_DIR, "concat_wordpiece_32000_tokenizer_az")), ] for name, source, vsz, ar_dir, az_dir in ours: if os.path.exists(ar_dir) and os.path.exists(az_dir): print(f"\nEvaluating {name}...") t0 = time.perf_counter() tok = OurConcatTokenizer(ar_dir, az_dir) m = evaluate_tokenizer(tok, name, source, vsz, test_texts) print(f" [{time.perf_counter()-t0:.1f}s] Fert={m.fertility_overall:.3f} Disp={m.disparity:.3f} EM_ar={m.exact_match_ar:.2%} EM_az={m.exact_match_az:.2%}") all_results.append(m) else: print(f"\nSKIP {name} (missing: {ar_dir} or {az_dir})") # --- External tokenizers --- externals = [ ("CaMeLBERT-MSA (30K WP)", "external_msa", 30000, "CAMeL-Lab/bert-base-arabic-camelbert-msa"), ("Asafaya-BERT (32K WP)", "external_msa", 32000, "asafaya/bert-base-arabic"), ("Aranizer (86K SP)", "external_msa", 86000, "riotu-lab/Aranizer-SP-86k"), ("DarijaBERT-ar (80K WP)", "external_darija", 80000, "SI2M-Lab/DarijaBERT"), ("DarijaBERT-az (110K WP)", "external_darija", 110000, "SI2M-Lab/DarijaBERT-arabizi"), ] for name, source, vsz, repo in externals: print(f"\nEvaluating {name} ({repo})...") try: t0 = time.perf_counter() tok = HFTokenizer(repo) m = evaluate_tokenizer(tok, name, source, vsz, test_texts) print(f" [{time.perf_counter()-t0:.1f}s] Fert={m.fertility_overall:.3f} Disp={m.disparity:.3f} EM_ar={m.exact_match_ar:.2%} EM_az={m.exact_match_az:.2%}") all_results.append(m) except Exception as e: print(f" FAILED: {e}") # --- Save results --- out_csv = os.path.join(RESULTS, "external_comparison.csv") out_json = os.path.join(RESULTS, "external_comparison.json") import csv with open(out_csv, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=[k for k in asdict(all_results[0]).keys()]) w.writeheader() for m in all_results: w.writerow(asdict(m)) with open(out_json, "w", encoding="utf-8") as f: json.dump([asdict(m) for m in all_results], f, indent=2) print(f"\nResults saved: {out_csv}, {out_json}") # --- Print summary table --- print("\n" + "=" * 120) print(f"{'Name':<30} {'Source':<16} {'V':>6} {'Fert':>7} {'F_ar':>7} {'F_az':>7} {'Disp':>7} {'CPT_ar':>7} {'CPT_az':>7} {'EM_ar':>7} {'EM_az':>7}") print("-" * 120) for m in sorted(all_results, key=lambda x: (x.source, x.vocab_size)): print(f"{m.name:<30} {m.source:<16} {m.vocab_size:>6,} {m.fertility_overall:>7.3f} {m.fertility_ar:>7.3f} {m.fertility_az:>7.3f} {m.disparity:>7.3f} {m.cpt_ar:>7.3f} {m.cpt_az:>7.3f} {m.exact_match_ar:>7.2%} {m.exact_match_az:>7.2%}") print("=" * 120) if __name__ == "__main__": main()