| |
| """ |
| compare_with_external.py |
| Compare our best 3 tokenizers (one per vocab size) against existing |
| Arabic/Darija tokenizers from HuggingFace. |
| |
| Our tokenizers (from benchmark): |
| - concat_bpe_8000 (V=8K) |
| - concat_wordpiece_16000 (V=16K) |
| - concat_wordpiece_32000 (V=32K) |
| |
| External tokenizers: |
| - CAMeL-Lab/bert-base-arabic-camelbert-msa (WordPiece 30K, MSA) |
| - asafaya/bert-base-arabic (WordPiece 32K, MSA) |
| - riotu-lab/Aranizer-SP-86k (SentencePiece 86K, MSA) |
| - SI2M-Lab/DarijaBERT (WordPiece 80K, Darija Arabic) |
| - SI2M-Lab/DarijaBERT-arabizi (WordPiece 110K, Darija Arabizi) |
| """ |
|
|
| import json, os, sys, time, re, warnings |
| from collections import Counter |
| from dataclasses import dataclass, field, asdict |
| from typing import List, Dict, Tuple |
|
|
| import numpy as np |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| |
| |
| BASE = "/root/oiq_cc_tokenizer" |
| RESULTS = os.path.join(BASE, "results") |
| CORPORA = os.path.join(RESULTS, "corpora") |
| TOKENIZER_DIR = os.path.join(RESULTS, "tokenizers") |
| TRANS_DIR = os.path.join(RESULTS, "transformers_tokenizers") |
|
|
| import regex |
| _WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE) |
| _AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]") |
|
|
|
|
| def segment_words(text): |
| return _WORD_PAT.findall(text) |
|
|
|
|
| def count_graphemes(text): |
| return len(regex.findall(r"\X", text)) |
|
|
|
|
| def detect_script(text): |
| return "ar" if len(_AR_PAT.findall(text)) > len(text) * 0.3 else "az" |
|
|
|
|
| |
| |
| |
| def load_test_texts(): |
| texts = {"ar": [], "az": [], "mi": []} |
| for split in ("test", "val"): |
| for script in ("ar", "az", "mi"): |
| path = os.path.join(CORPORA, f"{split}_{script}.txt") |
| if os.path.exists(path): |
| with open(path, encoding="utf-8") as f: |
| texts[script].extend(l.strip() for l in f if l.strip()) |
| return texts |
|
|
|
|
| |
| |
| |
| class OurConcatTokenizer: |
| """Wrapper for our concatenated tokenizers (HuggingFace tokenizers lib).""" |
| def __init__(self, ar_dir, az_dir): |
| from tokenizers import Tokenizer |
| self.tok_ar = Tokenizer.from_file(os.path.join(ar_dir, "tokenizer.json")) |
| self.tok_az = Tokenizer.from_file(os.path.join(az_dir, "tokenizer.json")) |
|
|
| def encode(self, text): |
| script = detect_script(text) |
| if script == "ar": |
| enc = self.tok_ar.encode(text) |
| else: |
| enc = self.tok_az.encode(text) |
| return enc.tokens, enc.ids |
|
|
| def decode(self, ids, script=None): |
| if script == "ar": |
| return self.tok_ar.decode(ids, skip_special_tokens=True) |
| else: |
| return self.tok_az.decode(ids, skip_special_tokens=True) |
|
|
|
|
| class HFTokenizer: |
| """Wrapper for HuggingFace transformers tokenizers.""" |
| def __init__(self, repo_id): |
| from transformers import AutoTokenizer |
| self.tok = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True) |
|
|
| def encode(self, text): |
| ids = self.tok.encode(text, add_special_tokens=False) |
| tokens = self.tok.convert_ids_to_tokens(ids) |
| return tokens, ids |
|
|
| def decode(self, ids, script=None): |
| return self.tok.decode(ids, skip_special_tokens=True) |
|
|
|
|
| |
| |
| |
| @dataclass |
| class CompMetrics: |
| name: str = "" |
| source: str = "" |
| vocab_size: int = 0 |
| fertility_ar: float = 0.0 |
| fertility_az: float = 0.0 |
| fertility_overall: float = 0.0 |
| disparity: float = 0.0 |
| cpt_ar: float = 0.0 |
| cpt_az: float = 0.0 |
| exact_match_ar: float = 0.0 |
| exact_match_az: float = 0.0 |
|
|
|
|
| def evaluate_tokenizer(tok, name, source, vocab_size, test_texts): |
| metrics = CompMetrics(name=name, source=source, vocab_size=vocab_size) |
|
|
| |
| ar_fert_list, az_fert_list = [], [] |
| ar_cpt_list, az_cpt_list = [], [] |
| ar_match, az_match = 0, 0 |
| ar_total, az_total = 0, 0 |
|
|
| all_texts = test_texts["ar"] + test_texts["az"] + test_texts["mi"] |
| all_fert = [] |
| n_total = len(all_texts) |
|
|
| for i, text in enumerate(all_texts): |
| if (i + 1) % 5000 == 0: |
| print(f" [{i+1}/{n_total}] {name}", flush=True) |
| script = detect_script(text) |
| try: |
| tokens, ids = tok.encode(text) |
| filtered = [t for t in tokens if not t.startswith("[") and not t.startswith("<") and t not in ("[CLS]", "[SEP]", "[PAD]", "[UNK]", "<s>", "</s>", "<unk>", "<pad>")] |
|
|
| words = segment_words(text) |
| if len(words) == 0: |
| continue |
|
|
| fertility = len(filtered) / len(words) |
| all_fert.append(fertility) |
|
|
| try: |
| decoded = tok.decode(ids, script=script) |
| exact = decoded.strip() == text.strip() |
| except Exception: |
| exact = False |
|
|
| if script == "ar": |
| ar_fert_list.append(fertility) |
| ar_cpt_list.append(count_graphemes(text) / max(len(filtered), 1)) |
| ar_total += 1 |
| if exact: |
| ar_match += 1 |
| else: |
| az_fert_list.append(fertility) |
| az_cpt_list.append(count_graphemes(text) / max(len(filtered), 1)) |
| az_total += 1 |
| if exact: |
| az_match += 1 |
| except Exception as e: |
| pass |
|
|
| metrics.fertility_ar = float(np.mean(ar_fert_list)) if ar_fert_list else 0 |
| metrics.fertility_az = float(np.mean(az_fert_list)) if az_fert_list else 0 |
| metrics.fertility_overall = float(np.mean(all_fert)) if all_fert else 0 |
| metrics.disparity = abs(metrics.fertility_ar - metrics.fertility_az) / max(metrics.fertility_ar, metrics.fertility_az, 1e-9) |
| metrics.cpt_ar = float(np.mean(ar_cpt_list)) if ar_cpt_list else 0 |
| metrics.cpt_az = float(np.mean(az_cpt_list)) if az_cpt_list else 0 |
| metrics.exact_match_ar = ar_match / max(ar_total, 1) |
| metrics.exact_match_az = az_match / max(az_total, 1) |
|
|
| return metrics |
|
|
|
|
| |
| |
| |
| def main(): |
| print("Loading test texts...") |
| test_texts = load_test_texts() |
| total = sum(len(v) for v in test_texts.values()) |
| print(f" Total: {total} texts (ar={len(test_texts['ar'])}, az={len(test_texts['az'])}, mi={len(test_texts['mi'])})") |
|
|
| all_results = [] |
|
|
| |
| ours = [ |
| ("Ours: concat_bpe_8K", "ours", 8000, |
| os.path.join(TRANS_DIR, "concat_bpe_8000_tokenizer_ar"), |
| os.path.join(TRANS_DIR, "concat_bpe_8000_tokenizer_az")), |
| ("Ours: concat_wp_16K", "ours", 16000, |
| os.path.join(TRANS_DIR, "concat_wordpiece_16000_tokenizer_ar"), |
| os.path.join(TRANS_DIR, "concat_wordpiece_16000_tokenizer_az")), |
| ("Ours: concat_wp_32K", "ours", 32000, |
| os.path.join(TRANS_DIR, "concat_wordpiece_32000_tokenizer_ar"), |
| os.path.join(TRANS_DIR, "concat_wordpiece_32000_tokenizer_az")), |
| ] |
|
|
| for name, source, vsz, ar_dir, az_dir in ours: |
| if os.path.exists(ar_dir) and os.path.exists(az_dir): |
| print(f"\nEvaluating {name}...") |
| t0 = time.perf_counter() |
| tok = OurConcatTokenizer(ar_dir, az_dir) |
| m = evaluate_tokenizer(tok, name, source, vsz, test_texts) |
| print(f" [{time.perf_counter()-t0:.1f}s] Fert={m.fertility_overall:.3f} Disp={m.disparity:.3f} EM_ar={m.exact_match_ar:.2%} EM_az={m.exact_match_az:.2%}") |
| all_results.append(m) |
| else: |
| print(f"\nSKIP {name} (missing: {ar_dir} or {az_dir})") |
|
|
| |
| externals = [ |
| ("CaMeLBERT-MSA (30K WP)", "external_msa", 30000, "CAMeL-Lab/bert-base-arabic-camelbert-msa"), |
| ("Asafaya-BERT (32K WP)", "external_msa", 32000, "asafaya/bert-base-arabic"), |
| ("Aranizer (86K SP)", "external_msa", 86000, "riotu-lab/Aranizer-SP-86k"), |
| ("DarijaBERT-ar (80K WP)", "external_darija", 80000, "SI2M-Lab/DarijaBERT"), |
| ("DarijaBERT-az (110K WP)", "external_darija", 110000, "SI2M-Lab/DarijaBERT-arabizi"), |
| ] |
|
|
| for name, source, vsz, repo in externals: |
| print(f"\nEvaluating {name} ({repo})...") |
| try: |
| t0 = time.perf_counter() |
| tok = HFTokenizer(repo) |
| m = evaluate_tokenizer(tok, name, source, vsz, test_texts) |
| print(f" [{time.perf_counter()-t0:.1f}s] Fert={m.fertility_overall:.3f} Disp={m.disparity:.3f} EM_ar={m.exact_match_ar:.2%} EM_az={m.exact_match_az:.2%}") |
| all_results.append(m) |
| except Exception as e: |
| print(f" FAILED: {e}") |
|
|
| |
| out_csv = os.path.join(RESULTS, "external_comparison.csv") |
| out_json = os.path.join(RESULTS, "external_comparison.json") |
|
|
| import csv |
| with open(out_csv, "w", newline="", encoding="utf-8") as f: |
| w = csv.DictWriter(f, fieldnames=[k for k in asdict(all_results[0]).keys()]) |
| w.writeheader() |
| for m in all_results: |
| w.writerow(asdict(m)) |
|
|
| with open(out_json, "w", encoding="utf-8") as f: |
| json.dump([asdict(m) for m in all_results], f, indent=2) |
|
|
| print(f"\nResults saved: {out_csv}, {out_json}") |
|
|
| |
| print("\n" + "=" * 120) |
| print(f"{'Name':<30} {'Source':<16} {'V':>6} {'Fert':>7} {'F_ar':>7} {'F_az':>7} {'Disp':>7} {'CPT_ar':>7} {'CPT_az':>7} {'EM_ar':>7} {'EM_az':>7}") |
| print("-" * 120) |
| for m in sorted(all_results, key=lambda x: (x.source, x.vocab_size)): |
| print(f"{m.name:<30} {m.source:<16} {m.vocab_size:>6,} {m.fertility_overall:>7.3f} {m.fertility_ar:>7.3f} {m.fertility_az:>7.3f} {m.disparity:>7.3f} {m.cpt_ar:>7.3f} {m.cpt_az:>7.3f} {m.exact_match_ar:>7.2%} {m.exact_match_az:>7.2%}") |
| print("=" * 120) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|