teszenofficial commited on
Commit
eed645b
·
verified ·
1 Parent(s): 5277939

Delete trainer_gpu.py

Browse files
Files changed (1) hide show
  1. trainer_gpu.py +0 -573
trainer_gpu.py DELETED
@@ -1,573 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import DataLoader, random_split
4
- from torch.optim import AdamW
5
- from torch.optim.lr_scheduler import CosineAnnealingLR
6
- from torch.cuda.amp import autocast, GradScaler
7
- from tqdm import tqdm
8
- import yaml
9
- import os
10
- import pickle
11
- import math
12
- import numpy as np
13
-
14
- from model import MTPMiniModel
15
- from tokenizer import MTPTokenizer
16
- from dataset import MTPDataset, collate_fn
17
-
18
-
19
- class MTPTrainer:
20
- """Entrenador MEJORADO x20 con capacidades avanzadas"""
21
-
22
- def __init__(self, config_path='config.yaml'):
23
- with open(config_path, 'r', encoding='utf-8') as f:
24
- self.config = yaml.safe_load(f)
25
-
26
- # ========== CONFIGURAR DISPOSITIVO ==========
27
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
-
29
- print("=" * 70)
30
- print("MTP MINI x20 - Transformer Avanzado con Razonamiento")
31
- print("=" * 70)
32
- print(f"\n🔥 Device: {self.device}")
33
-
34
- if self.device.type == 'cuda':
35
- print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
36
- print(f"🔥 VRAM Total: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
37
- torch.backends.cudnn.benchmark = True
38
- torch.backends.cuda.matmul.allow_tf32 = True
39
- torch.backends.cudnn.allow_tf32 = True
40
- print(f"🔥 Optimizaciones CUDA: Activadas")
41
-
42
- # Gradient checkpointing para ahorrar memoria
43
- self.use_gradient_checkpointing = self.config['training'].get('use_gradient_checkpointing', True)
44
- if self.use_gradient_checkpointing:
45
- print(f"🔥 Gradient Checkpointing: Activado (ahorra VRAM)")
46
- else:
47
- print("⚠️ WARNING: Usando CPU - El entrenamiento será MUY lento")
48
- self.use_gradient_checkpointing = False
49
-
50
- # Mixed precision training
51
- self.use_mixed_precision = self.device.type == 'cuda' and self.config['training'].get('use_mixed_precision', True)
52
- if self.use_mixed_precision:
53
- self.scaler = GradScaler()
54
- print(f"🔥 Mixed Precision (FP16): Activado")
55
-
56
- torch.set_num_threads(self.config['training']['num_threads'])
57
-
58
- # ========== TOKENIZER ==========
59
- print("\n[1/7] Inicializando tokenizer mejorado...")
60
- self.tokenizer = MTPTokenizer()
61
-
62
- tokenizer_path = 'mtp_tokenizer.model'
63
- if not os.path.exists(tokenizer_path):
64
- print(" -> Entrenando nuevo tokenizer...")
65
- self.tokenizer.train(
66
- self.config['data']['corpus_path'],
67
- vocab_size=self.config['model']['vocab_size'],
68
- model_prefix='mtp_tokenizer'
69
- )
70
- else:
71
- print(f" -> Cargando tokenizer: {tokenizer_path}")
72
- self.tokenizer.load(tokenizer_path)
73
-
74
- print(f" ✅ Vocabulario: {self.tokenizer.vocab_size()} tokens")
75
-
76
- # ========== MODELO ==========
77
- print("\n[2/7] Inicializando modelo GRANDE (x20)...")
78
-
79
- model_config = self.config['model']
80
-
81
- self.model = MTPMiniModel(
82
- vocab_size=self.tokenizer.vocab_size(),
83
- d_model=model_config['d_model'],
84
- n_layers=model_config['n_layers'],
85
- n_heads=model_config['n_heads'],
86
- d_ff=model_config['d_ff'],
87
- max_seq_len=model_config['max_seq_len'],
88
- dropout=model_config['dropout'],
89
- use_swiglu=model_config.get('use_swiglu', True),
90
- use_flash_attention=model_config.get('use_flash_attention', True),
91
- use_reasoning_layer=model_config.get('use_reasoning_layer', True),
92
- reasoning_steps=model_config.get('reasoning_steps', 3),
93
- use_confidence_score=model_config.get('use_confidence_score', True)
94
- ).to(self.device)
95
-
96
- param_count = self.model.count_parameters()
97
- print(f" ✅ Parámetros TOTALES: {param_count:,} ({param_count/1e6:.1f}M)")
98
- print(f" ✅ Arquitectura:")
99
- print(f" • Capas: {model_config['n_layers']}")
100
- print(f" • Cabezas de atención: {model_config['n_heads']}")
101
- print(f" • Dimensión: {model_config['d_model']}")
102
- print(f" • FFN: {model_config['d_ff']}")
103
- print(f" • Contexto máximo: {model_config['max_seq_len']} tokens")
104
-
105
- # Mostrar memoria GPU
106
- if self.device.type == 'cuda':
107
- memory_allocated = torch.cuda.memory_allocated(0) / 1e9
108
- memory_reserved = torch.cuda.memory_reserved(0) / 1e9
109
- print(f" ✅ VRAM usada: {memory_allocated:.2f} GB (reservada: {memory_reserved:.2f} GB)")
110
-
111
- improvements = [
112
- "RoPE", "RMSNorm", "SwiGLU", "Flash Attention",
113
- "Reasoning Layers", "Confidence Score", "Anti-Hallucination",
114
- "Label Smoothing", "Repetition Penalty", "Early Stopping",
115
- "Mixed Precision", "Gradient Checkpointing"
116
- ]
117
- print(f" ✅ Mejoras activas: {', '.join(improvements)}")
118
-
119
- # ========== DATASET ==========
120
- print("\n[3/7] Cargando dataset grande...")
121
- full_dataset = MTPDataset(
122
- self.config['data']['corpus_path'],
123
- self.tokenizer,
124
- max_seq_len=model_config['max_seq_len'],
125
- use_augmentation=self.config['data'].get('use_augmentation', True),
126
- augmentation_prob=self.config['data'].get('augmentation_prob', 0.4)
127
- )
128
-
129
- total_examples = len(full_dataset)
130
- print(f" ✅ Total ejemplos: {total_examples}")
131
-
132
- if total_examples < 100:
133
- print(f" ⚠️ WARNING: Dataset pequeño ({total_examples} ejemplos)")
134
- print(f" ⚠️ Se recomienda al menos 1000 ejemplos para este modelo")
135
-
136
- val_split = self.config.get('data', {}).get('validation_split', 0.12)
137
- val_size = max(1, int(total_examples * val_split))
138
- train_size = total_examples - val_size
139
-
140
- if train_size > 0:
141
- self.train_dataset, self.val_dataset = random_split(
142
- full_dataset,
143
- [train_size, val_size],
144
- generator=torch.Generator().manual_seed(42)
145
- )
146
-
147
- print(f" ✅ Train: {len(self.train_dataset)} ejemplos ({train_size/total_examples*100:.1f}%)")
148
- print(f" ✅ Validation: {len(self.val_dataset)} ejemplos ({val_size/total_examples*100:.1f}%)")
149
- else:
150
- self.train_dataset = full_dataset
151
- self.val_dataset = full_dataset
152
- print(f" ⚠️ Dataset muy pequeño - usando todo para train y validación")
153
-
154
- # DataLoaders optimizados
155
- num_workers = 4 if self.device.type == 'cuda' else 2
156
-
157
- self.train_loader = DataLoader(
158
- self.train_dataset,
159
- batch_size=self.config['training']['batch_size'],
160
- shuffle=True,
161
- collate_fn=lambda batch: collate_fn(batch, self.tokenizer.pad_id()),
162
- num_workers=num_workers,
163
- pin_memory=True if self.device.type == 'cuda' else False,
164
- persistent_workers=True if num_workers > 0 else False
165
- )
166
-
167
- self.val_loader = DataLoader(
168
- self.val_dataset,
169
- batch_size=self.config['training']['batch_size'],
170
- shuffle=False,
171
- collate_fn=lambda batch: collate_fn(batch, self.tokenizer.pad_id()),
172
- num_workers=num_workers,
173
- pin_memory=True if self.device.type == 'cuda' else False,
174
- persistent_workers=True if num_workers > 0 else False
175
- )
176
-
177
- # ========== OPTIMIZER ==========
178
- print("\n[4/7] Configurando optimizer avanzado...")
179
-
180
- # Grupos de parámetros con weight decay diferencial
181
- decay_params = []
182
- no_decay_params = []
183
- reasoning_params = []
184
-
185
- for name, param in self.model.named_parameters():
186
- if param.requires_grad:
187
- if 'reasoning' in name:
188
- reasoning_params.append(param)
189
- elif 'bias' in name or 'norm' in name or 'embedding' in name:
190
- no_decay_params.append(param)
191
- else:
192
- decay_params.append(param)
193
-
194
- param_groups = [
195
- {'params': decay_params, 'weight_decay': self.config['training']['weight_decay']},
196
- {'params': no_decay_params, 'weight_decay': 0.0},
197
- ]
198
-
199
- if reasoning_params:
200
- # Learning rate ligeramente menor para capas de razonamiento
201
- param_groups.append({
202
- 'params': reasoning_params,
203
- 'weight_decay': self.config['training']['weight_decay'] * 0.5,
204
- 'lr': self.config['training']['learning_rate'] * 0.8
205
- })
206
- print(f" ✅ Reasoning params: {sum(p.numel() for p in reasoning_params):,}")
207
-
208
- self.optimizer = AdamW(
209
- param_groups,
210
- lr=self.config['training']['learning_rate'],
211
- betas=(0.9, 0.95), # Betas optimizados para LLMs
212
- eps=1e-8
213
- )
214
-
215
- print(f" ✅ Optimizer: AdamW")
216
- print(f" ✅ LR base: {self.config['training']['learning_rate']}")
217
- print(f" ✅ Weight decay: {self.config['training']['weight_decay']}")
218
-
219
- # ========== SCHEDULER ==========
220
- print("\n[5/7] Configurando LR scheduler...")
221
- self.warmup_steps = self.config['training'].get('warmup_steps', 500)
222
- total_steps = len(self.train_loader) * self.config['training']['epochs']
223
-
224
- if self.config['training'].get('use_lr_scheduler', True):
225
- self.scheduler = CosineAnnealingLR(
226
- self.optimizer,
227
- T_max=total_steps - self.warmup_steps,
228
- eta_min=self.config['training'].get('min_lr', 0.000005)
229
- )
230
- print(f" ✅ Scheduler: Cosine Annealing")
231
- print(f" ✅ Total steps: {total_steps:,}")
232
- else:
233
- self.scheduler = None
234
- print(f" ✅ Scheduler: None")
235
-
236
- print(f" ✅ Warmup steps: {self.warmup_steps}")
237
-
238
- # ========== TRAINING STATE ==========
239
- self.start_epoch = 0
240
- self.global_step = 0
241
- self.best_val_loss = float('inf')
242
-
243
- # Early stopping
244
- self.patience = self.config['training'].get('patience', 8)
245
- self.min_delta = self.config['training'].get('min_delta', 0.0005)
246
- self.patience_counter = 0
247
- print(f" ✅ Early stopping: patience={self.patience}, min_delta={self.min_delta}")
248
-
249
- # Gradient accumulation
250
- self.accumulation_steps = self.config['training'].get('accumulation_steps', 8)
251
- effective_batch = self.config['training']['batch_size'] * self.accumulation_steps
252
- print(f" ✅ Gradient accumulation: {self.accumulation_steps} steps")
253
- print(f" ✅ Effective batch size: {effective_batch}")
254
-
255
- self.use_eos_weight = self.config['training'].get('use_eos_loss_weight', True)
256
- if self.use_eos_weight:
257
- print(f" ✅ EOS token weight: 2.0x")
258
-
259
- # ========== RESUME CHECKPOINT ==========
260
- print("\n[6/7] Verificando checkpoints...")
261
- if os.path.exists('checkpoint.pt'):
262
- print(" -> Cargando checkpoint...")
263
- self.load_checkpoint('checkpoint.pt')
264
- else:
265
- print(" ✅ No hay checkpoint previo")
266
-
267
- print("\n[7/7] ✅ Sistema listo para entrenar!")
268
- print("=" * 70)
269
-
270
- def get_lr(self):
271
- """Get current learning rate with warmup"""
272
- if self.global_step < self.warmup_steps:
273
- return self.config['training']['learning_rate'] * (self.global_step / self.warmup_steps)
274
- return self.optimizer.param_groups[0]['lr']
275
-
276
- def train_epoch(self, epoch):
277
- """Train one epoch con mixed precision"""
278
- self.model.train()
279
- total_loss = 0
280
- total_confidence = 0
281
- confidence_samples = 0
282
-
283
- progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")
284
-
285
- self.optimizer.zero_grad()
286
-
287
- for batch_idx, (input_ids, target_ids) in enumerate(progress_bar):
288
- # Mover datos a GPU
289
- input_ids = input_ids.to(self.device, non_blocking=True)
290
- target_ids = target_ids.to(self.device, non_blocking=True)
291
-
292
- # Forward pass con mixed precision
293
- if self.use_mixed_precision:
294
- with autocast():
295
- if self.model.use_confidence:
296
- logits, loss, confidence = self.model(
297
- input_ids, target_ids,
298
- use_eos_weight=self.use_eos_weight,
299
- return_confidence=True
300
- )
301
- # Trackear confianza promedio
302
- mask = (target_ids != 0).float()
303
- avg_conf = (confidence * mask).sum() / mask.sum()
304
- total_confidence += avg_conf.item()
305
- confidence_samples += 1
306
- else:
307
- logits, loss = self.model(
308
- input_ids, target_ids,
309
- use_eos_weight=self.use_eos_weight
310
- )
311
-
312
- loss = loss / self.accumulation_steps
313
-
314
- # Backward con scaling
315
- self.scaler.scale(loss).backward()
316
- else:
317
- # Sin mixed precision (CPU o GPU sin FP16)
318
- if self.model.use_confidence:
319
- logits, loss, confidence = self.model(
320
- input_ids, target_ids,
321
- use_eos_weight=self.use_eos_weight,
322
- return_confidence=True
323
- )
324
- mask = (target_ids != 0).float()
325
- avg_conf = (confidence * mask).sum() / mask.sum()
326
- total_confidence += avg_conf.item()
327
- confidence_samples += 1
328
- else:
329
- logits, loss = self.model(
330
- input_ids, target_ids,
331
- use_eos_weight=self.use_eos_weight
332
- )
333
-
334
- loss = loss / self.accumulation_steps
335
- loss.backward()
336
-
337
- # Optimizer step cada accumulation_steps
338
- if (batch_idx + 1) % self.accumulation_steps == 0:
339
- if self.use_mixed_precision:
340
- # Gradient clipping con scaler
341
- self.scaler.unscale_(self.optimizer)
342
- torch.nn.utils.clip_grad_norm_(
343
- self.model.parameters(),
344
- self.config['training']['max_grad_norm']
345
- )
346
-
347
- # Optimizer step
348
- self.scaler.step(self.optimizer)
349
- self.scaler.update()
350
- else:
351
- # Gradient clipping normal
352
- torch.nn.utils.clip_grad_norm_(
353
- self.model.parameters(),
354
- self.config['training']['max_grad_norm']
355
- )
356
-
357
- # Optimizer step
358
- self.optimizer.step()
359
-
360
- # Warmup
361
- if self.global_step < self.warmup_steps:
362
- lr = self.get_lr()
363
- for param_group in self.optimizer.param_groups:
364
- param_group['lr'] = lr
365
-
366
- # Scheduler
367
- if self.scheduler and self.global_step >= self.warmup_steps:
368
- self.scheduler.step()
369
-
370
- self.optimizer.zero_grad()
371
- self.global_step += 1
372
-
373
- total_loss += loss.item() * self.accumulation_steps
374
-
375
- # Progress bar
376
- postfix = {
377
- 'loss': f"{loss.item() * self.accumulation_steps:.4f}",
378
- 'lr': f"{self.get_lr():.6f}"
379
- }
380
-
381
- if confidence_samples > 0:
382
- postfix['conf'] = f"{total_confidence/confidence_samples:.3f}"
383
-
384
- if self.device.type == 'cuda' and batch_idx % 10 == 0:
385
- vram_used = torch.cuda.memory_allocated(0) / 1e9
386
- vram_total = torch.cuda.get_device_properties(0).total_memory / 1e9
387
- postfix['vram'] = f"{vram_used:.1f}/{vram_total:.1f}GB"
388
-
389
- progress_bar.set_postfix(postfix)
390
-
391
- avg_loss = total_loss / len(self.train_loader)
392
- avg_confidence = total_confidence / confidence_samples if confidence_samples > 0 else 0
393
-
394
- return avg_loss, avg_confidence
395
-
396
- def validate(self):
397
- """Validate model"""
398
- self.model.eval()
399
- total_loss = 0
400
- total_confidence = 0
401
- confidence_samples = 0
402
-
403
- with torch.no_grad():
404
- for input_ids, target_ids in self.val_loader:
405
- input_ids = input_ids.to(self.device, non_blocking=True)
406
- target_ids = target_ids.to(self.device, non_blocking=True)
407
-
408
- if self.model.use_confidence:
409
- logits, loss, confidence = self.model(
410
- input_ids, target_ids,
411
- return_confidence=True
412
- )
413
- mask = (target_ids != 0).float()
414
- avg_conf = (confidence * mask).sum() / mask.sum()
415
- total_confidence += avg_conf.item()
416
- confidence_samples += 1
417
- else:
418
- logits, loss = self.model(input_ids, target_ids)
419
-
420
- total_loss += loss.item()
421
-
422
- avg_loss = total_loss / len(self.val_loader)
423
- avg_confidence = total_confidence / confidence_samples if confidence_samples > 0 else 0
424
-
425
- return avg_loss, avg_confidence
426
-
427
- def train(self):
428
- """Main training loop mejorado"""
429
- print("\n" + "=" * 70)
430
- print("INICIANDO ENTRENAMIENTO AVANZADO")
431
- print("=" * 70)
432
-
433
- if self.device.type == 'cuda':
434
- print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
435
- print(f"🔥 VRAM Disponible: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
436
-
437
- epochs = self.config['training']['epochs']
438
- print(f"📊 Total epochs: {epochs}")
439
- print(f"📊 Train batches: {len(self.train_loader)}")
440
- print(f"📊 Val batches: {len(self.val_loader)}")
441
- print("=" * 70 + "\n")
442
-
443
- for epoch in range(self.start_epoch, epochs):
444
- train_loss, train_conf = self.train_epoch(epoch)
445
- val_loss, val_conf = self.validate()
446
-
447
- # Limpiar caché GPU
448
- if self.device.type == 'cuda':
449
- torch.cuda.empty_cache()
450
-
451
- # Mostrar resultados
452
- print(f"\n{'='*70}")
453
- print(f"Epoch {epoch+1}/{epochs} - Resultados")
454
- print(f"{'='*70}")
455
- print(f" Train Loss: {train_loss:.4f}")
456
- print(f" Val Loss: {val_loss:.4f}")
457
- print(f" Train Confidence: {train_conf:.3f}")
458
- print(f" Val Confidence: {val_conf:.3f}")
459
- print(f" Learning Rate: {self.get_lr():.6f}")
460
-
461
- if self.device.type == 'cuda':
462
- vram_used = torch.cuda.memory_allocated(0) / 1e9
463
- vram_peak = torch.cuda.max_memory_allocated(0) / 1e9
464
- print(f" VRAM Used: {vram_used:.2f} GB (peak: {vram_peak:.2f} GB)")
465
- torch.cuda.reset_peak_memory_stats()
466
-
467
- # Early stopping check
468
- improvement = self.best_val_loss - val_loss
469
-
470
- if improvement > self.min_delta:
471
- self.best_val_loss = val_loss
472
- self.patience_counter = 0
473
- self.save_checkpoint('best_model.pt', epoch + 1, is_best=True)
474
- print(f" ✅ ¡NUEVO MEJOR MODELO! (Val Loss: {val_loss:.4f})")
475
- else:
476
- self.patience_counter += 1
477
- print(f" ⏳ No improvement. Patience: {self.patience_counter}/{self.patience}")
478
-
479
- if self.patience_counter >= self.patience:
480
- print(f"\n⚠️ EARLY STOPPING - Mejor val loss: {self.best_val_loss:.4f}")
481
- break
482
-
483
- # Save periodic checkpoint
484
- if (epoch + 1) % self.config['training']['save_every'] == 0:
485
- self.save_checkpoint('checkpoint.pt', epoch + 1)
486
-
487
- print(f"{'='*70}\n")
488
-
489
- print("\n" + "=" * 70)
490
- print("ENTRENAMIENTO COMPLETADO")
491
- print(f"Mejor Val Loss: {self.best_val_loss:.4f}")
492
- print("=" * 70)
493
-
494
- # Load best model
495
- if os.path.exists('best_model.pt'):
496
- print("\n📦 Cargando mejor modelo...")
497
- checkpoint = torch.load('best_model.pt', map_location=self.device)
498
- self.model.load_state_dict(checkpoint['model_state_dict'])
499
- print("✅ Mejor modelo cargado")
500
-
501
- self.save_model()
502
-
503
- def save_checkpoint(self, path, epoch, is_best=False):
504
- """Save checkpoint"""
505
- checkpoint = {
506
- 'epoch': epoch,
507
- 'global_step': self.global_step,
508
- 'model_state_dict': self.model.state_dict(),
509
- 'optimizer_state_dict': self.optimizer.state_dict(),
510
- 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
511
- 'best_val_loss': self.best_val_loss,
512
- 'patience_counter': self.patience_counter,
513
- 'config': self.config,
514
- 'scaler_state_dict': self.scaler.state_dict() if self.use_mixed_precision else None
515
- }
516
- torch.save(checkpoint, path)
517
- if not is_best:
518
- print(f" 💾 Checkpoint guardado: {path}")
519
-
520
- def load_checkpoint(self, path):
521
- """Load checkpoint"""
522
- checkpoint = torch.load(path, map_location=self.device)
523
- self.model.load_state_dict(checkpoint['model_state_dict'])
524
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
525
- if self.scheduler and checkpoint.get('scheduler_state_dict'):
526
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
527
- if self.use_mixed_precision and checkpoint.get('scaler_state_dict'):
528
- self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
529
- self.start_epoch = checkpoint['epoch']
530
- self.global_step = checkpoint['global_step']
531
- self.best_val_loss = checkpoint.get('best_val_loss', float('inf'))
532
- self.patience_counter = checkpoint.get('patience_counter', 0)
533
- print(f" ✅ Resumido desde epoch {self.start_epoch}")
534
- print(f" ✅ Mejor val loss: {self.best_val_loss:.4f}")
535
-
536
- def save_model(self):
537
- """Save final model"""
538
- os.makedirs('output', exist_ok=True)
539
-
540
- # Mover modelo a CPU para guardar
541
- self.model.to('cpu')
542
-
543
- model_data = {
544
- 'model_state_dict': self.model.state_dict(),
545
- 'config': self.config,
546
- 'vocab_size': self.tokenizer.vocab_size(),
547
- 'tokenizer_path': self.tokenizer.model_path,
548
- 'training_info': {
549
- 'final_epoch': self.start_epoch,
550
- 'best_val_loss': self.best_val_loss,
551
- 'total_parameters': self.model.count_parameters()
552
- }
553
- }
554
-
555
- output_path = 'output/mtp_mini.pkl'
556
- with open(output_path, 'wb') as f:
557
- pickle.dump(model_data, f)
558
-
559
- file_size_mb = os.path.getsize(output_path) / (1024*1024)
560
-
561
- print(f"\n{'='*70}")
562
- print(f"✅ MODELO FINAL GUARDADO")
563
- print(f"{'='*70}")
564
- print(f"📁 Ruta: {output_path}")
565
- print(f"💾 Tamaño: {file_size_mb:.2f} MB")
566
- print(f"🧠 Parámetros: {self.model.count_parameters()/1e6:.1f}M")
567
- print(f"📊 Mejor Val Loss: {self.best_val_loss:.4f}")
568
- print(f"{'='*70}\n")
569
-
570
-
571
- if __name__ == '__main__':
572
- trainer = MTPTrainer('config.yaml')
573
- trainer.train()