teszenofficial commited on
Commit
5277939
·
verified ·
1 Parent(s): 40c7704

Delete trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +0 -446
trainer.py DELETED
@@ -1,446 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import DataLoader, random_split
4
- from torch.optim import AdamW
5
- from torch.optim.lr_scheduler import CosineAnnealingLR
6
- from tqdm import tqdm
7
- import yaml
8
- import os
9
- import pickle
10
- import math
11
-
12
- from model import MTPMiniModel
13
- from tokenizer import MTPTokenizer
14
- from dataset import MTPDataset, collate_fn
15
-
16
-
17
- class MTPTrainer:
18
- """Entrenador mejorado x20 (versión CPU/GPU básica)"""
19
-
20
- def __init__(self, config_path='config.yaml'):
21
- with open(config_path, 'r', encoding='utf-8') as f:
22
- self.config = yaml.safe_load(f)
23
-
24
- # Device detection
25
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
-
27
- torch.set_num_threads(self.config['training']['num_threads'])
28
-
29
- print("=" * 70)
30
- print("MTP MINI x20 - Transformer Avanzado")
31
- print("=" * 70)
32
- print(f"\n🔥 Device: {self.device}")
33
-
34
- if self.device.type == 'cuda':
35
- print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
36
- print(f"🔥 VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
37
-
38
- # Tokenizer
39
- print("\n[1/6] Inicializando tokenizer mejorado...")
40
- self.tokenizer = MTPTokenizer()
41
-
42
- tokenizer_path = 'mtp_tokenizer.model'
43
- if not os.path.exists(tokenizer_path):
44
- print(" -> Entrenando nuevo tokenizer...")
45
- self.tokenizer.train(
46
- self.config['data']['corpus_path'],
47
- vocab_size=self.config['model']['vocab_size'],
48
- model_prefix='mtp_tokenizer'
49
- )
50
- else:
51
- print(f" -> Cargando tokenizer: {tokenizer_path}")
52
- self.tokenizer.load(tokenizer_path)
53
-
54
- print(f" ✅ Vocabulario: {self.tokenizer.vocab_size()} tokens")
55
-
56
- # Modelo
57
- print("\n[2/6] Inicializando modelo GRANDE...")
58
-
59
- model_config = self.config['model']
60
-
61
- self.model = MTPMiniModel(
62
- vocab_size=self.tokenizer.vocab_size(),
63
- d_model=model_config['d_model'],
64
- n_layers=model_config['n_layers'],
65
- n_heads=model_config['n_heads'],
66
- d_ff=model_config['d_ff'],
67
- max_seq_len=model_config['max_seq_len'],
68
- dropout=model_config['dropout'],
69
- use_swiglu=model_config.get('use_swiglu', True),
70
- use_flash_attention=model_config.get('use_flash_attention', True),
71
- use_reasoning_layer=model_config.get('use_reasoning_layer', True),
72
- reasoning_steps=model_config.get('reasoning_steps', 3),
73
- use_confidence_score=model_config.get('use_confidence_score', True)
74
- )
75
-
76
- param_count = self.model.count_parameters()
77
- print(f" ✅ Parámetros: {param_count:,} ({param_count/1e6:.1f}M)")
78
- print(f" ✅ Arquitectura: {model_config['n_layers']} layers, "
79
- f"{model_config['n_heads']} heads, dim={model_config['d_model']}")
80
-
81
- improvements = [
82
- "RoPE", "RMSNorm", "SwiGLU", "Flash Attention",
83
- "Reasoning Layers", "Confidence Score", "Anti-Hallucination",
84
- "Label Smoothing", "Advanced Repetition Penalty", "Early Stopping"
85
- ]
86
- print(f" ✅ Mejoras: {', '.join(improvements)}")
87
-
88
- # Mover a device
89
- self.model.to(self.device)
90
-
91
- if self.device.type == 'cuda':
92
- memory_allocated = torch.cuda.memory_allocated(0) / 1e9
93
- print(f" ✅ VRAM usada: {memory_allocated:.2f} GB")
94
-
95
- # Dataset
96
- print("\n[3/6] Cargando dataset con filtrado de calidad...")
97
- full_dataset = MTPDataset(
98
- self.config['data']['corpus_path'],
99
- self.tokenizer,
100
- max_seq_len=model_config['max_seq_len'],
101
- use_augmentation=self.config['data'].get('use_augmentation', True),
102
- augmentation_prob=self.config['data'].get('augmentation_prob', 0.4),
103
- min_quality_score=self.config['data'].get('min_quality_score', 0.3),
104
- remove_duplicates=self.config['data'].get('remove_duplicates', True)
105
- )
106
-
107
- total_examples = len(full_dataset)
108
-
109
- if total_examples < 100:
110
- print(f" ⚠️ WARNING: Dataset pequeño ({total_examples} ejemplos)")
111
- print(f" ⚠️ Se recomienda 1000+ ejemplos para óptimo rendimiento")
112
-
113
- val_split = self.config.get('data', {}).get('validation_split', 0.12)
114
- val_size = max(1, int(total_examples * val_split))
115
- train_size = total_examples - val_size
116
-
117
- if train_size > 0:
118
- self.train_dataset, self.val_dataset = random_split(
119
- full_dataset,
120
- [train_size, val_size],
121
- generator=torch.Generator().manual_seed(42)
122
- )
123
-
124
- print(f" ✅ Train: {len(self.train_dataset)} ejemplos")
125
- print(f" ✅ Validation: {len(self.val_dataset)} ejemplos")
126
- else:
127
- self.train_dataset = full_dataset
128
- self.val_dataset = full_dataset
129
- print(f" ⚠️ Dataset muy pequeño - usando todo para train y val")
130
-
131
- self.train_loader = DataLoader(
132
- self.train_dataset,
133
- batch_size=self.config['training']['batch_size'],
134
- shuffle=True,
135
- collate_fn=lambda batch: collate_fn(batch, self.tokenizer.pad_id()),
136
- num_workers=0
137
- )
138
-
139
- self.val_loader = DataLoader(
140
- self.val_dataset,
141
- batch_size=self.config['training']['batch_size'],
142
- shuffle=False,
143
- collate_fn=lambda batch: collate_fn(batch, self.tokenizer.pad_id()),
144
- num_workers=0
145
- )
146
-
147
- # Optimizer con grupos diferenciados
148
- print("\n[4/6] Configurando optimizer avanzado...")
149
-
150
- decay_params = []
151
- no_decay_params = []
152
- reasoning_params = []
153
-
154
- for name, param in self.model.named_parameters():
155
- if param.requires_grad:
156
- if 'reasoning' in name:
157
- reasoning_params.append(param)
158
- elif 'bias' in name or 'norm' in name or 'embedding' in name:
159
- no_decay_params.append(param)
160
- else:
161
- decay_params.append(param)
162
-
163
- param_groups = [
164
- {'params': decay_params, 'weight_decay': self.config['training']['weight_decay']},
165
- {'params': no_decay_params, 'weight_decay': 0.0},
166
- ]
167
-
168
- if reasoning_params:
169
- param_groups.append({
170
- 'params': reasoning_params,
171
- 'weight_decay': self.config['training']['weight_decay'] * 0.5,
172
- 'lr': self.config['training']['learning_rate'] * 0.8
173
- })
174
-
175
- self.optimizer = AdamW(
176
- param_groups,
177
- lr=self.config['training']['learning_rate'],
178
- betas=(0.9, 0.95),
179
- eps=1e-8
180
- )
181
-
182
- print(f" ✅ Optimizer: AdamW")
183
- print(f" ✅ LR: {self.config['training']['learning_rate']}")
184
- print(f" ✅ Weight decay: {self.config['training']['weight_decay']}")
185
-
186
- # Learning rate scheduler
187
- print("\n[5/6] Configurando LR scheduler...")
188
- self.warmup_steps = self.config['training'].get('warmup_steps', 500)
189
- total_steps = len(self.train_loader) * self.config['training']['epochs']
190
-
191
- if self.config['training'].get('use_lr_scheduler', True):
192
- self.scheduler = CosineAnnealingLR(
193
- self.optimizer,
194
- T_max=total_steps - self.warmup_steps,
195
- eta_min=self.config['training'].get('min_lr', 0.000005)
196
- )
197
- print(f" ✅ Scheduler: Cosine Annealing")
198
- else:
199
- self.scheduler = None
200
- print(f" ✅ Scheduler: None")
201
-
202
- print(f" ✅ Warmup steps: {self.warmup_steps}")
203
-
204
- self.start_epoch = 0
205
- self.global_step = 0
206
- self.best_val_loss = float('inf')
207
-
208
- # Early stopping
209
- self.patience = self.config['training'].get('patience', 8)
210
- self.min_delta = self.config['training'].get('min_delta', 0.0005)
211
- self.patience_counter = 0
212
- print(f" ✅ Early stopping: patience={self.patience}, min_delta={self.min_delta}")
213
-
214
- # Gradient accumulation
215
- self.accumulation_steps = self.config['training'].get('accumulation_steps', 8)
216
- effective_batch = self.config['training']['batch_size'] * self.accumulation_steps
217
- print(f" ✅ Gradient accumulation: {self.accumulation_steps} steps")
218
- print(f" ✅ Effective batch size: {effective_batch}")
219
-
220
- self.use_eos_weight = self.config['training'].get('use_eos_loss_weight', True)
221
-
222
- # Resume checkpoint
223
- if os.path.exists('checkpoint.pt'):
224
- print("\n[6/6] Cargando checkpoint...")
225
- self.load_checkpoint('checkpoint.pt')
226
- else:
227
- print("\n[6/6] Listo para entrenar!")
228
-
229
- def get_lr(self):
230
- """Get current learning rate with warmup"""
231
- if self.global_step < self.warmup_steps:
232
- return self.config['training']['learning_rate'] * (self.global_step / self.warmup_steps)
233
- return self.optimizer.param_groups[0]['lr']
234
-
235
- def train_epoch(self, epoch):
236
- """Train one epoch"""
237
- self.model.train()
238
- total_loss = 0
239
- total_confidence = 0
240
- confidence_samples = 0
241
-
242
- progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")
243
-
244
- self.optimizer.zero_grad()
245
-
246
- for batch_idx, (input_ids, target_ids) in enumerate(progress_bar):
247
- input_ids = input_ids.to(self.device, non_blocking=True)
248
- target_ids = target_ids.to(self.device, non_blocking=True)
249
-
250
- if self.model.use_confidence:
251
- logits, loss, confidence = self.model(
252
- input_ids, target_ids,
253
- use_eos_weight=self.use_eos_weight,
254
- return_confidence=True
255
- )
256
- mask = (target_ids != 0).float()
257
- avg_conf = (confidence * mask).sum() / mask.sum()
258
- total_confidence += avg_conf.item()
259
- confidence_samples += 1
260
- else:
261
- logits, loss = self.model(
262
- input_ids, target_ids,
263
- use_eos_weight=self.use_eos_weight
264
- )
265
-
266
- loss = loss / self.accumulation_steps
267
- loss.backward()
268
-
269
- if (batch_idx + 1) % self.accumulation_steps == 0:
270
- torch.nn.utils.clip_grad_norm_(
271
- self.model.parameters(),
272
- self.config['training']['max_grad_norm']
273
- )
274
-
275
- if self.global_step < self.warmup_steps:
276
- lr = self.get_lr()
277
- for param_group in self.optimizer.param_groups:
278
- param_group['lr'] = lr
279
-
280
- self.optimizer.step()
281
-
282
- if self.scheduler and self.global_step >= self.warmup_steps:
283
- self.scheduler.step()
284
-
285
- self.optimizer.zero_grad()
286
- self.global_step += 1
287
-
288
- total_loss += loss.item() * self.accumulation_steps
289
-
290
- postfix = {
291
- 'loss': f"{loss.item() * self.accumulation_steps:.4f}",
292
- 'lr': f"{self.get_lr():.6f}"
293
- }
294
-
295
- if confidence_samples > 0:
296
- postfix['conf'] = f"{total_confidence/confidence_samples:.3f}"
297
-
298
- progress_bar.set_postfix(postfix)
299
-
300
- avg_loss = total_loss / len(self.train_loader)
301
- avg_confidence = total_confidence / confidence_samples if confidence_samples > 0 else 0
302
-
303
- return avg_loss, avg_confidence
304
-
305
- def validate(self):
306
- """Validate model"""
307
- self.model.eval()
308
- total_loss = 0
309
- total_confidence = 0
310
- confidence_samples = 0
311
-
312
- with torch.no_grad():
313
- for input_ids, target_ids in self.val_loader:
314
- input_ids = input_ids.to(self.device, non_blocking=True)
315
- target_ids = target_ids.to(self.device, non_blocking=True)
316
-
317
- if self.model.use_confidence:
318
- logits, loss, confidence = self.model(
319
- input_ids, target_ids,
320
- return_confidence=True
321
- )
322
- mask = (target_ids != 0).float()
323
- avg_conf = (confidence * mask).sum() / mask.sum()
324
- total_confidence += avg_conf.item()
325
- confidence_samples += 1
326
- else:
327
- logits, loss = self.model(input_ids, target_ids)
328
-
329
- total_loss += loss.item()
330
-
331
- avg_loss = total_loss / len(self.val_loader)
332
- avg_confidence = total_confidence / confidence_samples if confidence_samples > 0 else 0
333
-
334
- return avg_loss, avg_confidence
335
-
336
- def train(self):
337
- """Main training loop"""
338
- print("\n" + "=" * 70)
339
- print("INICIANDO ENTRENAMIENTO")
340
- print("=" * 70)
341
-
342
- epochs = self.config['training']['epochs']
343
-
344
- for epoch in range(self.start_epoch, epochs):
345
- train_loss, train_conf = self.train_epoch(epoch)
346
- val_loss, val_conf = self.validate()
347
-
348
- if self.device.type == 'cuda':
349
- torch.cuda.empty_cache()
350
-
351
- print(f"\nEpoch {epoch+1}/{epochs}")
352
- print(f" Train Loss: {train_loss:.4f}")
353
- print(f" Val Loss: {val_loss:.4f}")
354
- print(f" Train Confidence: {train_conf:.3f}")
355
- print(f" Val Confidence: {val_conf:.3f}")
356
- print(f" LR: {self.get_lr():.6f}")
357
-
358
- # Early stopping
359
- if val_loss < (self.best_val_loss - self.min_delta):
360
- self.best_val_loss = val_loss
361
- self.patience_counter = 0
362
- self.save_checkpoint('best_model.pt', epoch + 1, is_best=True)
363
- print(f" ✅ Nuevo mejor modelo! (Val Loss: {val_loss:.4f})")
364
- else:
365
- self.patience_counter += 1
366
- print(f" -> No improvement. Patience: {self.patience_counter}/{self.patience}")
367
-
368
- if self.patience_counter >= self.patience:
369
- print(f"\n⚠️ Early stopping. Mejor val loss: {self.best_val_loss:.4f}")
370
- break
371
-
372
- if (epoch + 1) % self.config['training']['save_every'] == 0:
373
- self.save_checkpoint('checkpoint.pt', epoch + 1)
374
-
375
- print("\n" + "=" * 70)
376
- print("ENTRENAMIENTO COMPLETADO")
377
- print(f"Mejor Val Loss: {self.best_val_loss:.4f}")
378
- print("=" * 70)
379
-
380
- if os.path.exists('best_model.pt'):
381
- print("\nCargando mejor modelo...")
382
- checkpoint = torch.load('best_model.pt', map_location=self.device)
383
- self.model.load_state_dict(checkpoint['model_state_dict'])
384
-
385
- self.save_model()
386
-
387
- def save_checkpoint(self, path, epoch, is_best=False):
388
- """Save checkpoint"""
389
- checkpoint = {
390
- 'epoch': epoch,
391
- 'global_step': self.global_step,
392
- 'model_state_dict': self.model.state_dict(),
393
- 'optimizer_state_dict': self.optimizer.state_dict(),
394
- 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
395
- 'best_val_loss': self.best_val_loss,
396
- 'patience_counter': self.patience_counter,
397
- 'config': self.config
398
- }
399
- torch.save(checkpoint, path)
400
- if not is_best:
401
- print(f" 💾 Checkpoint guardado: {path}")
402
-
403
- def load_checkpoint(self, path):
404
- """Load checkpoint"""
405
- checkpoint = torch.load(path, map_location=self.device)
406
- self.model.load_state_dict(checkpoint['model_state_dict'])
407
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
408
- if self.scheduler and checkpoint['scheduler_state_dict']:
409
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
410
- self.start_epoch = checkpoint['epoch']
411
- self.global_step = checkpoint['global_step']
412
- self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
413
- self.patience_counter = checkpoint.get('patience_counter', 0)
414
- print(f" ✅ Resumido desde epoch {self.start_epoch}")
415
- print(f" ✅ Mejor val loss: {self.best_val_loss:.4f}")
416
-
417
- def save_model(self):
418
- """Save final model"""
419
- os.makedirs('output', exist_ok=True)
420
-
421
- self.model.to('cpu')
422
-
423
- model_data = {
424
- 'model_state_dict': self.model.state_dict(),
425
- 'config': self.config,
426
- 'vocab_size': self.tokenizer.vocab_size(),
427
- 'tokenizer_path': self.tokenizer.model_path,
428
- 'training_info': {
429
- 'final_epoch': self.start_epoch,
430
- 'best_val_loss': self.best_val_loss,
431
- 'total_parameters': self.model.count_parameters()
432
- }
433
- }
434
-
435
- output_path = 'output/mtp_mini.pkl'
436
- with open(output_path, 'wb') as f:
437
- pickle.dump(model_data, f)
438
-
439
- print(f"\n✅ Modelo final guardado: {output_path}")
440
- print(f"💾 Tamaño: {os.path.getsize(output_path) / (1024*1024):.2f} MB")
441
- print(f"🧠 Parámetros: {self.model.count_parameters()/1e6:.1f}M")
442
-
443
-
444
- if __name__ == '__main__':
445
- trainer = MTPTrainer('config.yaml')
446
- trainer.train()