teszenofficial commited on
Commit
f32c9d2
·
verified ·
1 Parent(s): 516e5ea

Upload 6 files

Browse files
Files changed (6) hide show
  1. config.yaml +38 -0
  2. dataset.py +124 -0
  3. model.py +331 -0
  4. mtp_tokenizer.model +3 -0
  5. mtp_tokenizer.vocab +0 -0
  6. tokenizer.py +138 -0
config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ augmentation_prob: 0.3
3
+ corpus_path: data.jsonl
4
+ max_text_length: 3000
5
+ min_text_length: 30
6
+ use_augmentation: true
7
+ validation_split: 0.15
8
+ generation:
9
+ default_max_tokens: 200
10
+ default_repetition_penalty: 1.2
11
+ default_temperature: 0.8
12
+ default_top_k: 50
13
+ default_top_p: 0.95
14
+ min_response_length: 30
15
+ model:
16
+ d_ff: 4096
17
+ d_model: 1024
18
+ dropout: 0.1
19
+ max_seq_len: 2048
20
+ n_heads: 16
21
+ n_layers: 24
22
+ vocab_size: 8000
23
+ training:
24
+ accumulation_steps: 8
25
+ batch_size: 2
26
+ epochs: 30
27
+ label_smoothing: 0.1
28
+ learning_rate: 0.0003
29
+ max_grad_norm: 1.0
30
+ min_delta: 0.0005
31
+ min_lr: 1.0e-06
32
+ num_threads: 4
33
+ patience: 7
34
+ save_every: 3
35
+ use_amp: true
36
+ use_lr_scheduler: true
37
+ warmup_steps: 500
38
+ weight_decay: 0.1
dataset.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import json
4
+ import random
5
+
6
+
7
+ class MTPDataset(Dataset):
8
+ """Dataset optimizado con augmentación inteligente"""
9
+
10
+ def __init__(self, corpus_path, tokenizer, max_seq_len=2048,
11
+ use_augmentation=True, augmentation_prob=0.3):
12
+ self.tokenizer = tokenizer
13
+ self.max_seq_len = max_seq_len
14
+ self.use_augmentation = use_augmentation
15
+ self.augmentation_prob = augmentation_prob
16
+ self.data = []
17
+
18
+ # Load corpus
19
+ print(f" → Cargando corpus: {corpus_path}")
20
+ with open(corpus_path, 'r', encoding='utf-8') as f:
21
+ for line in f:
22
+ line = line.strip()
23
+ if line:
24
+ try:
25
+ entry = json.loads(line)
26
+ if 'instruction' in entry and 'response' in entry:
27
+ # Validar que no estén vacíos
28
+ if entry['instruction'].strip() and entry['response'].strip():
29
+ self.data.append(entry)
30
+ except json.JSONDecodeError:
31
+ continue
32
+
33
+ print(f" ✓ Cargados {len(self.data)} ejemplos válidos")
34
+ if use_augmentation:
35
+ print(f" ✓ Augmentación activada (prob={augmentation_prob})")
36
+
37
+ def __len__(self):
38
+ return len(self.data)
39
+
40
+ def augment_text(self, text):
41
+ """Augmentación mejorada de texto"""
42
+ if not self.use_augmentation or random.random() > self.augmentation_prob:
43
+ return text
44
+
45
+ # 1. Variación en espacios y formato
46
+ if random.random() < 0.3:
47
+ text = text.strip()
48
+
49
+ # 2. Variación en puntuación final
50
+ if random.random() < 0.25:
51
+ if text.endswith('.'):
52
+ # A veces remover punto final
53
+ if random.random() < 0.5:
54
+ text = text[:-1]
55
+ elif not text.endswith(('.', '!', '?', ':')):
56
+ # A veces agregar punto
57
+ if random.random() < 0.5:
58
+ text = text + '.'
59
+
60
+ # 3. Variación en mayúsculas iniciales (muy ocasional)
61
+ if random.random() < 0.1 and len(text) > 0:
62
+ if text[0].isupper():
63
+ text = text[0].lower() + text[1:]
64
+ elif text[0].islower():
65
+ text = text[0].upper() + text[1:]
66
+
67
+ return text
68
+
69
+ def __getitem__(self, idx):
70
+ entry = self.data[idx]
71
+
72
+ instruction = entry['instruction']
73
+ response = entry['response']
74
+
75
+ # Aplicar augmentación
76
+ instruction = self.augment_text(instruction)
77
+ response = self.augment_text(response)
78
+
79
+ # Formato optimizado para entrenamiento
80
+ full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
81
+
82
+ # Tokenize
83
+ tokens = self.tokenizer.encode(full_text)
84
+
85
+ # Add BOS and EOS
86
+ tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
87
+
88
+ # Truncate if too long (mantener BOS y EOS)
89
+ if len(tokens) > self.max_seq_len:
90
+ tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
91
+
92
+ # Pad token ID será -100 para ignorar en loss
93
+ input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
94
+ target_ids = torch.tensor(tokens[1:], dtype=torch.long)
95
+
96
+ return input_ids, target_ids
97
+
98
+
99
+ def collate_fn(batch, pad_id=0):
100
+ """Collate function optimizada con padding dinámico"""
101
+ input_ids = [item[0] for item in batch]
102
+ target_ids = [item[1] for item in batch]
103
+
104
+ # Find max length in this batch (dynamic padding)
105
+ max_len = max(len(ids) for ids in input_ids)
106
+
107
+ # Pad sequences
108
+ input_ids_padded = []
109
+ target_ids_padded = []
110
+
111
+ for inp, tgt in zip(input_ids, target_ids):
112
+ pad_len = max_len - len(inp)
113
+
114
+ # Pad input with pad_id
115
+ input_ids_padded.append(
116
+ torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])
117
+ )
118
+
119
+ # Pad target with -100 (ignore_index in CrossEntropyLoss)
120
+ target_ids_padded.append(
121
+ torch.cat([tgt, torch.full((pad_len,), -100, dtype=torch.long)])
122
+ )
123
+
124
+ return torch.stack(input_ids_padded), torch.stack(target_ids_padded)
model.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class RotaryPositionalEmbedding(nn.Module):
8
+ """RoPE - Rotary Position Embedding con scaling mejorado"""
9
+
10
+ def __init__(self, dim, max_seq_len=4096, base=10000):
11
+ super().__init__()
12
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
13
+ self.register_buffer('inv_freq', inv_freq)
14
+ self.max_seq_len = max_seq_len
15
+
16
+ def forward(self, seq_len, device):
17
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
18
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
19
+ emb = torch.cat((freqs, freqs), dim=-1)
20
+ return emb.cos(), emb.sin()
21
+
22
+
23
+ def apply_rotary_pos_emb(q, k, cos, sin):
24
+ """Aplica RoPE a queries y keys"""
25
+ def rotate_half(x):
26
+ x1, x2 = x.chunk(2, dim=-1)
27
+ return torch.cat((-x2, x1), dim=-1)
28
+
29
+ q_embed = (q * cos) + (rotate_half(q) * sin)
30
+ k_embed = (k * cos) + (rotate_half(k) * sin)
31
+ return q_embed, k_embed
32
+
33
+
34
+ class MultiQueryAttention(nn.Module):
35
+ """Multi-Query Attention (MQA) - Más eficiente que MHA"""
36
+
37
+ def __init__(self, d_model, n_heads, dropout=0.1, max_seq_len=4096):
38
+ super().__init__()
39
+ assert d_model % n_heads == 0
40
+
41
+ self.d_model = d_model
42
+ self.n_heads = n_heads
43
+ self.d_k = d_model // n_heads
44
+
45
+ # Multi-query: Q tiene múltiples heads, K y V tienen 1 head
46
+ self.q_linear = nn.Linear(d_model, d_model, bias=False)
47
+ self.k_linear = nn.Linear(d_model, self.d_k, bias=False)
48
+ self.v_linear = nn.Linear(d_model, self.d_k, bias=False)
49
+ self.out_linear = nn.Linear(d_model, d_model, bias=False)
50
+
51
+ self.dropout = nn.Dropout(dropout)
52
+ self.attn_dropout = nn.Dropout(dropout)
53
+ self.rope = RotaryPositionalEmbedding(self.d_k, max_seq_len)
54
+
55
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
56
+
57
+ def forward(self, x, mask=None, use_cache=False, past_kv=None):
58
+ batch_size, seq_len, d_model = x.size()
59
+
60
+ # Q: [batch, seq, n_heads, d_k]
61
+ Q = self.q_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
62
+
63
+ # K, V: [batch, seq, d_k] -> expandir a [batch, n_heads, seq, d_k]
64
+ K = self.k_linear(x).unsqueeze(1).expand(-1, self.n_heads, -1, -1)
65
+ V = self.v_linear(x).unsqueeze(1).expand(-1, self.n_heads, -1, -1)
66
+
67
+ # Apply RoPE
68
+ cos, sin = self.rope(seq_len, x.device)
69
+ cos = cos[None, None, :, :]
70
+ sin = sin[None, None, :, :]
71
+ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)
72
+
73
+ # KV cache para inferencia
74
+ if use_cache:
75
+ if past_kv is not None:
76
+ K = torch.cat([past_kv[0], K], dim=2)
77
+ V = torch.cat([past_kv[1], V], dim=2)
78
+ cache = (K, V)
79
+ else:
80
+ cache = None
81
+
82
+ # Attention
83
+ if self.flash and mask is None and not use_cache:
84
+ context = F.scaled_dot_product_attention(
85
+ Q, K, V,
86
+ attn_mask=None,
87
+ dropout_p=self.dropout.p if self.training else 0.0,
88
+ is_causal=True
89
+ )
90
+ else:
91
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
92
+ if mask is not None:
93
+ scores = scores.masked_fill(mask == 0, float('-inf'))
94
+ attn_weights = F.softmax(scores, dim=-1)
95
+ attn_weights = self.attn_dropout(attn_weights)
96
+ context = torch.matmul(attn_weights, V)
97
+
98
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
99
+ output = self.out_linear(context)
100
+ return self.dropout(output), cache
101
+
102
+
103
+ class SwiGLU(nn.Module):
104
+ """SwiGLU activation con eficiencia mejorada"""
105
+
106
+ def __init__(self, d_model, d_ff, dropout=0.1):
107
+ super().__init__()
108
+ # FFN de GPT-3: 4x expansion
109
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
110
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
111
+ self.w3 = nn.Linear(d_model, d_ff, bias=False)
112
+ self.dropout = nn.Dropout(dropout)
113
+
114
+ def forward(self, x):
115
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
116
+
117
+
118
+ class RMSNorm(nn.Module):
119
+ """RMSNorm - Más estable que LayerNorm"""
120
+
121
+ def __init__(self, dim, eps=1e-6):
122
+ super().__init__()
123
+ self.eps = eps
124
+ self.weight = nn.Parameter(torch.ones(dim))
125
+
126
+ def forward(self, x):
127
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
128
+ return x * norm * self.weight
129
+
130
+
131
+ class TransformerBlock(nn.Module):
132
+ """Transformer Block optimizado estilo GPT-3"""
133
+
134
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1, max_seq_len=4096):
135
+ super().__init__()
136
+ self.attention = MultiQueryAttention(d_model, n_heads, dropout, max_seq_len)
137
+ self.feed_forward = SwiGLU(d_model, d_ff, dropout)
138
+
139
+ self.norm1 = RMSNorm(d_model)
140
+ self.norm2 = RMSNorm(d_model)
141
+
142
+ def forward(self, x, mask=None, use_cache=False, past_kv=None):
143
+ # Pre-norm architecture (mejor que post-norm)
144
+ attn_out, cache = self.attention(self.norm1(x), mask, use_cache, past_kv)
145
+ x = x + attn_out
146
+ x = x + self.feed_forward(self.norm2(x))
147
+ return x, cache
148
+
149
+
150
+ class MTPModel(nn.Module):
151
+ """MTP 3 - Arquitectura mejorada nivel GPT-3"""
152
+
153
+ def __init__(self, vocab_size, d_model=1024, n_layers=24, n_heads=16,
154
+ d_ff=4096, max_seq_len=2048, dropout=0.1):
155
+ super().__init__()
156
+
157
+ self.vocab_size = vocab_size
158
+ self.d_model = d_model
159
+ self.max_seq_len = max_seq_len
160
+
161
+ # Embeddings con escalado
162
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
163
+ self.dropout = nn.Dropout(dropout)
164
+
165
+ # Transformer blocks
166
+ self.blocks = nn.ModuleList([
167
+ TransformerBlock(d_model, n_heads, d_ff, dropout, max_seq_len)
168
+ for _ in range(n_layers)
169
+ ])
170
+
171
+ # Final norm y projection
172
+ self.norm_f = RMSNorm(d_model)
173
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
174
+
175
+ # Weight tying (reduce parámetros)
176
+ self.token_embedding.weight = self.lm_head.weight
177
+
178
+ # Inicialización mejorada (GPT-3 style)
179
+ self.apply(self._init_weights)
180
+
181
+ # Escalado especial para residual connections
182
+ for pn, p in self.named_parameters():
183
+ if pn.endswith('w2.weight') or pn.endswith('out_linear.weight'):
184
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * n_layers))
185
+
186
+ def _init_weights(self, module):
187
+ if isinstance(module, nn.Linear):
188
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
189
+ if module.bias is not None:
190
+ torch.nn.init.zeros_(module.bias)
191
+ elif isinstance(module, nn.Embedding):
192
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
193
+
194
+ def forward(self, input_ids, targets=None):
195
+ batch_size, seq_len = input_ids.size()
196
+
197
+ # Causal mask
198
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).view(1, 1, seq_len, seq_len)
199
+
200
+ # Embeddings con escalado
201
+ x = self.dropout(self.token_embedding(input_ids) * math.sqrt(self.d_model))
202
+
203
+ # Transformer blocks
204
+ for block in self.blocks:
205
+ x, _ = block(x, mask)
206
+
207
+ # Final norm y projection
208
+ x = self.norm_f(x)
209
+ logits = self.lm_head(x)
210
+
211
+ loss = None
212
+ if targets is not None:
213
+ # Label smoothing para mejor generalización
214
+ loss = F.cross_entropy(
215
+ logits.view(-1, self.vocab_size),
216
+ targets.view(-1),
217
+ label_smoothing=0.1,
218
+ ignore_index=-100
219
+ )
220
+
221
+ return logits, loss
222
+
223
+ @torch.no_grad()
224
+ def generate(self, input_ids, max_new_tokens=200, temperature=0.8,
225
+ top_k=50, top_p=0.95, repetition_penalty=1.2,
226
+ min_length=30, eos_token_id=3):
227
+ """Generación optimizada con KV cache"""
228
+ self.eval()
229
+
230
+ device = input_ids.device
231
+ generated = input_ids.clone()
232
+ past_kvs = [None] * len(self.blocks)
233
+ generated_text_tokens = 0
234
+
235
+ for step in range(max_new_tokens):
236
+ # Use cache para tokens ya procesados
237
+ if step == 0:
238
+ current_input = generated
239
+ use_cache = False
240
+ else:
241
+ current_input = generated[:, -1:]
242
+ use_cache = True
243
+
244
+ # Truncate si excede max_seq_len
245
+ if current_input.size(1) > self.max_seq_len:
246
+ current_input = current_input[:, -self.max_seq_len:]
247
+ use_cache = False
248
+ past_kvs = [None] * len(self.blocks)
249
+
250
+ # Forward pass
251
+ batch_size, seq_len = current_input.size()
252
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).view(1, 1, seq_len, seq_len)
253
+
254
+ x = self.token_embedding(current_input) * math.sqrt(self.d_model)
255
+
256
+ new_past_kvs = []
257
+ for i, block in enumerate(self.blocks):
258
+ x, cache = block(x, mask, use_cache, past_kvs[i] if use_cache else None)
259
+ new_past_kvs.append(cache)
260
+
261
+ if use_cache:
262
+ past_kvs = new_past_kvs
263
+
264
+ x = self.norm_f(x)
265
+ logits = self.lm_head(x[:, -1, :])
266
+
267
+ # Repetition penalty
268
+ if repetition_penalty != 1.0:
269
+ for token_id in set(generated[0].tolist()):
270
+ if logits[0, token_id] < 0:
271
+ logits[0, token_id] *= repetition_penalty
272
+ else:
273
+ logits[0, token_id] /= repetition_penalty
274
+
275
+ # Penalizar tokens muy repetidos
276
+ if generated.size(1) > 20:
277
+ recent = generated[0, -20:].tolist()
278
+ for token_id in set(recent):
279
+ count = recent.count(token_id)
280
+ if count > 3:
281
+ logits[0, token_id] -= count * 3.0
282
+
283
+ # Control de longitud mínima
284
+ if generated_text_tokens < min_length:
285
+ logits[0, eos_token_id] = float('-inf')
286
+ else:
287
+ # Boost EOS gradualmente
288
+ eos_boost = min((generated_text_tokens - min_length) * 0.15, 3.0)
289
+ logits[0, eos_token_id] += eos_boost
290
+
291
+ # Temperature scaling
292
+ logits = logits / temperature
293
+
294
+ # Top-k filtering
295
+ if top_k > 0:
296
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
297
+ logits[logits < v[:, [-1]]] = float('-inf')
298
+
299
+ # Top-p (nucleus) filtering
300
+ if top_p < 1.0:
301
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
302
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
303
+ sorted_indices_to_remove = cumulative_probs > top_p
304
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
305
+ sorted_indices_to_remove[:, 0] = 0
306
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
307
+ logits[indices_to_remove] = float('-inf')
308
+
309
+ # Sample
310
+ probs = F.softmax(logits, dim=-1)
311
+ next_token = torch.multinomial(probs, num_samples=1)
312
+
313
+ # Check EOS
314
+ if next_token.item() == eos_token_id and generated_text_tokens >= min_length:
315
+ break
316
+
317
+ generated = torch.cat([generated, next_token], dim=1)
318
+ generated_text_tokens += 1
319
+
320
+ return generated
321
+
322
+ def count_parameters(self):
323
+ """Cuenta parámetros entrenables"""
324
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
325
+
326
+ def get_num_params(self, non_embedding=True):
327
+ """Cuenta parámetros excluyendo embeddings si se requiere"""
328
+ n_params = sum(p.numel() for p in self.parameters())
329
+ if non_embedding:
330
+ n_params -= self.token_embedding.weight.numel()
331
+ return n_params
mtp_tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e13a9965e5241f5f677f6689f568b4bbaf68c16c2605d4907d0d608dede95adb
3
+ size 124448
mtp_tokenizer.vocab ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+ import os
3
+ import json
4
+
5
+
6
+ class MTPTokenizer:
7
+ """Tokenizer using SentencePiece BPE"""
8
+
9
+ def __init__(self, model_path=None):
10
+ self.sp = None
11
+ self.model_path = model_path
12
+
13
+ if model_path and os.path.exists(model_path):
14
+ self.load(model_path)
15
+
16
+ def train(self, corpus_path, vocab_size=4000, model_prefix='mtp_tokenizer'):
17
+ """Train SentencePiece BPE tokenizer on corpus"""
18
+
19
+ # Extract text from JSONL corpus
20
+ texts = []
21
+ with open(corpus_path, 'r', encoding='utf-8') as f:
22
+ for line in f:
23
+ data = json.loads(line)
24
+ if 'instruction' in data:
25
+ texts.append(data['instruction'])
26
+ if 'response' in data:
27
+ texts.append(data['response'])
28
+
29
+ # Save temporary text file
30
+ temp_file = 'temp_corpus.txt'
31
+ with open(temp_file, 'w', encoding='utf-8') as f:
32
+ f.write('\n'.join(texts))
33
+
34
+ # Calculate optimal vocab size based on corpus
35
+ total_chars = sum(len(text) for text in texts)
36
+ max_vocab = min(vocab_size, int(total_chars * 0.15)) # Heuristic: ~15% of chars
37
+
38
+ print(f" → Corpus stats: {len(texts)} texts, {total_chars} characters")
39
+ print(f" → Adjusted vocab size: {max_vocab} (requested: {vocab_size})")
40
+
41
+ # Train SentencePiece with adjusted parameters
42
+ try:
43
+ spm.SentencePieceTrainer.train(
44
+ input=temp_file,
45
+ model_prefix=model_prefix,
46
+ vocab_size=max_vocab,
47
+ model_type='bpe',
48
+ pad_id=0,
49
+ unk_id=1,
50
+ bos_id=2,
51
+ eos_id=3,
52
+ character_coverage=1.0,
53
+ normalization_rule_name='identity',
54
+ num_threads=4,
55
+ split_digits=True,
56
+ allow_whitespace_only_pieces=False,
57
+ byte_fallback=False,
58
+ max_sentencepiece_length=16
59
+ )
60
+ except RuntimeError as e:
61
+ if "Vocabulary size too high" in str(e):
62
+ # Extract suggested max from error and retry
63
+ import re
64
+ match = re.search(r'value <= (\d+)', str(e))
65
+ if match:
66
+ suggested_max = int(match.group(1))
67
+ print(f" → Retrying with vocab size: {suggested_max}")
68
+ spm.SentencePieceTrainer.train(
69
+ input=temp_file,
70
+ model_prefix=model_prefix,
71
+ vocab_size=suggested_max,
72
+ model_type='bpe',
73
+ pad_id=0,
74
+ unk_id=1,
75
+ bos_id=2,
76
+ eos_id=3,
77
+ character_coverage=1.0,
78
+ normalization_rule_name='identity',
79
+ num_threads=4,
80
+ split_digits=True,
81
+ allow_whitespace_only_pieces=False,
82
+ byte_fallback=False,
83
+ max_sentencepiece_length=16
84
+ )
85
+ else:
86
+ raise
87
+ else:
88
+ raise
89
+
90
+ # Clean up
91
+ os.remove(temp_file)
92
+
93
+ # Load the trained model
94
+ self.model_path = f"{model_prefix}.model"
95
+ self.load(self.model_path)
96
+
97
+ print(f"✓ Tokenizer trained: {self.vocab_size()} tokens")
98
+ print(f"✓ Model saved: {self.model_path}")
99
+
100
+ def load(self, model_path):
101
+ """Load trained tokenizer"""
102
+ self.sp = spm.SentencePieceProcessor()
103
+ self.sp.load(model_path)
104
+ self.model_path = model_path
105
+
106
+ def encode(self, text):
107
+ """Encode text to token IDs"""
108
+ if self.sp is None:
109
+ raise ValueError("Tokenizer not loaded. Train or load a model first.")
110
+ return self.sp.encode_as_ids(text)
111
+
112
+ def decode(self, ids):
113
+ """Decode token IDs to text"""
114
+ if self.sp is None:
115
+ raise ValueError("Tokenizer not loaded. Train or load a model first.")
116
+ return self.sp.decode_ids(ids)
117
+
118
+ def vocab_size(self):
119
+ """Get vocabulary size"""
120
+ if self.sp is None:
121
+ return 0
122
+ return self.sp.get_piece_size()
123
+
124
+ def bos_id(self):
125
+ """Beginning of sentence token ID"""
126
+ return self.sp.bos_id()
127
+
128
+ def eos_id(self):
129
+ """End of sentence token ID"""
130
+ return self.sp.eos_id()
131
+
132
+ def pad_id(self):
133
+ """Padding token ID"""
134
+ return self.sp.pad_id()
135
+
136
+ def unk_id(self):
137
+ """Unknown token ID"""
138
+ return self.sp.unk_id()