File size: 6,656 Bytes
186cb4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3 -u
"""bootstrap_test_set.py — Bootstrap 95% CIs on test set, for Table 5 consistency."""

import json, os, sys, csv, gc, warnings
from dataclasses import dataclass, asdict
from collections import Counter
from typing import List

import numpy as np
import regex
warnings.filterwarnings("ignore")

BASE = "/root/oiq_cc_tokenizer/results"
CORPORA = os.path.join(BASE, "corpora")
TOK_DIR = os.path.join(BASE, "tokenizers")

_WORD_PAT = regex.compile(r"[\p{L}\p{M}\p{N}]+", regex.UNICODE)
_AR_PAT = regex.compile(r"[\u0600-\u06FF\u0750-\u077F]")
_SPECIAL = {"<unk>", "<s>", "</s>", "[CLS]", "[SEP]", "[PAD]", "[UNK]", "<pad>", ""}

def segment_words(t): return _WORD_PAT.findall(t)
def count_graphemes(t): return len(regex.findall(r"\X", t))
def detect_script(t): return "ar" if len(_AR_PAT.findall(t)) > len(t) * 0.3 else "az"
def filter_sp(tokens): return [t for t in tokens if t not in _SPECIAL]


class RawConcat:
    def __init__(self, ar_j, az_j):
        from tokenizers import Tokenizer
        self.ar = Tokenizer.from_file(ar_j)
        self.az = Tokenizer.from_file(az_j)
    def encode(self, text):
        s = detect_script(text)
        t = self.ar if s == "ar" else self.az
        enc = t.encode(text)
        return enc.tokens, enc.ids, s

class RawShared:
    def __init__(self, j):
        from tokenizers import Tokenizer
        self.tok = Tokenizer.from_file(j)
    def encode(self, text):
        enc = self.tok.encode(text)
        return enc.tokens, enc.ids, detect_script(text)


def precompute_metrics(texts):
    """Compute per-text fertility and CPT for bootstrap resampling."""
    words_per_text = [segment_words(t) for t in texts]
    graphemes_per_text = [count_graphemes(t) for t in texts]
    return words_per_text, graphemes_per_text


def bootstrap_ci(tok, texts, words_per_text, graphemes_per_text, n_bootstrap=500):
    """Pre-compute per-text metrics once, then resample."""
    n = len(texts)
    # Pre-compute per-text fertility and CPT
    per_text_fert = []
    per_text_cpt = []
    valid_mask = []
    for i, text in enumerate(texts):
        w = words_per_text[i]
        if not w:
            valid_mask.append(False)
            per_text_fert.append(0)
            per_text_cpt.append(0)
            continue
        try:
            tokens, ids, script = tok.encode(text)
            content = filter_sp(tokens)
            fert = len(content) / len(w)
            cpt = graphemes_per_text[i] / max(len(content), 1)
            valid_mask.append(True)
            per_text_fert.append(fert)
            per_text_cpt.append(cpt)
        except:
            valid_mask.append(False)
            per_text_fert.append(0)
            per_text_cpt.append(0)

    valid_idx = np.where(valid_mask)[0]
    fert_arr = np.array([per_text_fert[i] for i in valid_idx])
    cpt_arr = np.array([per_text_cpt[i] for i in valid_idx])
    n_valid = len(valid_idx)

    fert_samples = []
    cpt_samples = []
    rng = np.random.RandomState(42)
    for _ in range(n_bootstrap):
        idx = rng.choice(n_valid, size=n_valid, replace=True)
        fert_samples.append(np.mean(fert_arr[idx]))
        cpt_samples.append(np.mean(cpt_arr[idx]))

    point_fert = float(np.mean(fert_arr))
    point_cpt = float(np.mean(cpt_arr))
    fert_lo, fert_hi = float(np.percentile(fert_samples, 2.5)), float(np.percentile(fert_samples, 97.5))
    cpt_lo, cpt_hi = float(np.percentile(cpt_samples, 2.5)), float(np.percentile(cpt_samples, 97.5))
    return point_fert, fert_lo, fert_hi, point_cpt, cpt_lo, cpt_hi


