thekusaldarshana commited on
Commit
e51bea7
·
1 Parent(s): e59ea28

Seperate Before you Compress

Browse files
Files changed (6) hide show
  1. EVALUATION.md +7 -7
  2. README.md +81 -53
  3. encoder.py +1 -1
  4. gpe_trainer.py +781 -0
  5. meta_config.json +0 -8
  6. router.py +5 -25
EVALUATION.md CHANGED
@@ -1,5 +1,5 @@
1
  ================================================================================
2
- BATTERY 1: SINHALA LINGUISTIC COMPLEXITY (2,000 Edge-Case Words)
3
  ================================================================================
4
 
5
  Category Total Pass Fail
@@ -110,7 +110,7 @@ BATTERY 1: SINHALA LINGUISTIC COMPLEXITY (2,000 Edge-Case Words)
110
  Result: PASS — Tested 500 complex words. Violations: 0, Leading-space violations: 0
111
 
112
  ================================================================================
113
- BATTERY 2: GLITCHED TOKEN DETECTION (v2 Multi-Script)
114
  ================================================================================
115
  Total unified vocab size: 328,020 (SGPE component: 128,001)
116
  Zero-usage SGPE tokens: 1,394
@@ -119,7 +119,7 @@ BATTERY 2: GLITCHED TOKEN DETECTION (v2 Multi-Script)
119
  Result: PASS — Zero: 1394, Near-Zero: 3163, Glitched: 0
120
 
121
  ================================================================================
122
- BATTERY 3: FRONTIER BENCHMARKING (V2 STRATIFIED)
123
  ================================================================================
124
 
125
  1. Tokenization Anatomy (Visual Examples)
@@ -216,7 +216,7 @@ BATTERY 5: BOUNDARY & LEADING SPACE EDGE-CASES
216
  Result: PASS — Violations: 0
217
 
218
  ================================================================================
219
- BATTERY 6: ZERO-BREAKAGE GUARANTEE
220
  ================================================================================
221
  Testing all C + HAL + ZWJ + C pairs...
222
  Testing C + HAL + C pairs (implicit conjuncts)...
@@ -230,7 +230,7 @@ BATTERY 6: ZERO-BREAKAGE GUARANTEE
230
  Result: PASS — Ran 1,703 exhaustive breakage tests. Violations: 0
231
 
232
  ================================================================================
233
- BATTERY 6: ZERO-BREAKAGE GUARANTEE (v2 Multi-Script)
234
  ================================================================================
235
  Testing Devanagari C + HAL + C pairs (implicit conjuncts)...
236
  Testing Devanagari C + vowel_sign...
@@ -238,7 +238,7 @@ BATTERY 6: ZERO-BREAKAGE GUARANTEE (v2 Multi-Script)
238
  Testing Devanagari C + anusvara / visarga / chandrabindu...
239
  Testing Devanagari C + vowel_sign + modifier...
240
 
241
- Result: PASS — Devanagari Violations: 0
242
 
243
  ================================================================================
244
  BATTERY 7: DEVANAGARI LINGUISTIC COMPLEXITY
@@ -280,7 +280,7 @@ BATTERY 8: CODE-SWITCHING INTEGRITY
280
  Result: PASS — Tested 13 code-switching cases. Violations: 0, Crashes: 0
281
 
282
  ================================================================================
283
- BATTERY 9: META-VOCAB ROUND-TRIP (SGPEMetaEncoder)
284
  ================================================================================
285
 
286
  Sentences: 1,499,950
 
1
  ================================================================================
2
+ BATTERY 1: SINHALA LINGUISTIC COMPLEXITY
3
  ================================================================================
4
 
5
  Category Total Pass Fail
 
110
  Result: PASS — Tested 500 complex words. Violations: 0, Leading-space violations: 0
111
 
112
  ================================================================================
113
+ BATTERY 2: GLITCHED TOKEN DETECTION
114
  ================================================================================
115
  Total unified vocab size: 328,020 (SGPE component: 128,001)
116
  Zero-usage SGPE tokens: 1,394
 
119
  Result: PASS — Zero: 1394, Near-Zero: 3163, Glitched: 0
120
 
121
  ================================================================================
122
+ BATTERY 3: FRONTIER BENCHMARKING
123
  ================================================================================
124
 
125
  1. Tokenization Anatomy (Visual Examples)
 
216
  Result: PASS — Violations: 0
217
 
218
  ================================================================================
219
+ BATTERY 6: ZERO-BREAKAGE GUARANTEE (Sinhala)
220
  ================================================================================
221
  Testing all C + HAL + ZWJ + C pairs...
222
  Testing C + HAL + C pairs (implicit conjuncts)...
 
230
  Result: PASS — Ran 1,703 exhaustive breakage tests. Violations: 0
231
 
232
  ================================================================================
233
+ BATTERY 6B: ZERO-BREAKAGE GUARANTEE (Devanagari)
234
  ================================================================================
235
  Testing Devanagari C + HAL + C pairs (implicit conjuncts)...
236
  Testing Devanagari C + vowel_sign...
 
238
  Testing Devanagari C + anusvara / visarga / chandrabindu...
239
  Testing Devanagari C + vowel_sign + modifier...
240
 
241
+ Result: PASS — Ran 1,078 exhaustive breakage tests. Violations: 0
242
 
243
  ================================================================================
244
  BATTERY 7: DEVANAGARI LINGUISTIC COMPLEXITY
 
280
  Result: PASS — Tested 13 code-switching cases. Violations: 0, Crashes: 0
281
 
282
  ================================================================================
283
+ BATTERY 9: META-VOCAB ROUND-TRIP
284
  ================================================================================
285
 
286
  Sentences: 1,499,950
README.md CHANGED
@@ -1,93 +1,121 @@
1
  ---
2
  license: apache-2.0
3
  datasets:
4
- - polyglots/MADLAD_CulturaX_cleaned
5
  language:
6
  - si
 
 
7
  pipeline_tag: feature-extraction
8
  library_name: transformers
9
  tags:
10
  - tokenizer
 
11
  - SGPE
12
  - linguis_trie
13
  - token
14
  - tokenization
 
15
  - remeinium
16
  - transformer
17
  - linguistics
18
  - NLP
19
  - sinhala
 
 
20
  - BPE
21
  - GPE
22
  model-index:
23
- - name: SGPE-Sinhala
24
  results:
25
  - task:
26
  type: feature-extraction
27
  dataset:
28
- name: MADLAD-400 (CulturaX Cleaned Sinhala subset)
29
- type: polyglots/MADLAD_CulturaX_cleaned
30
  metrics:
31
- - name: Token-to-Word Ratio (TWR)
32
  type: twr
33
- value: 1.438
34
  verified: false
35
- - name: Characters per Token (CPT)
36
- type: cpt
37
- value: 4.48
 
 
 
 
38
  verified: false
39
  ---
40
- # Syllable is the Token: SGPE - Syllable-Aware Grapheme Pair Encoding
41
 
