ace-1 commited on
Commit
09246b1
·
verified ·
1 Parent(s): 63ab8c9

Publish mgpt2 sft checkpoint (step 1262, val_loss 1.240358)

Browse files
README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - hi
5
+ - kn
6
+ license: mit
7
+ tags:
8
+ - causal-lm
9
+ - multilingual
10
+ - indic
11
+ - hindi
12
+ - kannada
13
+ - instruction-tuned
14
+ - text-generation-inference
15
+ pipeline_tag: text-generation
16
+ base_model: ace-1/mgpt2-pretrain
17
+ ---
18
+
19
+ # mgpt2-sft — Multilingual GPT-2 (Instruction-Tuned)
20
+
21
+ `mgpt2` fine-tuned on **30,000 multilingual instruction–response pairs** across 5 language variants:
22
+ English, Hindi (Devanagari), Hindi (Latin transliteration), Kannada (Kannada script), and Kannada
23
+ (Latin transliteration). Training data from ai4bharat/indic-align (Anudesh, Dolly-T, OpenAssistant-T).
24
+
25
+ Built on top of the pretrained `mgpt2` base — same 124M architecture, same custom multilingual tokenizer.
26
+ Uses masked cross-entropy (loss computed over response tokens only).
27
+
28
+ ## Quick start
29
+
30
+ ```python
31
+ import sys, torch
32
+ import torch.nn.functional as F
33
+ from huggingface_hub import snapshot_download
34
+
35
+ local = snapshot_download("ace-1/mgpt2-sft")
36
+ sys.path.insert(0, local)
37
+ from model import GPT
38
+ from tokenizer.regex_tokenizer import RegexTokenizer
39
+
40
+ ckpt = torch.load(f"{local}/pytorch_model.pt", weights_only=False, map_location="cpu")
41
+ model = GPT(ckpt["config"])
42
+ model.load_state_dict(ckpt["model"])
43
+ model.eval()
44
+
45
+ enc = RegexTokenizer()
46
+ enc.load(f"{local}/tokenizer/artifacts/mgpt2.model")
47
+
48
+ # Prompt: plain text, no special template needed
49
+ prompts = [
50
+ "What is the capital of Karnataka?", # English
51
+ "कर्नाटक की राजधानी क्या है?", # Hindi (Devanagari)
52
+ "ಕರ್ನಾಟಕದ ರಾಜಧಾನಿ ಯಾವುದು?", # Kannada script
53
+ ]
54
+
55
+ for prompt in prompts:
56
+ ids = enc.encode(prompt)
57
+ x = torch.tensor(ids, dtype=torch.long).unsqueeze(0)
58
+ with torch.no_grad():
59
+ for _ in range(120):
60
+ logits, _ = model(x[:, -1024:])
61
+ probs = F.softmax(logits[:, -1, :] / 0.7, dim=-1)
62
+ next_id = torch.multinomial(probs, num_samples=1)
63
+ if next_id.item() == 50256: break
64
+ x = torch.cat([x, next_id], dim=1)
65
+ print(f"Prompt : {prompt}")
66
+ print(f"Response: {enc.decode(x[0, len(ids):].tolist())}")
67
+ print()
68
+ ```
69
+
70
+ ## Intended use
71
+
72
+ **Good for:**
73
+ - Multilingual Q&A and instruction following (en/hi/kn, native + romanised scripts)
74
+ - Downstream fine-tuning starting point for Indic NLP tasks
75
+ - Research: multilingual instruction tuning at small scale
76
+
77
+ **Not for:** Safety-critical applications. Native-script variants (Devanagari, Kannada) are more reliable than
78
+ transliterated Latin variants, which are prone to mid-generation script drift (known limitation —
79
+ see training notes).
80
+
81
+ ## Model details
82
+
83
+ | Property | Value |
84
+ |---|---|
85
+ | Architecture | GPT-2 (12 layers / 12 heads / 768d) |
86
+ | Parameters | ~124M |
87
+ | Vocabulary | 50,257 (mgpt2 BPE) + padded to 50,304 |
88
+ | Context length | 1,024 tokens |
89
+ | Training stage | SFT (instruction-tuned) |
90
+ | Git commit | `d07224070033` |
91
+
92
+ ## Training configuration
93
+
94
+ | Parameter | Value |
95
+ |---|---|
96
+ | `seed` | `1337` |
97
+ | `batch_size` | `64` |
98
+ | `micro_batch_size` | `8` |
99
+ | `epochs` | `3` |
100
+ | `warmup_steps` | `50` |
101
+ | `max_lr` | `0.0003` |
102
+ | `min_lr_ratio` | `0.1` |
103
+ | `weight_decay` | `0.1` |
104
+ | `eval_interval` | `50` |
105
+
106
+ ## Evaluation
107
+
108
+ | Metric | Value | Notes |
109
+ |---|---|---|
110
+ | Val loss (masked CE) | 1.2404 | Response tokens only, held-out SFT set |
111
+ | Val PPL (SFT set) | 3.46 | Not comparable to pretrain LM PPL |
112
+ | Training steps | 1262 | 3 epochs over 30K examples |
113
+
114
+ > SFT val PPL is measured on the SFT held-out set (narrower domain) and is **not comparable**
115
+ > to the pretrain LM eval PPL (12.4), which measures general language modelling ability.
116
+
117
+ ## Training data
118
+
119
+ | Language | Count | Source |
120
+ |---|---|---|
121
+ | English (`eng_Latn`) | 16,500 | [ai4bharat/indic-align](https://huggingface.co/datasets/ai4bharat/indic-align) Anudesh |
122
+ | Hindi Devanagari (`hin_Deva`) | 5,400 | indic-align Dolly-T + OpenAssistant-T |
123
+ | Kannada script (`kan_Knda`) | 3,900 | indic-align Dolly-T + OpenAssistant-T |
124
+ | Hindi Latin translit (`hin_Latn`) | 2,100 | indic-align Dolly-T + OpenAssistant-T |
125
+ | Kannada Latin translit (`kan_Latn`) | 2,100 | indic-align Dolly-T + OpenAssistant-T |
126
+
127
+ 30,000 examples total. 90/10 train/val split. Masked CE — loss computed over response tokens only.
128
+
129
+ ## Tokenizer
130
+
131
+ Custom multilingual regex + BPE tokenizer (`mgpt2`), trained on the same corpus mixture.
132
+ Same vocabulary size as tiktoken-gpt2 (50,257 tokens), but with Indic-aware merge priorities:
133
+
134
+ | Bucket | tiktoken-gpt2 | **mgpt2** | Δ |
135
+ |---|---:|---:|---:|
136
+ | Overall | 480 tok/kB | **223 tok/kB** | −54% |
137
+ | Devanagari | 592 tok/kB | **215 tok/kB** | −64% |
138
+ | Kannada | 981 tok/kB | **213 tok/kB** | −78% |
139
+ | Latin | 257 tok/kB | **230 tok/kB** | −10% |
140
+
141
+ Tokenizer published separately: [ace-1/mgpt2-tokenizer](https://huggingface.co/ace-1/mgpt2-tokenizer)
142
+
143
+ ## Known limitations
144
+
145
+ - **Transliterated Latin script drift.** `hin_Latn` and `kan_Latn` may switch scripts mid-generation. Cause: ASCII tokens shared with English; no Unicode anchor. Mitigated but not eliminated at this data scale.
146
+ - **124M parameters.** Factual accuracy and multi-step reasoning are limited.
147
+ - **No safety alignment.** The SFT model was trained on benign instruction data only; it may attempt to answer harmful prompts. Use the DPO variant for light safety alignment.
148
+ - **Research checkpoint** — not evaluated for production use.
149
+
150
+ ## Citation
151
+
152
+ ```bibtex
153
+ @misc{mgpt2,
154
+ title = {mgpt2: Multilingual GPT-2 with custom Indic tokenizer},
155
+ year = {2026},
156
+ note = {Pretrain → SFT → DPO pipeline for English/Hindi/Kannada},
157
+ url = {https://huggingface.co/ace-1/mgpt2-sft}
158
+ }
159
+ ```
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GPT"
4
+ ],
5
+ "model_type": "mgpt2",
6
+ "block_size": 1024,
7
+ "vocab_size": 50304,
8
+ "n_layer": 12,
9
+ "n_head": 12,
10
+ "n_embd": 768,
11
+ "tokenizer_kind": "mgpt2_regex_bpe"
12
+ }
model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import inspect
6
+
7
+ @dataclass
8
+ class GPTConfig:
9
+ block_size: int = 1024 # sequence length
10
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 byte tokens + 1 <|endoftext|> token
11
+ n_layer: int = 12 # number of layers
12
+ n_head: int = 12 # number of attention heads
13
+ n_embd: int = 768 # embedding dimension
14
+
15
+ class CausalSelfAttention(nn.Module):
16
+ def __init__(self, config) -> None:
17
+ super().__init__()
18
+ assert config.n_embd % config.n_head == 0
19
+ self.c_attn= nn.Linear(config.n_embd, config.n_embd*3)
20
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
21
+ self.c_proj.NANOGPT_SCALE_INIT = 1
22
+ self.n_head = config.n_head
23
+ self.n_embd = config.n_embd
24
+
25
+ def forward(self, x):
26
+ B, T, C = x.size()
27
+ qkv = self.c_attn(x)
28
+ q, k, v = qkv.split(self.n_embd, dim=2)
29
+
30
+ q = q.reshape(B, T, self.n_head, C // self.n_head).transpose(1,2)
31
+ k = k.reshape(B, T, self.n_head, C // self.n_head).transpose(1,2)
32
+ v = v.reshape(B, T, self.n_head, C // self.n_head).transpose(1,2)
33
+
34
+ # att = q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))
35
+ # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
36
+ # att = F.softmax(att, dim=-1)
37
+ # y = att @ v
38
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
39
+
40
+ y = y.transpose(1, 2).contiguous().view(B,T,C)
41
+ y = self.c_proj(y)
42
+ return y
43
+
44
+ class MLP(nn.Module):
45
+ def __init__(self, config: GPTConfig):
46
+ super().__init__()
47
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
48
+ self.gelu = nn.GELU(approximate="tanh")
49
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
50
+ self.c_proj.NANOGPT_SCALE_INIT = 1
51
+
52
+ def forward(self, x):
53
+ x = self.c_fc(x)
54
+ x = self.gelu(x)
55
+ x = self.c_proj(x)
56
+ return x
57
+
58
+ class Block(nn.Module):
59
+ def __init__(self, config):
60
+ super().__init__()
61
+ self.ln_1 = nn.LayerNorm(config.n_embd)
62
+ self.attn = CausalSelfAttention(config)
63
+ self.ln_2 = nn.LayerNorm(config.n_embd)
64
+ self.mlp = MLP(config)
65
+
66
+ def forward(self, x):
67
+ x = x + self.attn(self.ln_1(x)) # (B, T, C)
68
+ x = x + self.mlp(self.ln_2(x)) # (B, T, C)
69
+ return x
70
+
71
+ class GPT(nn.Module):
72
+ def __init__(self, config):
73
+ super().__init__()
74
+ self.config = config
75
+
76
+ self.transformer = nn.ModuleDict(dict(
77
+ wte=nn.Embedding(config.vocab_size, config.n_embd), # token embedding table
78
+ wpe=nn.Embedding(config.block_size, config.n_embd), # position embedding table
79
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # transformer layers
80
+ ln_f=nn.LayerNorm(config.n_embd), # final layer norm
81
+ ))
82
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # language modeling head
83
+
84
+ # weight sharing scheme
85
+ self.transformer.wte.weight = self.lm_head.weight
86
+
87
+ self.apply(self._init_weights)
88
+
89
+ def _init_weights(self, module):
90
+ if isinstance(module, nn.Linear):
91
+ std = 0.02
92
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
93
+ std *= (2 * self.config.n_layer) ** -0.5
94
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
95
+ if module.bias is not None:
96
+ torch.nn.init.zeros_(module.bias)
97
+ elif isinstance(module, nn.Embedding):
98
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
99
+
100
+ def forward(self, idx, targets=None):
101
+ B, T = idx.size() # (B, T) = batch size, sequence length
102
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
103
+
104
+ pos = torch.arange(0, T, dtype=torch.long, device = idx.device)
105
+ tok_emb = self.transformer.wte(idx) # (B, T, n_embd)
106
+ pos_emb = self.transformer.wpe(pos) # (T, n_embd)
107
+ x = tok_emb + pos_emb # (B, T, n_embd)
108
+
109
+ for block in self.transformer.h:
110
+ x = block(x)
111
+
112
+ x = self.transformer.ln_f(x) # (B, T, n_embd)
113
+ logits = self.lm_head(x) # (B, T, vocab_size)
114
+
115
+ loss = None
116
+ if targets is not None:
117
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
118
+
119
+ return logits, loss
120
+
121
+ @classmethod
122
+ def from_pretrained(cls, model_type):
123
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
124
+ from transformers import GPT2LMHeadModel
125
+ print(f"loading weights from pretrained gpt {model_type}..")
126
+
127
+ config_args = {
128
+ "gpt2": dict(n_layer=12, n_head=12, n_embd=768),
129
+ "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024),
130
+ "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280),
131
+ "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600)
132
+ }[model_type]
133
+ config_args['vocab_size'] = 50257
134
+ config_args['block_size'] = 1024
135
+
136
+ config = GPTConfig(**config_args)
137
+ model = GPT(config)
138
+ sd = model.state_dict()
139
+ sd_keys = sd.keys()
140
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
141
+
142
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
143
+ sd_hf = model_hf.state_dict()
144
+
145
+ sd_keys_hf = sd_hf.keys()
146
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
147
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
148
+ transposed_keys = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
149
+ assert len(sd_keys_hf) == len(sd_keys), f"Mismatch: {len(sd_keys_hf)} != {len(sd_keys)}"
150
+ for k in sd_keys_hf:
151
+ if any(k.endswith(suffix) for suffix in transposed_keys):
152
+ assert sd_hf[k].shape[::-1] == sd[k].shape
153
+ with torch.no_grad():
154
+ sd[k].copy_(sd_hf[k].T)
155
+ else:
156
+ assert sd_hf[k].shape == sd[k].shape
157
+ with torch.no_grad():
158
+ sd[k].copy_(sd_hf[k])
159
+ return model
160
+
161
+ def configure_optimizers(self, weight_decay, learning_rate, device_type):
162
+ # start with all parameters that require gradients
163
+ param_dict = {pn: p for pn, p in self.named_parameters()}
164
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
165
+ # create optim groups. Any parameters that are 2D ares going to be weight decayed.
166
+ # i.e all weight tensors in matmul + embedding. All biases and layernorms are not.
167
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
168
+ non_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]
169
+ optim_groups = [
170
+ {'params': decay_params, 'weight_decay': weight_decay},
171
+ {'params': non_decay_params, 'weight_decay': 0.0}
172
+ ]
173
+ # num_decay_params = sum(p.numel() for p in decay_params)
174
+ # num_non_decay_params = sum(p.numel() for p in non_decay_params)
175
+ # if master_process:
176
+ # print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
177
+ # print(f"num non-decayed parameter tensors: {len(non_decay_params)}, with {num_non_decay_params:,} parameters")
178
+ # create AdamW optimizer and use fused version if it is available
179
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
180
+ use_fused = fused_available and device_type == 'cuda'
181
+ # if master_process:
182
+ # print(f"using fused AdamW: {use_fused}")
183
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
184
+ return optimizer
pytorch_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3876db428ed2c0ab1c6152b7e1221be21d8a01a9f52f752c6ffc17c988121cbb
3
+ size 497958335
tokenization_mgpt2.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from tokenizer.hf_tokenizer import MGPT2Tokenizer
2
+
3
+ __all__ = ['MGPT2Tokenizer']
tokenizer/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import Tokenizer
2
+ from .basic import BasicTokenizer
3
+ from .regex_tokenizer import RegexTokenizer
4
+ from .gpt4 import GPT4Tokenizer
5
+ from .patterns import GPT4_SPLIT_PATTERN, INDIC_SPLIT_PATTERN
6
+
7
+ __all__ = [
8
+ "Tokenizer",
9
+ "BasicTokenizer",
10
+ "RegexTokenizer",
11
+ "GPT4Tokenizer",
12
+ "GPT4_SPLIT_PATTERN",
13
+ "INDIC_SPLIT_PATTERN",
14
+ ]
15
+
tokenizer/artifacts/mgpt2.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2911100f93f224a36cfd6a40de8739a12f3fe7b0b885cd0edc961c6e5e6c4b1
3
+ size 463596
tokenizer/base.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A minimal implementation of Byte-Pair Encoding (BPE) tokenization.
3
+
4
+ BPE is a subword tokenization algorithm that iteratively merges the most frequent pairs of bytes or characters
5
+ to build a vocabulary of subword tokens. This implementation is inspired by Andrej Karpathy's minbpe
6
+ (https://github.com/karpathy/minbpe).
7
+ """
8
+ import unicodedata
9
+
10
+ def get_stats(ids, freq):
11
+ for pair in zip(ids[:-1], ids[1:]):
12
+ freq[pair] = freq.get(pair, 0) + 1
13
+
14
+ def merge(ids, pair, idx):
15
+ newids = []
16
+ i = 0
17
+ while i < len(ids):
18
+ if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
19
+ newids.append(idx)
20
+ i += 2
21
+ else:
22
+ newids.append(ids[i])
23
+ i += 1
24
+ return newids
25
+
26
+ def visualise_tokens(token_values: list[bytes]) -> None:
27
+ background = [f"\u001b[48;5;{i}m" for i in [167, 179, 185, 77, 80, 68, 134]]
28
+ # If token boundaries do not occur at unicode character boundaries, it's unclear how best to
29
+ # visualise the token. Here, we'll just use the unicode replacement character to represent some
30
+ # fraction of a character.
31
+ unicode_token_values = [x.decode("utf-8", errors="replace") for x in token_values]
32
+
33
+ running_length = 0
34
+ last_color = None
35
+ for token in unicode_token_values:
36
+ color = background[running_length % len(background)]
37
+ if color == last_color:
38
+ color = background[(running_length + 1) % len(background)]
39
+ assert color != last_color
40
+ last_color = color
41
+ running_length += len(token)
42
+ print(color + token, end="")
43
+ print("\u001b[0m")
44
+
45
+ # first two helper functions...
46
+ def replace_control_characters(s: str) -> str:
47
+ # we don't want to print control characters
48
+ # which distort the output (e.g. \n or much worse)
49
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117
50
+ # http://www.unicode.org/reports/tr44/#GC_Values_Table
51
+ chars = []
52
+ for ch in s:
53
+ if unicodedata.category(ch)[0] != "C":
54
+ chars.append(ch) # this character is ok
55
+ else:
56
+ chars.append(f"\\u{ord(ch):04x}") # escape
57
+ return "".join(chars)
58
+
59
+ def render_token(t: bytes) -> str:
60
+ # pretty print a token, escaping control characters
61
+ s = t.decode('utf-8', errors='replace')
62
+ s = replace_control_characters(s)
63
+ return s
64
+
65
+ #--------------------------------------------------------------------------------------------------
66
+ class Tokenizer:
67
+ def __init__(self):
68
+ self.merges = {} # (int, int) -> int
69
+ self.pattern = "" # str
70
+ self.special_tokens = {} # str -> int e.g {'<|endoftext|>': 100257}
71
+ self.inverse_special_tokens = {} # int -> str
72
+ self.vocab = self._build_vocab() # int -> bytes
73
+
74
+ def _build_vocab(self):
75
+ vocab = {idx: bytes([idx]) for idx in range(256)}
76
+ for (p0, p1), idx in self.merges.items():
77
+ vocab[idx] = vocab[p0] + vocab[p1]
78
+ return vocab
79
+
80
+ def train(self, text, vocab_size, verbose=False):
81
+ raise NotImplementedError
82
+
83
+ def decode(self, ids) -> str:
84
+ raise NotImplementedError
85
+
86
+ def encode(self, text, verbose=False) -> list[int]:
87
+ raise NotImplementedError
88
+
89
+ def save(self, file_prefix):
90
+ """
91
+ Saves two files: file_prefix.vocab and file_prefix.model
92
+ This is inspired (but not equivalent to!) sentencepiece's model saving:
93
+ - model file is the critical one, intended for load()
94
+ - vocab file is just a pretty printed version for human inspection only
95
+ """
96
+ # write the model: to be used in load() later
97
+ model_file = file_prefix + ".model"
98
+ with open(model_file, 'w') as f:
99
+ # write the version, pattern and merges, that's all that's needed
100
+ f.write("minbpe v1\n")
101
+ f.write(f"{self.pattern}\n")
102
+ # write the special tokens, first the number of them, then each one
103
+ f.write(f"{len(self.special_tokens)}\n")
104
+ for special, idx in self.special_tokens.items():
105
+ f.write(f"{special} {idx}\n")
106
+ # the merges dict
107
+ for idx1, idx2 in self.merges:
108
+ f.write(f"{idx1} {idx2}\n")
109
+ # write the vocab: for the human to look at
110
+ vocab_file = file_prefix + ".vocab"
111
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
112
+ with open(vocab_file, "w", encoding="utf-8") as f:
113
+ for idx, token in self.vocab.items():
114
+ # note: many tokens may be partial utf-8 sequences
115
+ # and cannot be decoded into valid strings. Here we're using
116
+ # errors='replace' to replace them with the replacement char �.
117
+ # this also means that we couldn't possibly use .vocab in load()
118
+ # because decoding in this way is a lossy operation!
119
+ s = render_token(token)
120
+ # find the children of this token, if any
121
+ if idx in inverted_merges:
122
+ # if this token has children, render it nicely as a merge
123
+ idx0, idx1 = inverted_merges[idx]
124
+ s0 = render_token(self.vocab[idx0])
125
+ s1 = render_token(self.vocab[idx1])
126
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
127
+ else:
128
+ # otherwise this is leaf token, just print it
129
+ # (this should just be the first 256 tokens, the bytes)
130
+ f.write(f"[{s}] {idx}\n")
131
+
132
+ def load(self, model_file):
133
+ """Inverse of save() but only for the model file"""
134
+ assert model_file.endswith(".model")
135
+ # read the model file
136
+ merges = {}
137
+ special_tokens = {}
138
+ idx = 256
139
+ with open(model_file, 'r', encoding="utf-8") as f:
140
+ # read the version
141
+ version = f.readline().strip()
142
+ assert version == "minbpe v1"
143
+ # read the pattern
144
+ self.pattern = f.readline().strip()
145
+ # read the special tokens
146
+ num_special = int(f.readline().strip())
147
+ for _ in range(num_special):
148
+ special, special_idx = f.readline().strip().split()
149
+ special_tokens[special] = int(special_idx)
150
+ # read the merges
151
+ for line in f:
152
+ idx1, idx2 = map(int, line.split())
153
+ merges[(idx1, idx2)] = idx
154
+ idx += 1
155
+ self.merges = merges
156
+ self.special_tokens = special_tokens
157
+ self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
158
+ self.vocab = self._build_vocab()
tokenizer/hf_tokenizer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Optional
5
+
6
+ from transformers import PreTrainedTokenizer
7
+
8
+ from tokenizer.regex_tokenizer import RegexTokenizer
9
+
10
+
11
+ class MGPT2Tokenizer(PreTrainedTokenizer):
12
+ """
13
+ Hugging Face-compatible (slow) tokenizer wrapper around `RegexTokenizer`.
14
+
15
+ This is intended for publishing alongside the model using `trust_remote_code=True`.
16
+ """
17
+
18
+ model_input_names = ["input_ids", "attention_mask"]
19
+ # Let `PreTrainedTokenizer.from_pretrained()` know which file it should pass to `__init__`.
20
+ vocab_files_names = {"model_file": "tokenizer.model"}
21
+
22
+ def __init__(self, model_file: str, **kwargs: Any):
23
+ if not model_file.endswith(".model"):
24
+ raise ValueError(f"model_file must end with .model, got: {model_file}")
25
+
26
+ self._tok = RegexTokenizer()
27
+ self._tok.load(model_file)
28
+
29
+ # Bind common special tokens if present in the trained tokenizer.
30
+ special = self._tok.special_tokens
31
+ kwargs.setdefault("eos_token", "<|endoftext|>" if "<|endoftext|>" in special else None)
32
+ kwargs.setdefault("unk_token", None)
33
+ kwargs.setdefault("pad_token", None)
34
+ kwargs.setdefault("bos_token", None)
35
+
36
+ super().__init__(**kwargs)
37
+
38
+ self.model_file = model_file
39
+
40
+ @property
41
+ def vocab_size(self) -> int:
42
+ # vocab is sparse only if merges are incomplete; generally size is max_id+1
43
+ return max(self._tok.vocab.keys()) + 1
44
+
45
+ def get_vocab(self) -> dict[str, int]:
46
+ # Provide a stable token-string mapping for HF internals.
47
+ inv_special = self._tok.inverse_special_tokens
48
+ vocab: dict[str, int] = {}
49
+ for i in range(self.vocab_size):
50
+ if i in inv_special:
51
+ vocab[inv_special[i]] = i
52
+ else:
53
+ vocab[f"<|bytebpe_{i}|>"] = i
54
+ return vocab
55
+
56
+ def _tokenize(self, text: str, **kwargs: Any) -> list[str]:
57
+ ids = self._tok.encode(text, allowed_special="all")
58
+ inv_special = self._tok.inverse_special_tokens
59
+ out: list[str] = []
60
+ for i in ids:
61
+ out.append(inv_special.get(i, f"<|bytebpe_{i}|>"))
62
+ return out
63
+
64
+ def _convert_token_to_id(self, token: str) -> int:
65
+ if token in self._tok.special_tokens:
66
+ return self._tok.special_tokens[token]
67
+ if token.startswith("<|bytebpe_") and token.endswith("|>"):
68
+ inner = token[len("<|bytebpe_") : -len("|>")]
69
+ return int(inner)
70
+ raise KeyError(f"Unknown token string: {token!r}")
71
+
72
+ def _convert_id_to_token(self, index: int) -> str:
73
+ return self._tok.inverse_special_tokens.get(index, f"<|bytebpe_{index}|>")
74
+
75
+ def convert_tokens_to_string(self, tokens: list[str]) -> str:
76
+ ids = [self._convert_token_to_id(t) for t in tokens]
77
+ return self._tok.decode(ids)
78
+
79
+ def build_inputs_with_special_tokens(self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None) -> list[int]:
80
+ if token_ids_1 is not None:
81
+ raise ValueError("This tokenizer does not support pair inputs.")
82
+ return token_ids_0
83
+
84
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
85
+ os.makedirs(save_directory, exist_ok=True)
86
+ prefix = filename_prefix or "tokenizer"
87
+ out_prefix = os.path.join(save_directory, prefix)
88
+ # Save in the native `.model`/`.vocab` format (human + machine readable for this repo).
89
+ self._tok.save(out_prefix)
90
+ return (out_prefix + ".model",)
91
+
tokenizer/patterns.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Regex patterns used by tokenizers in this package.
3
+
4
+ Keep patterns centralized so experiments + training scripts + notebooks
5
+ stay in sync.
6
+ """
7
+
8
+ # Default GPT-4-ish split pattern (as used in `RegexTokenizer` and `GPT4Tokenizer`)
9
+ GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
10
+
11
+ # Indic-focused experimental pattern (Hindi Devanagari + Kannada ranges and punctuation)
12
+ INDIC_SPLIT_PATTERN = r"""(?i) 's|'t|'re|'ve|'m|'ll|'d| ?\b[\p{L}\u0900-\u0963|\u0966-\u097F]+\b| ?\b[\p{L}\u0C80-\u0C9E|\u0CA0-\u0CFF]+\b| ?[\p{N}]+| ?[.,!?;:'\"-]| ?[\u0964-\u0965]| ?[\u0C9E-\u0C9F]| ?[^\s\p{L}\p{N}\u0900-\u097F\u0C80-\u0CFF]+| \s+(?!\S)| \s+"""
13
+
tokenizer/regex_tokenizer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from .base import get_stats, merge, visualise_tokens
3
+ from .basic import BasicTokenizer
4
+ from .patterns import GPT4_SPLIT_PATTERN
5
+ except ImportError: # allow running as a script from inside `tokenizer/`
6
+ from base import get_stats, merge, visualise_tokens
7
+ from basic import BasicTokenizer
8
+ from patterns import GPT4_SPLIT_PATTERN
9
+ from collections import Counter, defaultdict
10
+ import heapq
11
+ import regex as re
12
+ from tqdm import tqdm
13
+ import time
14
+
15
+ class RegexTokenizer(BasicTokenizer):
16
+ def __init__(self, regex: str = GPT4_SPLIT_PATTERN):
17
+ super().__init__()
18
+ self.pattern = regex
19
+ self.regex = re.compile(self.pattern)
20
+
21
+ def register_special_tokens(self, special_tokens: dict[str, int]):
22
+ self.special_tokens = special_tokens
23
+ self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
24
+
25
+ @staticmethod
26
+ def _merge_word(word: tuple[int, ...], pair: tuple[int, int], new_id: int) -> tuple[int, ...]:
27
+ """Merge all non-overlapping occurrences of `pair` in `word`."""
28
+ out: list[int] = []
29
+ i = 0
30
+ while i < len(word):
31
+ if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]:
32
+ out.append(new_id)
33
+ i += 2
34
+ else:
35
+ out.append(word[i])
36
+ i += 1
37
+ return tuple(out)
38
+
39
+ @staticmethod
40
+ def _pair_occurrences(word: tuple[int, ...]) -> dict[tuple[int, int], int]:
41
+ """Return unweighted pair -> count for a single word/chunk."""
42
+ if len(word) < 2:
43
+ return {}
44
+ counts: dict[tuple[int, int], int] = {}
45
+ a = word[0]
46
+ for b in word[1:]:
47
+ p = (a, b)
48
+ counts[p] = counts.get(p, 0) + 1
49
+ a = b
50
+ return counts
51
+
52
+ def train(
53
+ self,
54
+ text: str,
55
+ vocab_size: int = 50_257,
56
+ verbose: bool = False,
57
+ *,
58
+ min_chunk_freq: int = 1,
59
+ max_chunks: int | None = None,
60
+ ):
61
+ assert vocab_size >= 256, "Vocab size must be at least 256"
62
+ num_merges = vocab_size - 256
63
+
64
+ # Count chunk frequencies without storing a giant list of chunks.
65
+ # Each unique chunk becomes a "word" in classic BPE training.
66
+ chunk_counts: Counter[bytes] = Counter()
67
+ for m in self.regex.finditer(text):
68
+ s = m.group(0)
69
+ if s:
70
+ chunk_counts[s.encode("utf-8")] += 1
71
+
72
+ # Heuristic speed knobs: ignore rare chunks and/or cap unique chunk types.
73
+ # This massively reduces training state on web-scale corpora and keeps code simple.
74
+ if min_chunk_freq > 1:
75
+ chunk_counts = Counter({b: f for b, f in chunk_counts.items() if f >= min_chunk_freq})
76
+ if max_chunks is not None and len(chunk_counts) > max_chunks:
77
+ chunk_counts = Counter(dict(chunk_counts.most_common(max_chunks)))
78
+
79
+ # words: tuple(symbol_ids) -> frequency
80
+ words: dict[tuple[int, ...], int] = {}
81
+ for b, freq in chunk_counts.items():
82
+ words[tuple(b)] = freq
83
+
84
+ # Global pair stats and a reverse index pair -> set(words containing it)
85
+ pair_counts: dict[tuple[int, int], int] = defaultdict(int)
86
+ pair_to_words: dict[tuple[int, int], set[tuple[int, ...]]] = defaultdict(set)
87
+ for w, freq in words.items():
88
+ local = self._pair_occurrences(w)
89
+ for p, occ in local.items():
90
+ pair_counts[p] += freq * occ
91
+ pair_to_words[p].add(w)
92
+
93
+ # Max-heap for fast "most frequent pair" selection (lazy updates).
94
+ heap: list[tuple[int, tuple[int, int]]] = [(-c, p) for p, c in pair_counts.items()]
95
+ heapq.heapify(heap)
96
+
97
+ merges = {}
98
+ vocab = {idx: bytes([idx]) for idx in range(256)}
99
+
100
+ def bump_pair(p: tuple[int, int], delta: int) -> None:
101
+ if delta == 0:
102
+ return
103
+ new = pair_counts.get(p, 0) + delta
104
+ if new <= 0:
105
+ pair_counts.pop(p, None)
106
+ pair_to_words.pop(p, None)
107
+ return
108
+ pair_counts[p] = new
109
+ heapq.heappush(heap, (-new, p))
110
+
111
+ for i in tqdm(range(num_merges), desc="Training tokenizer"):
112
+ start_time = time.time()
113
+
114
+ # Pop stale heap entries until the top matches current counts.
115
+ while heap:
116
+ negc, p = heap[0]
117
+ c = pair_counts.get(p, 0)
118
+ if c > 0 and -negc == c:
119
+ break
120
+ heapq.heappop(heap)
121
+ if not heap:
122
+ break
123
+
124
+ pair = heap[0][1]
125
+ count = pair_counts.get(pair, 0)
126
+ if count <= 0:
127
+ break
128
+
129
+ idx = 256 + i
130
+ merges[pair] = idx
131
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
132
+
133
+ affected = list(pair_to_words.get(pair, ()))
134
+ if not affected:
135
+ pair_counts.pop(pair, None)
136
+ pair_to_words.pop(pair, None)
137
+ continue
138
+
139
+ # Apply merge to all words that contain the best pair.
140
+ for w in affected:
141
+ freq = words.get(w)
142
+ if not freq:
143
+ continue
144
+
145
+ new_w = self._merge_word(w, pair, idx)
146
+ if new_w == w:
147
+ continue
148
+
149
+ # Remove old word contributions
150
+ old_local = self._pair_occurrences(w)
151
+ for p, occ in old_local.items():
152
+ bump_pair(p, -freq * occ)
153
+ s = pair_to_words.get(p)
154
+ if s is not None:
155
+ s.discard(w)
156
+ if not s:
157
+ pair_to_words.pop(p, None)
158
+
159
+ # Update words dict (merge words that collapse to the same new tuple)
160
+ del words[w]
161
+ words[new_w] = words.get(new_w, 0) + freq
162
+
163
+ # Add new word contributions
164
+ new_local = self._pair_occurrences(new_w)
165
+ for p, occ in new_local.items():
166
+ bump_pair(p, freq * occ)
167
+ pair_to_words[p].add(new_w)
168
+
169
+ # This pair should be fully merged away.
170
+ pair_counts.pop(pair, None)
171
+ pair_to_words.pop(pair, None)
172
+
173
+ if verbose and i % 10 == 0:
174
+ time_taken = time.time() - start_time
175
+ tqdm.write(
176
+ f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) "
177
+ f"had {count} occurrences (took {time_taken:.2f}s)"
178
+ )
179
+
180
+ self.merges = merges
181
+ self.vocab = vocab
182
+
183
+ def decode(self, ids) -> str:
184
+ part_bytes = []
185
+ for id in ids:
186
+ if id in self.vocab:
187
+ part_bytes.append(self.vocab[id]) # id can be > 256 after merging
188
+ elif id in getattr(self, "inverse_special_tokens", {}):
189
+ part_bytes.append(self.inverse_special_tokens[id].encode("utf-8"))
190
+ else:
191
+ raise ValueError(f"id={id} not in vocab or special_tokens")
192
+ text_bytes = b"".join(part_bytes)
193
+ text = text_bytes.decode(encoding="utf-8", errors="replace")
194
+ return text
195
+
196
+ def _encode_chunk(self, chunk_bytes: bytes, verbose=False) -> list[int]:
197
+ tokens = list(chunk_bytes)
198
+ while len(tokens) >= 2:
199
+ if verbose:
200
+ visualise_tokens([self.vocab[token] for token in tokens]) # token can be > 256 after merging
201
+ stats = {}
202
+ get_stats(tokens, stats)
203
+ pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
204
+ if not pair in self.merges:
205
+ break
206
+ idx = self.merges[pair]
207
+ tokens = merge(tokens, pair, idx)
208
+ return tokens
209
+
210
+ def encode_ordinary(self, text, verbose=False) -> list[int]:
211
+ chunk_texts = re.findall(self.regex, text)
212
+ ids_list = []
213
+ for i, text in enumerate(chunk_texts):
214
+ if verbose:
215
+ print()
216
+ print(f"encoding chunk {i+1}/{len(chunk_texts)}: {text}")
217
+ chunk_bytes = text.encode("utf-8") # raw bytes
218
+ ids = self._encode_chunk(chunk_bytes, verbose)
219
+ ids_list.extend(ids)
220
+ return ids_list
221
+
222
+ def encode(self, text, verbose=False, allowed_special="none") -> list[int]:
223
+ special = {}
224
+ if allowed_special == "all":
225
+ special = self.special_tokens
226
+ elif allowed_special == "none":
227
+ special = {}
228
+ elif allowed_special == "none_raise":
229
+ special = {}
230
+ assert all(token not in text for token in self.special_tokens), "Text contains special tokens that are not allowed"
231
+ elif isinstance(allowed_special, set):
232
+ special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
233
+ else:
234
+ raise ValueError(f"allowed_special={allowed_special} not understood.")
235
+ if not special:
236
+ return self.encode_ordinary(text, verbose)
237
+ special_pattern = "(" + "|".join(re.escape(token) for token in special) + ")"
238
+ parts = re.split(special_pattern, text)
239
+ ids = []
240
+ for part in parts:
241
+ if part in special:
242
+ ids.append(special[part])
243
+ else:
244
+ ids.extend(self.encode_ordinary(part, verbose))
245
+ return ids
246
+