def main():
    texts = []
    for s in ("test_ar", "test_az", "test_mi"):
        p = os.path.join(CORPORA, f"{s}.txt")
        if os.path.exists(p):
            with open(p) as f:
                texts.extend(l.strip() for l in f if l.strip())
    print(f"{len(texts)} test texts", flush=True)

    words_per_text, graphemes_per_text = precompute_metrics(texts)

    results = []
    for vsz in (8000, 16000, 32000):
        for algo in ("bpe", "unigram", "wordpiece", "bbpe"):
            jpath = os.path.join(TOK_DIR, f"shared_{algo}_{vsz}.json")
            if os.path.exists(jpath):
                name = f"shared_{algo}_{vsz}"
                print(f"\n{name}", flush=True)
                tok = RawShared(jpath)
                r = bootstrap_ci(tok, texts, words_per_text, graphemes_per_text)
                print(f"  F={r[0]:.4f} [{r[1]:.4f}, {r[2]:.4f}]  CPT={r[3]:.3f} [{r[4]:.3f}, {r[5]:.3f}]", flush=True)
                results.append({"name": name, **dict(zip(["fert","fert_lo","fert_hi","cpt","cpt_lo","cpt_hi"], r))})
                del tok; gc.collect()

            ar_j = os.path.join(TOK_DIR, f"concat_ar_{algo}_{vsz//2}.json")
            az_j = os.path.join(TOK_DIR, f"concat_az_{algo}_{vsz//2}.json")
            if os.path.exists(ar_j) and os.path.exists(az_j):
                name = f"concat_{algo}_{vsz}"
                print(f"\n{name}", flush=True)
                tok = RawConcat(ar_j, az_j)
                r = bootstrap_ci(tok, texts, words_per_text, graphemes_per_text)
                print(f"  F={r[0]:.4f} [{r[1]:.4f}, {r[2]:.4f}]  CPT={r[3]:.3f} [{r[4]:.3f}, {r[5]:.3f}]", flush=True)
                results.append({"name": name, **dict(zip(["fert","fert_lo","fert_hi","cpt","cpt_lo","cpt_hi"], r))})
                del tok; gc.collect()

    # Verify consistency with test_set_results.csv
    print("\n--- Consistency check ---", flush=True)
    import csv as csv_mod
    test_results = {}
    with open(os.path.join(BASE, "test_set_results.csv")) as f:
        for row in csv_mod.DictReader(f):
            test_results[row["name"]] = row

    print(f"{'Name':<25} {'Table5_F':>8} {'TestCSV_F':>8} {'Match':>5} {'Table5_CPT':>8} {'TestCSV_CPT':>8} {'Match':>5}", flush=True)
    for r in results:
        csv_r = test_results.get(r["name"])
        if csv_r:
            f_match = abs(float(r["fert"]) - float(csv_r["fertility_overall"])) < 0.001
            c_match = abs(float(r["cpt"]) - float(csv_r["cpt_overall"])) < 0.01
            print(f"{r['name']:<25} {r['fert']:>8.4f} {float(csv_r['fertility_overall']):>8.4f} {'OK' if f_match else 'MISMATCH':>5} {r['cpt']:>8.3f} {float(csv_r['cpt_overall']):>8.3f} {'OK' if c_match else 'MISMATCH':>5}", flush=True)

    # Save
    out = os.path.join(BASE, "bootstrap_ci_test_set.csv")
    with open(out, "w", newline="") as f:
        w = csv_mod.DictWriter(f, fieldnames=["name","fert","fert_lo","fert_hi","cpt","cpt_lo","cpt_hi"])
        w.writeheader()
        for r in results: w.writerow(r)
    print(f"\nSaved: {out}", flush=True)
    print("DONE!", flush=True)


if __name__ == "__main__":
    main()