42
- **Remeinium Research**
43
- [remeinium.com](https://remeinium.com) | [Paper](https://arxiv.org/abs/...) | [Tokenizer](https://huggingface.co/remeinium/SGPE-Tokenizer) | [Dataset](https://huggingface.co/datasets/remeinium/SGPE_Cleaned)
44
 
45
- ---
46
 
47
  ## The Next Architectural Primitive in Tokenization
48
 
49
  Large language models remain linguistically blind to Abugida scripts. Byte-Pair Encoding and its descendants routinely shatter complex conjuncts — atomic multi-codepoint grapheme clusters that constitute the fundamental phonetic units of Indic and Southeast Asian writing systems — into meaningless sub-character fragments. The result is degraded reasoning, inflated inference costs, and a systemic “Token Tax” that disproportionately burdens more than one billion speakers.
50
 
51
- **SGPE introduces the clean separation of concerns the field has been missing.**
52
-
53
- **Layer 1 (LinguisTrie)** enforces linguistic integrity by construction: a deterministic $O(N)$ finite automaton segments raw Unicode into well-formed syllables with a formal zero-breakage guarantee.
54
- **Layer 2 (GPE)** then performs statistical pair merging exclusively over this linguistically sound stream, inheriting the guarantee by design.
55
-
56
- Sinhala serves as the high-complexity proof-of-concept. The same architecture generalizes directly to Devanagari, Tamil, Khmer, Myanmar, and the broader Abugida family through script-specific character-class mappings and conjunct rules.
57
-
58
- ---
59
-
60
- ## Results on 59.3 Million Characters
61
 
62
- | Tokenizer | TWR ↓ | Tokens | Chars/Token | Reduction vs SGPE |
63
- |------------------------|---------|-------------|---------------|-------------------|
64
- | **SGPE (ours)** | **1.438** | **13.26 M** | **4.48** | — |
65
- | OpenAI o200k_base | 3.515 | 32.39 M | 1.83 | 59.1 % |
66
- | Llama 4 Scout | 3.673 | 33.85 M | 1.75 | 60.8 % |
67
- | DeepSeek V3 | 5.965 | 54.98 M | 1.08 | 75.8 % |
68
 
69
- - **Zero-Breakage Guarantee** validated on 1,703 exhaustive conjunct formations (0 violations).
70
- - Full-corpus round-trip reconstruction: 0 non-UNK mismatches.
71
- - UNK rate: 0.46 % (rare compounds only; no structural errors).
 
 
 
72
 
73
- SGPE reclaims more than half the context window for Abugida text while preserving perfect orthographic and semantic integrity.
74
 
75
  ---
76
 
77
- ## Architecture
78
-
79
- SGPE is deliberately bimodal:
80
-
81
- 1. **LinguisTrie (Layer 1)**
82
- Deterministic finite automaton operating in a single left-to-right pass with constant-time transitions and $O(1)$ auxiliary space. Guarantees that no conjunct, pili, virama, or ZWJ sequence is ever fragmented.
83
-
84
- 2. **Grapheme Pair Encoding (Layer 2)**
85
- Standard BPE performed exclusively on the atomic syllable stream, with three critical constraints:
86
- - Syllabic initialization (base vocabulary consists of linguistically valid units)
87
- - Boundary-aware scoping (merges restricted to within-word spans)
88
- - Frequency pruning (rare syllables mapped to [UNK] sentinel before merging)
89
-
90
- The decoupling is the core scientific contribution: linguistic correctness is enforced by construction rather than hoped for statistically.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  ---
93
 
@@ -108,10 +136,12 @@ print(tokenizer.encode(text))
108
 
109
  ## Resources
110
 
111
- - **Research Paper**: “The Syllable is the Token: Breaking the Token Tax with SGPE” (Remeinium Research, February 2026)
112
- - **Pre-trained Tokenizer**: [Hugging Face](https://huggingface.co/remeinium/SGPE-Tokenizer)
113
- - **Cleaned Training Corpus**: [Hugging Face](https://huggingface.co/datasets/remeinium/SGPE_Cleaned)
114
- - **Full Code & Evaluation Harness**: [GitHub](https://github.com/remeinium/SGPE)
 
 
115
 
116
  ---
117
 
@@ -119,8 +149,6 @@ print(tokenizer.encode(text))
119
 
120
  Apache License 2.0 — see [LICENSE](LICENSE).
121
 
122
- ---
123
-
124
  **Remeinium Research | Remeinium AI | Intelligence for a Greater Tomorrow**
125
 
126
  ---
 
1
  ---
2
  license: apache-2.0
3
  datasets:
4
+ - Remeinium/WWHO_30m
5
  language:
6
  - si
7
+ - hi
8
+ - en
9
  pipeline_tag: feature-extraction
10
  library_name: transformers
11
  tags:
12
  - tokenizer
13
+ - WWHO
14
  - SGPE
15
  - linguis_trie
16
  - token
17
  - tokenization
18
+ - Syllable
19
  - remeinium
20
  - transformer
21
  - linguistics
22
  - NLP
23
  - sinhala
24
+ - hindi
25
+ - english
26
  - BPE
27
  - GPE
28
  model-index:
29
+ - name: WWHO
30
  results:
31
  - task:
32
  type: feature-extraction
33
  dataset:
34
+ name: WWHO_30m
35
+ type: Remeinium/WWHO_30m
36
  metrics:
37
+ - name: Token-to-Word Ratio (TWR) - Sinhala
38
  type: twr
39
+ value: 1.274
40
  verified: false
41
+ - name: Token-to-Word Ratio (TWR) - Hindi
42
+ type: twr
43
+ value: 1.181
44
+ verified: false
45
+ - name: Token-to-Word Ratio (TWR) - Overall
46
+ type: twr
47
+ value: 1.240
48
  verified: false
49
  ---
50
+ # Separate before you Compress
51
 
52
+ <!-- **Remeinium Research**
53
+ [remeinium.com](https://remeinium.com) | [Paper](https://arxiv.org/abs/...) | [Tokenizer](https://huggingface.co/remeinium/WWHO) | [Dataset](https://huggingface.co/datasets/remeinium/WWHO_Cleaned_30m)
54
 
55
+ --- -->
56
 
57
  ## The Next Architectural Primitive in Tokenization
58
 
59
  Large language models remain linguistically blind to Abugida scripts. Byte-Pair Encoding and its descendants routinely shatter complex conjuncts — atomic multi-codepoint grapheme clusters that constitute the fundamental phonetic units of Indic and Southeast Asian writing systems — into meaningless sub-character fragments. The result is degraded reasoning, inflated inference costs, and a systemic “Token Tax” that disproportionately burdens more than one billion speakers.
60
 
61
+ **WWHO (Where-What-How Often) introduces the clean separation of concerns the field has been missing.**
 
 
 
 
 
 
 
 
 
62
 
63
+ By decoupling linguistic structural constraints from statistical compression, WWHO builds a unified meta-vocabulary space:
 
 
 
 
 
64
 
65
+ 1. **Layer 1 (Where): Code-Switching Router**
66
+ A linear $O(N)$ block scanner that evaluates characters in $O(1)$ time to inherently identify script boundaries, routing Latin text to proven frontier tokenizers (like `o200k_base`) while sending Abugida text for specialized processing.
67
+ 2. **Layer 2 (What): LinguisTrie**
68
+ Enforces linguistic integrity by construction: a DFA based syllabifier segments raw Unicode into well-formed syllables with a formal zero-breakage guarantee.
69
+ 3. **Layer 3 (How Often): SGPE & Meta-Vocabulary**
70
+ Performs statistical pair merging exclusively over this linguistically sound stream, safely projecting the resulting tokens into a unified, mathematically offset ID space.
71
 
72
+ Sinhala and Devanagari serve as the high-complexity proofs-of-concept. The same architecture generalizes directly to Tamil, Khmer, Myanmar, and the broader Abugida family.
73
 
74
  ---
75
 
76
+ ## Multi-Script Stratified Benchmarks (122.2M Characters)
77
+
78
+ We evaluated WWHO against frontier models across a 1.5 million sentence code-switched corpus containing Sinhala, Hindi (Devanagari), and English.
79
+
80
+ ### 1. Sinhala Efficiency
81
+ |Tokenizer | Tokens | TWR | Chr/Tok | % Reduction
82
+ |----------------------------------------------------------------------
83
+ |**SGPE(WWHO) | 6,654,288 | 1.274 | 4.83 | -**
84
+ |OpenAI (o200k_base) | 17,360,196 | 3.324 | 1.85 | 61.7%
85
+ |Llama 4 Scout | 18,157,707 | 3.476 | 1.77 | 63.4%
86
+ |DeepSeek V3 | 29,152,698 | 5.581 | 1.10 | 77.2%
87
+
88
+ ### 2. Hindi (Devanagari) Efficiency
89
+ |Tokenizer | Tokens | TWR | Chr/Tok | % Reduction
90
+ |----------------------------------------------------------------------
91
+ |**SGPE(WWHO) | 13,433,554 | 1.181 | 4.29 | -**
92
+ |OpenAI (o200k_base) | 18,394,075 | 1.617 | 3.13 | 27.0%
93
+ |Llama 4 Scout | 19,566,121 | 1.720 | 2.94 | 31.3%
94
+ |DeepSeek V3 | 31,682,218 | 2.786 | 1.82 | 57.6%
95
+
96
+ ### 3. English
97
+ |Tokenizer | Tokens | TWR | Chr/Tok | % Reduction
98
+ |----------------------------------------------------------------------
99
+ |**SGPE(WWHO) | 7,240,147 | 1.330 | 4.46 | -**
100
+ |OpenAI (o200k_base) | 7,420,527 | 1.364 | 4.35 | 2.4%
101
+ |Llama 4 Scout | 7,512,843 | 1.381 | 4.30 | 3.6%
102
+ |DeepSeek V3 | 7,904,670 | 1.453 | 4.09 | 8.4%
103
+
104
+ *(Note: Because WWHO routes Latin text directly to the native Tiktoken sequence, English performance is mathematically identical. The minor delta in total tokens emerges solely from boundary crossing mechanics.)*
105
+
106
+ ### 4. Overall (Mixed-Script)
107
+ |Tokenizer | Tokens | TWR | Chr/Tok | % Reduction
108
+ |----------------------------------------------------------------------
109
+ |**SGPE(WWHO) | 27,327,989 | 1.240 | 4.47 | -**
110
+ |OpenAI (o200k_base) | 43,174,798 | 1.959 | 2.83 | 36.7%
111
+ |Llama 4 Scout | 45,236,671 | 2.053 | 2.70 | 39.6%
112
+ |DeepSeek V3 | 68,739,586 | 3.119 | 1.78 | 60.2%
113
+
114
+ - **Zero-Breakage Guarantee**: Validated through exhaustive testing permutations across all supported Abugida scripts (0 violations).
115
+ - **Full-corpus reconstruction**: 1.5M code-switched sentences encoded and decoded with 0 non-UNK mismatches.
116
+ - **UNK rate**: 0.08 % (restricted strictly to rare compounds without violating structural boundaries).
117
+
118
+ WWHO radically compresses the context window for Abugida text, effectively ending the Token Tax without penalizing existing state-of-the-art programming and reasoning capabilities.
119
 
120
  ---
121
 
 
136
 
137
  ## Resources
138
 
139
+ <!--
140
+ - **Research Paper**: “The Syllable is the Token: Breaking the Token Tax with SGPE” (Remeinium Research, February 2026) -->
141
+ - **Pre-trained Tokenizer**: [Hugging Face](https://huggingface.co/remeinium/WWHO)
142
+ - **Cleaned Training Corpus**: [Hugging Face](https://huggingface.co/datasets/remeinium/WWHO_30m)
143
+ - **Full Code & Evaluation Harness**: [GitHub](https://github.com/remeinium/WWHO)
144
+
145
 
146
  ---
147
 
 
149
 
150
  Apache License 2.0 — see [LICENSE](LICENSE).
151
 
 
 
152
  **Remeinium Research | Remeinium AI | Intelligence for a Greater Tomorrow**
153
 
154
  ---
encoder.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  ==========================================
3
- WWHO Encoder (Unified Meta-Vocabulary)
4
  ==========================================
5
  """
6
 
 
1
  """
2
  ==========================================
3
+ WWHO Encoder
4
  ==========================================
5
  """
6
 
gpe_trainer.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WWHO(SGPE) GPE Trainer
3
+ """
4
+
5
+ import argparse
6
+ import gc
7
+ import heapq
8
+ import json
9
+ import logging
10
+ import os
11
+ import pickle
12
+ import re
13
+ import time
14
+ from collections import Counter, defaultdict
15
+ from multiprocessing import Pool, cpu_count
16
+
17
+ from tqdm import tqdm
18
+
19
+ from router import CodeSwitchSegmenter
20
+ from export import export_hf_tokenizer
21
+
22
+ # ─── Logging ──────
23
+
24
+ try:
25
+ import psutil as _psutil
26
+ def _ram_mb() -> str:
27
+ p = _psutil.Process()
28
+ rss = p.memory_info().rss / 1024**2
29
+ avail = _psutil.virtual_memory().available / 1024**2
30
+ return f"RSS={rss:.0f}MB avail={avail:.0f}MB"
31
+ except ImportError:
32
+ def _ram_mb() -> str:
33
+ try:
34
+ with open("/proc/meminfo") as f:
35
+ info = {l.split(":")[0]: int(l.split()[1])
36
+ for l in f if ":" in l}
37
+ avail = info.get("MemAvailable", 0) // 1024
38
+ return f"avail={avail}MB"
39
+ except Exception:
40
+ return "ram=N/A"
41
+
42
+ _logger: logging.Logger | None = None
43
+
44
+ def _log(msg: str):
45
+ full = f"[{time.strftime('%H:%M:%S')}] [{_ram_mb()}] {msg}"
46
+ print(full, flush=True)
47
+ if _logger:
48
+ _logger.info(full)
49
+
50
+ def _setup_logging(output_dir: str):
51
+ global _logger
52
+ os.makedirs(output_dir, exist_ok=True)
53
+ log_path = os.path.join(output_dir, "training.log")
54
+ logging.basicConfig(
55
+ filename=log_path,
56
+ level=logging.INFO,
57
+ format="%(message)s",
58
+ )
59
+ _logger = logging.getLogger("wwho_trainer")
60
+ _log(f"Log started: {log_path}")
61
+
62
+
63
+ SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
64
+
65
+ # ─── Multiprocessing ──────
66
+ _worker_segmenter: CodeSwitchSegmenter | None = None
67
+ _worker_dfa_map: dict | None = None
68
+ _worker_script_mode: str = "mixed"
69
+
70
+
71
+ def _init_worker(script_mode: str):
72
+ global _worker_segmenter, _worker_dfa_map, _worker_script_mode
73
+ from linguis_trie import load_dfa_map
74
+
75
+ _worker_script_mode = script_mode
76
+ _worker_dfa_map = load_dfa_map(script_mode)
77
+
78
+ language_blocks = {lang: dfa.unicode_blocks for lang, dfa in _worker_dfa_map.items()}
79
+ _worker_segmenter = CodeSwitchSegmenter(language_blocks)
80
+
81
+
82
+ def _pretokenize_line(text: str) -> list[str]:
83
+ tokens: list[str] = []
84
+
85
+ for seg in _worker_segmenter.segment(text):
86
+ if seg.language == "latin":
87
+ tokens.append(seg.text)
88
+ else:
89
+ dfa = _worker_dfa_map.get(seg.language)
90
+ if not dfa:
91
+ tokens.append(seg.text)
92
+ continue
93
+ syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space)
94
+ tokens.extend(syllables)
95
+
96
+ return tokens
97
+
98
+
99
+ def _is_boundary_token(token: str) -> bool:
100
+ for ch in token:
101
+ if _worker_segmenter:
102
+ lang = _worker_segmenter._get_char_language(ch)
103
+ if lang is not None and lang != "latin":
104
+ return False
105
+ return True
106
+
107
+ def segment_into_words(syllables: list[str]) -> list[list[str]]:
108
+ words: list[list[str]] = []
109
+ current: list[str] = []
110
+
111
+ for tok in syllables:
112
+ if _is_boundary_token(tok):
113
+ if current:
114
+ words.append(current)
115
+ current = []
116
+ words.append([tok])
117
+ else:
118
+ if tok[0] in (' ', '\t', '\n', '\r') and current:
119
+ words.append(current)
120
+ current = []
121
+ current.append(tok)
122
+
123
+ if current:
124
+ words.append(current)
125
+ return words
126
+
127
+
128
+ # ─── Symbol Table ──────
129
+
130
+ class SymbolTable:
131
+ def __init__(self):
132
+ self._str_to_id: dict[str, int] = {}
133
+ self._id_to_str: list[str] = []
134
+
135
+ def get_or_add(self, token: str) -> int:
136
+ if token in self._str_to_id:
137
+ return self._str_to_id[token]
138
+ new_id = len(self._id_to_str)
139
+ self._str_to_id[token] = new_id
140
+ self._id_to_str.append(token)
141
+ return new_id
142
+
143
+ def add_merged(self, a_id: int, b_id: int) -> int:
144
+ merged_str = self._id_to_str[a_id] + self._id_to_str[b_id]
145
+ return self.get_or_add(merged_str)
146
+
147
+ def to_str(self, token_id: int) -> str:
148
+ return self._id_to_str[token_id]
149
+
150
+ def to_id(self, token: str) -> int | None:
151
+ return self._str_to_id.get(token)
152
+
153
+ def __len__(self) -> int:
154
+ return len(self._id_to_str)
155
+
156
+
157
+ # ─── GPETrainer ──────
158
+
159
+ class GPETrainer:
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_size: int = 128_000,
164
+ min_freq: int = 2,
165
+ num_workers: int | None = None,
166
+ checkpoint_every: int = 20_000,
167
+ prune_freq: int = 100,
168
+ script_mode: str = "mixed",
169
+ ):
170
+ self.target_vocab_size = vocab_size
171
+ self.min_freq = min_freq
172
+ self.num_workers = num_workers or max(1, cpu_count() - 1)
173
+ self.checkpoint_every = checkpoint_every
174
+ self.prune_freq = prune_freq
175
+ self.script_mode = script_mode
176
+ self.merges: list[tuple[int, int]] = []
177
+ self.symbols = SymbolTable()
178
+
179
+ def stream_and_count(
180
+ self, train_file: str, output_dir: str = "output"
181
+ ) -> tuple[Counter, set[str]]:
182
+ # ── 1. Count lines ──────
183
+ print(" counting lines...", end=" ", flush=True)
184
+ with open(train_file, "r", encoding="utf-8") as f:
185
+ num_lines = sum(1 for _ in f)
186
+ print(f"{num_lines:,}")
187
+
188
+ CHUNK_SIZE = 5_000_000
189
+ BATCH = 4_096
190
+
191
+ partial_dir = os.path.join(output_dir, "_partial_counters")
192
+ os.makedirs(partial_dir, exist_ok=True)
193
+
194
+ _init_worker(self.script_mode)
195
+
196
+ total_lines = 0
197
+ chunk_idx = 0
198
+ partial_paths: list[str] = []
199
+
200
+ PARTIAL_PRUNE = 2
201
+ def _save_partial(counter: Counter, idx: int, n_sent: int):
202
+ if PARTIAL_PRUNE > 1:
203
+ to_save = Counter(
204
+ {k: v for k, v in counter.items() if v >= PARTIAL_PRUNE}
205
+ )
206
+ else:
207
+ to_save = counter
208
+ pkl_path = os.path.join(partial_dir, f"partial_{idx:04d}.pkl")
209
+ with open(pkl_path, "wb") as pf:
210
+ pickle.dump(to_save, pf, protocol=pickle.HIGHEST_PROTOCOL)
211
+ partial_paths.append(pkl_path)
212
+ pkl_mb = os.path.getsize(pkl_path) / 1024**2
213
+ pbar.write(
214
+ f" chunk {idx+1} done: {n_sent:,} sent "
215
+ f"-> {len(to_save):,} word types (pruned from {len(counter):,}) "
216
+ f"-> {pkl_path} ({pkl_mb:.0f} MB)"
217
+ )
218
+ _log(f"CHUNK {idx+1} saved: {n_sent:,} sent, "
219
+ f"{len(to_save):,} word types, {pkl_mb:.0f} MB")
220
+ del to_save
221
+ counter.clear()
222
+ gc.collect()
223
+ _log(f"CHUNK {idx+1} post-gc")
224
+
225
+ chunk_counter: Counter = Counter()
226
+ chunk_sent = 0
227
+ batch_buf: list[str] = []
228
+
229
+ pool = Pool(
230
+ processes=self.num_workers,
231
+ initializer=_init_worker,
232
+ initargs=(self.script_mode,),
233
+ )
234
+
235
+ with open(train_file, "r", encoding="utf-8") as f:
236
+ pbar = tqdm(f, total=num_lines, unit=" sent",
237
+ desc=f" pre-tokenizing [chunk 1]")
238
+
239
+ for raw_line in pbar:
240
+ try:
241
+ obj = json.loads(raw_line)
242
+ text = obj.get("text", "").strip()
243
+ except json.JSONDecodeError:
244
+ text = raw_line.strip()
245
+ if not text:
246
+ continue
247
+
248
+ batch_buf.append(text)
249
+ total_lines += 1
250
+ chunk_sent += 1
251
+
252
+ if len(batch_buf) >= BATCH:
253
+ self._process_batch(pool, batch_buf, chunk_counter)
254
+ batch_buf = []
255
+ if chunk_sent >= CHUNK_SIZE:
256
+ if batch_buf:
257
+ self._process_batch(pool, batch_buf, chunk_counter)
258
+ batch_buf = []
259
+ pool.close()
260
+ pool.join()
261
+ pool = None
262
+ gc.collect()
263
+
264
+ _save_partial(chunk_counter, chunk_idx, chunk_sent)
265
+ chunk_idx += 1
266
+ chunk_sent = 0
267
+ pbar.set_description(
268
+ f" pre-tokenizing [chunk {chunk_idx + 1}]"
269
+ )
270
+ gc.collect()
271
+
272
+ pool = Pool(
273
+ processes=self.num_workers,
274
+ initializer=_init_worker,
275
+ initargs=(self.script_mode,),
276
+ )
277
+
278
+ if batch_buf:
279
+ self._process_batch(pool, batch_buf, chunk_counter)
280
+ pool.close()
281
+ pool.join()
282
+ gc.collect()
283
+
284
+ if chunk_counter:
285
+ _save_partial(chunk_counter, chunk_idx, chunk_sent)
286
+ chunk_idx += 1
287
+
288
+ pbar.close()
289
+
290
+ print(f" {total_lines:,} sentences -> {chunk_idx} chunks processed")
291
+
292
+ # ── 3. Sequential merge with intermediate pruning ──────
293
+ _log(f"MERGE START: {len(partial_paths)} partial counters, min_freq={self.min_freq}")
294
+ N = len(partial_paths)
295
+ word_counter: Counter = Counter()
296
+ for i, pkl_path in enumerate(partial_paths):
297
+ _log(f"MERGE [{i+1}/{N}] loading {pkl_path}")
298
+ with open(pkl_path, "rb") as pf:
299
+ partial: Counter = pickle.load(pf)
300
+ _log(f"MERGE [{i+1}/{N}] loaded {len(partial):,} types, updating master...")
301
+ word_counter.update(partial)
302
+ del partial
303
+ gc.collect()
304
+ _log(f"MERGE [{i+1}/{N}] after update+gc: {len(word_counter):,} types")
305
+
306
+ remaining = N - i - 1
307
+ safe_prune = max(1, self.min_freq - remaining)
308
+ before = len(word_counter)
309
+
310
+ if safe_prune > 1:
311
+ word_counter = Counter(
312
+ {k: v for k, v in word_counter.items() if v >= safe_prune}
313
+ )
314
+
315
+ if i > 0 and i % 5 == 0:
316
+ hard_threshold = max(2, self.min_freq // 2)
317
+ word_counter = Counter(
318
+ {k: v for k, v in word_counter.items() if v >= hard_threshold}
319
+ )
320
+ _log(f"MERGE [{i+1}/{N}] HARD PRUNE TRIGGERED (threshold={hard_threshold})")
321
+
322
+ gc.collect()
323
+ pruned_n = before - len(word_counter)
324
+
325
+ if pruned_n > 0:
326
+ msg = (f" [{i+1}/{N}] merged -> {len(word_counter):,} types "
327
+ f"(pruned {pruned_n:,})")
328
+ print(msg, flush=True)
329
+ _log(f"MERGE [{i+1}/{N}] post-prune: {len(word_counter):,} types "
330
+ f"(removed {pruned_n:,})")
331
+ else:
332
+ print(f" [{i+1}/{N}] merged -> {len(word_counter):,} types", flush=True)
333
+ _log(f"MERGE [{i+1}/{N}] no prune needed, {len(word_counter):,} types")
334
+
335
+ os.remove(pkl_path)
336
+ _log(f"MERGE [{i+1}/{N}] deleted {pkl_path}")
337
+
338
+ try:
339
+ os.rmdir(partial_dir)
340
+ except OSError:
341
+ pass
342
+
343
+ n_types = len(word_counter)
344
+ n_instances = sum(word_counter.values())
345
+ print(f"\n Final: {total_lines:,} sent -> {n_types:,} word types "
346
+ f"({n_instances:,} instances)")
347
+ return word_counter, set()
348
+
349
+ def _process_batch(
350
+ self,
351
+ pool: Pool,
352
+ batch: list[str],
353
+ word_counter: Counter,
354
+ ):
355
+ syllable_streams = pool.map(_pretokenize_line, batch, chunksize=128)
356
+
357
+ for stream in syllable_streams:
358
+ words = segment_into_words(stream)
359
+ for w in words:
360
+ if not w:
361
+ continue
362
+ if not _is_boundary_token(w[0]):
363
+ word_counter[tuple(w)] += 1
364
+
365
+ @staticmethod
366
+ def compute_syllable_freqs(word_counter: Counter) -> Counter:
367
+ syl_freq: Counter[str] = Counter()
368
+ for word_tuple, word_freq in word_counter.items():
369
+ for syl in word_tuple:
370
+ syl_freq[syl] += word_freq
371
+ return syl_freq
372
+
373
+ def build_word_types(
374
+ self,
375
+ word_counter: Counter,
376
+ boundary_tokens: set[str],
377
+ syl_freq: Counter | None = None,
378
+ ) -> tuple[list[list[int]], list[int]]:
379
+ UNK_SENTINEL = -1
380
+ pruned_set: set[str] = set()
381
+
382
+ if syl_freq is not None and self.prune_freq > 0:
383
+ for syl, freq in syl_freq.items():
384
+ if freq < self.prune_freq:
385
+ pruned_set.add(syl)
386
+
387
+ word_types: list[list[int]] = []
388
+ word_freqs: list[int] = []
389
+ pruned_word_count = 0
390
+
391
+ for word_tuple, freq in word_counter.items():
392
+ ids = []
393
+ for tok in word_tuple:
394
+ if tok in pruned_set:
395
+ ids.append(UNK_SENTINEL)
396
+ else:
397
+ ids.append(self.symbols.get_or_add(tok))
398
+ word_types.append(ids)
399
+ word_freqs.append(freq)
400
+ if UNK_SENTINEL in ids:
401
+ pruned_word_count += 1
402
+
403
+ if pruned_set:
404
+ print(f" pruned {len(pruned_set):,} rare syllables (freq < {self.prune_freq})")
405
+ print(f" {pruned_word_count:,} word types contain [UNK] syllables")
406
+
407
+ return word_types, word_freqs
408
+
409
+ @staticmethod
410
+ def build_token_index(word_types: list[list[int]]) -> dict[int, set[int]]:
411
+ index: dict[int, set[int]] = defaultdict(set)
412
+ for wt_idx, wt in enumerate(word_types):
413
+ for tid in wt:
414
+ if tid >= 0:
415
+ index[tid].add(wt_idx)
416
+ return dict(index)
417
+
418
+ def count_all_pairs(
419
+ self,
420
+ word_types: list[list[int]],
421
+ word_freqs: list[int],
422
+ ) -> dict[tuple[int, int], int]:
423
+ counts: dict[tuple[int, int], int] = defaultdict(int)
424
+ for wt_idx, wt in enumerate(word_types):
425
+ f = word_freqs[wt_idx]
426
+ for i in range(len(wt) - 1):
427
+ a, b = wt[i], wt[i + 1]
428
+ if a < 0 or b < 0:
429
+ continue
430
+ counts[(a, b)] += f
431
+ return dict(counts)
432
+
433
+ @staticmethod
434
+ def _build_heap(pair_counts: dict) -> list:
435
+ heap = [(-freq, pair) for pair, freq in pair_counts.items() if freq > 0]
436
+ heapq.heapify(heap)
437
+ return heap
438
+
439
+ @staticmethod
440
+ def _heap_push(heap, pair, freq):
441
+ if freq > 0:
442
+ heapq.heappush(heap, (-freq, pair))
443
+
444
+ def _pop_best(self, heap, pair_counts):
445
+ while heap:
446
+ neg_freq, pair = heapq.heappop(heap)
447
+ actual = pair_counts.get(pair, 0)
448
+ if actual <= 0:
449
+ continue
450
+ if actual != -neg_freq:
451
+ self._heap_push(heap, pair, actual)
452
+ continue
453
+ return pair, actual
454
+ return None, 0
455
+
456
+ def merge_and_update(
457
+ self,
458
+ word_types: list[list[int]],
459
+ word_freqs: list[int],
460
+ pair: tuple[int, int],
461
+ pair_counts: dict[tuple[int, int], int],
462
+ token_index: dict[int, set[int]],
463
+ merged_id: int,
464
+ heap: list,
465
+ ) -> int:
466
+ a, b = pair
467
+ total_applied = 0
468
+ candidates = list(token_index.get(a, set()) & token_index.get(b, set()))
469
+ pair_counts.pop(pair, None)
470
+ dirty_pairs: dict[tuple[int, int], int] = {}
471
+
472
+ for wt_idx in candidates:
473
+ wt = word_types[wt_idx]
474
+ freq = word_freqs[wt_idx]
475
+ if len(wt) < 2:
476
+ continue
477
+ new_wt: list[int] = []
478
+ i = 0
479
+ changed = False
480
+
481
+ while i < len(wt):
482
+ if i + 1 < len(wt) and wt[i] == a and wt[i + 1] == b:
483
+ if new_wt and new_wt[-1] >= 0:
484
+ lp = (new_wt[-1], a)
485
+ pair_counts[lp] = pair_counts.get(lp, 0) - freq
486
+ dirty_pairs[lp] = pair_counts[lp]
487
+ if i + 2 < len(wt) and wt[i + 2] >= 0:
488
+ rp = (b, wt[i + 2])
489
+ pair_counts[rp] = pair_counts.get(rp, 0) - freq
490
+ dirty_pairs[rp] = pair_counts[rp]
491
+ new_wt.append(merged_id)
492
+ total_applied += freq
493
+ changed = True
494
+ if len(new_wt) >= 2 and new_wt[-2] >= 0:
495
+ lp = (new_wt[-2], merged_id)
496
+ pair_counts[lp] = pair_counts.get(lp, 0) + freq
497
+ dirty_pairs[lp] = pair_counts[lp]
498
+ if i + 2 < len(wt) and wt[i + 2] >= 0:
499
+ rp = (merged_id, wt[i + 2])
500
+ pair_counts[rp] = pair_counts.get(rp, 0) + freq
501
+ dirty_pairs[rp] = pair_counts[rp]
502
+ i += 2
503
+ else:
504
+ new_wt.append(wt[i])
505
+ i += 1
506
+
507
+ if changed:
508
+ word_types[wt_idx] = new_wt
509
+ if merged_id not in token_index:
510
+ token_index[merged_id] = set()
511
+ token_index[merged_id].add(wt_idx)
512
+ remaining = set(new_wt)
513
+ if a not in remaining and wt_idx in token_index.get(a, set()):
514
+ token_index[a].discard(wt_idx)
515
+ if b not in remaining and wt_idx in token_index.get(b, set()):
516
+ token_index[b].discard(wt_idx)
517
+
518
+ for tok_id in (a, b):
519
+ if tok_id in token_index and not token_index[tok_id]:
520
+ del token_index[tok_id]
521
+
522
+ for p, cnt in dirty_pairs.items():
523
+ if cnt <= 0:
524
+ pair_counts.pop(p, None)
525
+ else:
526
+ self._heap_push(heap, p, cnt)
527
+
528
+ return total_applied
529
+
530
+ def save_checkpoint(self, step: int, output_dir: str, elapsed: float):
531
+ merge_strs = [
532
+ [self.symbols.to_str(a), self.symbols.to_str(b)]
533
+ for a, b in self.merges
534
+ ]
535
+ ckpt = {
536
+ "step": step,
537
+ "script_mode": self.script_mode,
538
+ "merges": merge_strs,
539
+ "elapsed_seconds": round(elapsed, 1),
540
+ }
541
+ path = os.path.join(output_dir, f"checkpoint_{step}.json")
542
+ with open(path, "w", encoding="utf-8") as f:
543
+ json.dump(ckpt, f, ensure_ascii=False)
544
+ size_mb = os.path.getsize(path) / (1024 * 1024)
545
+ return path, size_mb
546
+
547
+ def load_checkpoint(self, ckpt_path: str):
548
+ with open(ckpt_path, "r", encoding="utf-8") as f:
549
+ ckpt = json.load(f)
550
+ print(f" loaded checkpoint: step {ckpt['step']}, "
551
+ f"{len(ckpt['merges'])} merges, "
552
+ f"{ckpt['elapsed_seconds']:.1f}s elapsed")
553
+ return ckpt
554
+
555
+ def replay_merges(self, merge_strs, word_types, word_freqs, token_index, pair_counts):
556
+ print(f" replaying {len(merge_strs)} merges...", flush=True)
557
+ t0 = time.time()
558
+ dummy_heap: list = []
559
+ for a_str, b_str in tqdm(merge_strs, desc=" replaying", unit=" merge"):
560
+ a_id = self.symbols.to_id(a_str)
561
+ b_id = self.symbols.to_id(b_str)
562
+ if a_id is None or b_id is None:
563
+ continue
564
+ merged_id = self.symbols.add_merged(a_id, b_id)
565
+ self.merges.append((a_id, b_id))
566
+ self.merge_and_update(
567
+ word_types, word_freqs, (a_id, b_id), pair_counts,
568
+ token_index, merged_id, dummy_heap,
569
+ )
570
+ print(f" replayed {len(self.merges)} merges in {time.time()-t0:.1f}s")
571
+
572
+ def train(self, train_file: str, output_dir: str = "output",
573
+ resume_path: str | None = None):
574
+ os.makedirs(output_dir, exist_ok=True)
575
+
576
+ print(f"WWHO (SGPE) GPE Trainer — script_mode={self.script_mode}, "
577
+ f"workers={self.num_workers}")
578
+ print(f"Training file: {train_file}\n")
579
+
580
+ print("[1/5] Streaming pre-tokenization (CodeSwitchRouter)...")
581
+ t_start = time.time()
582
+ word_counter, boundary_tokens = self.stream_and_count(train_file, output_dir)
583
+
584
+ print("\n[2/5] Building ID corpus...")
585
+ syl_freq = None
586
+ if self.prune_freq > 0:
587
+ syl_freq = self.compute_syllable_freqs(word_counter)
588
+ total_syls = len(syl_freq)
589
+ surviving = sum(1 for f in syl_freq.values() if f >= self.prune_freq)
590
+ print(f" syllable pruning: {total_syls:,} unique syllables, "
591
+ f"{surviving:,} survive (freq >= {self.prune_freq})")
592
+
593
+ word_types, word_freqs = self.build_word_types(
594
+ word_counter, boundary_tokens, syl_freq=syl_freq,
595
+ )
596
+ del word_counter, syl_freq
597
+
598
+ base_vocab = len(self.symbols)
599
+ total_instances = sum(word_freqs)
600
+ print(f" base vocab (syllables + boundaries): {base_vocab:,}")
601
+ print(f" word types: {len(word_types):,} ({total_instances:,} instances)")
602
+
603
+ print("\n[3/5] Building index and counting pairs...")
604
+ token_index = self.build_token_index(word_types)
605
+ pair_counts = self.count_all_pairs(word_types, word_freqs)
606
+ print(f" {len(pair_counts):,} unique pairs")
607
+
608
+ start_step = 0
609
+ elapsed_prior = 0.0
610
+ if resume_path:
611
+ print(f"\n Resuming from {resume_path}...")
612
+ ckpt = self.load_checkpoint(resume_path)
613
+ self.replay_merges(
614
+ ckpt["merges"], word_types, word_freqs, token_index, pair_counts,
615
+ )
616
+ start_step = ckpt["step"]
617
+ elapsed_prior = ckpt["elapsed_seconds"]
618
+ pair_counts = self.count_all_pairs(word_types, word_freqs)
619
+ print(f" rebuilt pair counts: {len(pair_counts):,} unique pairs")
620
+
621
+ total_vocab_needed = self.target_vocab_size - len(SPECIAL_TOKENS)
622
+ num_merges = max(0, total_vocab_needed - base_vocab)
623
+ remaining = num_merges - start_step
624
+ print(f"\n merge budget: {num_merges:,} "
625
+ f"(starting at {start_step}, {remaining:,} remaining, min_freq={self.min_freq})")
626
+
627
+ print(f"\n[4/5] Merge loop...")
628
+ heap = self._build_heap(pair_counts)
629
+ t0 = time.time()
630
+ pbar = tqdm(range(start_step + 1, num_merges + 1),
631
+ desc=" merging", unit=" merge")
632
+
633
+ for step in pbar:
634
+ best_pair, freq = self._pop_best(heap, pair_counts)
635
+ if best_pair is None or freq < self.min_freq:
636
+ pbar.write(f" stopping at step {step}: "
637
+ f"{'no pairs' if best_pair is None else f'freq={freq} < {self.min_freq}'}")
638
+ break
639
+
640
+ a_id, b_id = best_pair
641
+ merged_id = self.symbols.add_merged(a_id, b_id)
642
+ self.merges.append(best_pair)
643
+
644
+ n_applied = self.merge_and_update(
645
+ word_types, word_freqs, best_pair, pair_counts,
646
+ token_index, merged_id, heap,
647
+ )
648
+
649
+ if step <= 10 or step % 1000 == 0:
650
+ a_s = self.symbols.to_str(a_id)
651
+ b_s = self.symbols.to_str(b_id)
652
+ m_s = self.symbols.to_str(merged_id)
653
+ elapsed = time.time() - t0 + elapsed_prior
654
+ pbar.write(f" [{step:>7}/{num_merges}] "
655
+ f"'{a_s}' + '{b_s}' -> '{m_s}' "
656
+ f"(freq={freq:,}, applied={n_applied:,}) [{elapsed:.1f}s]")
657
+
658
+ if self.checkpoint_every > 0 and step % self.checkpoint_every == 0:
659
+ elapsed = time.time() - t0 + elapsed_prior
660
+ path, sz = self.save_checkpoint(step, output_dir, elapsed)
661
+ pbar.write(f" >> checkpoint: {path} ({sz:.2f} MB)")
662
+
663
+ pbar.set_postfix(freq=freq, vocab=len(self.symbols))
664
+
665
+ pbar.close()
666
+ merge_elapsed = time.time() - t0
667
+ total_elapsed = merge_elapsed + elapsed_prior
668
+ print(f" done: {len(self.merges)} merges in {merge_elapsed:.1f}s "
669
+ f"(total {total_elapsed:.1f}s)")
670
+
671
+ print("\n[5/5] Building vocabulary and exporting...")
672
+ self._save_output(word_types, word_freqs, boundary_tokens, output_dir)
673
+
674
+ wall = time.time() - t_start
675
+ print(f"\nTotal wall time: {wall:.1f}s ({wall/60:.1f} min)")
676
+
677
+ def _save_output(self, word_types, word_freqs, boundary_tokens, output_dir):
678
+ final_freq: Counter[int] = Counter()
679
+ for wt_idx, wt in enumerate(word_types):
680
+ f = word_freqs[wt_idx]
681
+ for tid in wt:
682
+ if tid >= 0:
683
+ final_freq[tid] += f
684
+
685
+ vocab: dict[str, int] = {}
686
+ for i, st in enumerate(SPECIAL_TOKENS):
687
+ vocab[st] = i
688
+ next_id = len(SPECIAL_TOKENS)
689
+
690
+ for tid, _ in final_freq.most_common():
691
+ if len(vocab) >= self.target_vocab_size:
692
+ break
693
+ tok_str = self.symbols.to_str(tid)
694
+ if tok_str not in vocab:
695
+ vocab[tok_str] = next_id
696
+ next_id += 1
697
+
698
+ for sid in range(len(self.symbols)):
699
+ if len(vocab) >= self.target_vocab_size:
700
+ break
701
+ s = self.symbols.to_str(sid)
702
+ if s not in vocab:
703
+ vocab[s] = next_id
704
+ next_id += 1
705
+
706
+ print(f" vocab size: {len(vocab):,}")
707
+ print(f" merge rules: {len(self.merges):,}")
708
+
709
+ merge_strs = [
710
+ [self.symbols.to_str(a), self.symbols.to_str(b)]
711
+ for a, b in self.merges
712
+ ]
713
+
714
+ output = {
715
+ "version": "wwho_sgpe",
716
+ "script_mode": self.script_mode,
717
+ "vocab_size": len(vocab),
718
+ "special_tokens": SPECIAL_TOKENS,
719
+ "num_merges": len(self.merges),
720
+ "prune_freq": self.prune_freq,
721
+ "leading_space": True,
722
+ "merges": merge_strs,
723
+ "vocab": vocab,
724
+ }
725
+
726
+ path = os.path.join(output_dir, "vocab.json")
727
+ with open(path, "w", encoding="utf-8") as f:
728
+ json.dump(output, f, ensure_ascii=False, indent=2)
729
+ size_mb = os.path.getsize(path) / (1024 * 1024)
730
+ print(f" saved: {path} ({size_mb:.2f} MB)")
731
+
732
+ self.save_checkpoint(len(self.merges), output_dir, 0)
733
+
734
+ hf_path = os.path.join(output_dir, "tokenizer.json")
735
+ export_hf_tokenizer(vocab, merge_strs, SPECIAL_TOKENS, hf_path,
736
+ script_mode=self.script_mode)
737
+
738
+ print(f"\n{'='*60}")
739
+ print(f"TRAINING COMPLETE — WWHO")
740
+ print(f" Script mode: {self.script_mode}")
741
+ print(f" Vocab size: {len(vocab):,}")
742
+ print(f" Merge rules: {len(self.merges):,}")
743
+ print(f" Word types: {len(word_types):,}")
744
+ print(f"{'='*60}")
745
+
746
+
747
+ def main():
748
+ parser = argparse.ArgumentParser(description="WWHO (SGPE) GPE Trainer")
749
+ parser.add_argument("--train_file", type=str, default="dataset/mixed_train.jsonl")
750
+ parser.add_argument("--vocab_size", type=int, default=128_000,
751
+ help="Target SGPE vocab size (default 128K)")
752
+ parser.add_argument("--min_freq", type=int, default=2)
753
+ parser.add_argument("--prune_freq", type=int, default=100,
754
+ help="Drop syllables below this corpus frequency to [UNK]")
755
+ parser.add_argument("--output_dir", type=str, default="output")
756
+ parser.add_argument("--num_workers", type=int, default=None)
757
+ parser.add_argument("--checkpoint_every", type=int, default=20_000)
758
+ parser.add_argument("--resume", type=str, default=None)
759
+ parser.add_argument("--script_mode", type=str, default="mixed",
760
+ choices=["sinhala", "devanagari", "mixed"],
761
+ help="Which Indic script(s) to merge in BPE "
762
+ "(English/code always stays as boundary tokens)")
763
+ args = parser.parse_args()
764
+ _setup_logging(args.output_dir)
765
+ _log(f"Starting WWHO (SGPE) trainer: train_file={args.train_file} "
766
+ f"vocab_size={args.vocab_size} script_mode={args.script_mode} "
767
+ f"prune_freq={args.prune_freq} min_freq={args.min_freq}")
768
+
769
+ trainer = GPETrainer(
770
+ vocab_size=args.vocab_size,
771
+ min_freq=args.min_freq,
772
+ num_workers=args.num_workers,
773
+ checkpoint_every=args.checkpoint_every,
774
+ prune_freq=args.prune_freq,
775
+ script_mode=args.script_mode,
776
+ )
777
+ trainer.train(args.train_file, args.output_dir, resume_path=args.resume)
778
+
779
+
780
+ if __name__ == "__main__":
781
+ main()
meta_config.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "tiktoken_model": "o200k_base",
3
- "tiktoken_vocab_size": 200019,
4
- "sgpe_vocab_size": 128000,
5
- "sgpe_id_offset": 200019,
6
- "script_mode": "mixed",
7
- "sgpe_vocab_path": "vocab.json"
8
- }
 
 
 
 
 
 
 
 
 
router.py CHANGED
@@ -10,24 +10,12 @@ import re
10
  from dataclasses import dataclass
11
 
12
  import tiktoken
13
- # ---------------------------------------------------------------------------
14
- # Script-block detection
15
- # ---------------------------------------------------------------------------
16
-
17
- def _is_indic_joiner(ch: str) -> bool:
18
- # True if ZWJ or ZWNJ
19
- return ch in ('\u200C', '\u200D')
20
-
21
-
22
- # ---------------------------------------------------------------------------
23
- # Segment dataclass
24
- # ---------------------------------------------------------------------------
25
 
26
  @dataclass
27
  class TextSegment:
28
  text: str
29
- language: str # "latin", "sinhala", "devanagari", etc
30
- has_leading_space: bool = False # True if a boundary space was absorbed
31
 
32
 
33
  # ---------------------------------------------------------------------------
@@ -75,7 +63,7 @@ class CodeSwitchSegmenter:
75
  ch2 = text[pos]
76
  lang2 = self._get_char_language(ch2)
77
  if lang2 is not None and lang2 != "__joiner__":
78
- break # Found distinct Indic start
79
  pos += 1
80
 
81
  latino_only = text[start:pos]
@@ -92,7 +80,7 @@ class CodeSwitchSegmenter:
92
  indic_start = pos
93
  current_lang = self._get_char_language(text[pos])
94
  if current_lang == "__joiner__" or current_lang is None:
95
- current_lang = "__unknown__" # fallback
96
 
97
  while pos < n:
98
  c = text[pos]
@@ -101,7 +89,7 @@ class CodeSwitchSegmenter:
101
  pos += 1
102
  elif c_lang is not None:
103
  if current_lang == "__unknown__":
104
- current_lang = c_lang # adapt
105
  elif c_lang != current_lang:
106
  break
107
  pos += 1
@@ -114,7 +102,6 @@ class CodeSwitchSegmenter:
114
  has_leading_space=True
115
  ))
116
  else:
117
- # ─── 2. Accumulate Indic block (no prior Latin with space) ───
118
  indic_start = pos
119
  current_lang = ch_lang
120
 
@@ -138,12 +125,6 @@ class CodeSwitchSegmenter:
138
 
139
  return segments
140
 
141
-
142
-
143
- # ---------------------------------------------------------------------------
144
- # Router
145
- # ---------------------------------------------------------------------------
146
-
147
  # ---------------------------------------------------------------------------
148
  # Self-test
149
  # ---------------------------------------------------------------------------
@@ -172,7 +153,6 @@ if __name__ == "__main__":
172
  "AI (Artificial Intelligence) සහ देवनागरी text.",
173
  ]
174
 
175
- # _test segmenter independently
176
  language_blocks = {
177
  "sinhala": [(0x0d80, 0x0dff)],
178
  "devanagari": [(0x0900, 0x097f)]
 
10
  from dataclasses import dataclass
11
 
12
  import tiktoken
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  @dataclass
15
  class TextSegment:
16
  text: str
17
+ language: str
18
+ has_leading_space: bool = False
19
 
20
 
21
  # ---------------------------------------------------------------------------
 
63
  ch2 = text[pos]
64
  lang2 = self._get_char_language(ch2)
65
  if lang2 is not None and lang2 != "__joiner__":
66
+ break
67
  pos += 1
68
 
69
  latino_only = text[start:pos]
 
80
  indic_start = pos
81
  current_lang = self._get_char_language(text[pos])
82
  if current_lang == "__joiner__" or current_lang is None:
83
+ current_lang = "__unknown__"
84
 
85
  while pos < n:
86
  c = text[pos]
 
89
  pos += 1
90
  elif c_lang is not None:
91
  if current_lang == "__unknown__":
92
+ current_lang = c_lang
93
  elif c_lang != current_lang:
94
  break
95
  pos += 1
 
102
  has_leading_space=True
103
  ))
104
  else:
 
105
  indic_start = pos
106
  current_lang = ch_lang
107
 
 
125
 
126
  return segments
127
 
 
 
 
 
 
 
128
  # ---------------------------------------------------------------------------
129
  # Self-test
130
  # ---------------------------------------------------------------------------
 
153
  "AI (Artificial Intelligence) සහ देवनागरी text.",
154
  ]
155
 
 
156
  language_blocks = {
157
  "sinhala": [(0x0d80, 0x0dff)],
158
  "devanagari": [(0x0900, 0x097f)]