prelington commited on
Commit
ff17dd4
Β·
verified Β·
1 Parent(s): 3a7ae3c

Create training_manager.py

Browse files
Files changed (1) hide show
  1. training_manager.py +174 -0
training_manager.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyPilot Training Manager - Advanced distributed training with monitoring
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
8
+ import wandb
9
+ import numpy as np
10
+ import time
11
+ from datetime import datetime
12
+ import os
13
+
14
+ class CodeDataset(Dataset):
15
+ def __init__(self, tokenized_data):
16
+ self.data = tokenized_data
17
+
18
+ def __len__(self):
19
+ return len(self.data)
20
+
21
+ def __getitem__(self, idx):
22
+ return self.data[idx]
23
+
24
+ class PyPilotTrainingManager:
25
+ def __init__(self, model, model_name="PyPilot"):
26
+ self.model = model
27
+ self.model_name = model_name
28
+ self.training_history = []
29
+ self.best_loss = float('inf')
30
+
31
+ def setup_distributed_training(self, use_fp16=True, use_gradient_checkpointing=True):
32
+ """Configure distributed training options"""
33
+ training_args = TrainingArguments(
34
+ output_dir=f"./pypilot-checkpoints",
35
+ overwrite_output_dir=True,
36
+ num_train_epochs=10,
37
+ per_device_train_batch_size=4,
38
+ per_device_eval_batch_size=4,
39
+ gradient_accumulation_steps=8,
40
+ learning_rate=5e-5,
41
+ weight_decay=0.01,
42
+ warmup_steps=1000,
43
+ logging_dir="./logs",
44
+ logging_steps=500,
45
+ eval_steps=1000,
46
+ save_steps=2000,
47
+ save_total_limit=5,
48
+ prediction_loss_only=True,
49
+ remove_unused_columns=False,
50
+ fp16=use_fp16,
51
+ dataloader_pin_memory=False,
52
+ gradient_checkpointing=use_gradient_checkpointing,
53
+ report_to=["wandb"],
54
+ run_name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
55
+ )
56
+ return training_args
57
+
58
+ def setup_wandb_monitoring(self, project_name="pypilot"):
59
+ """Initialize Weights & Biases for experiment tracking"""
60
+ wandb.init(
61
+ project=project_name,
62
+ name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}",
63
+ config={
64
+ "architecture": "Transformer",
65
+ "dataset": "GitHub Code",
66
+ "epochs": 10,
67
+ "batch_size": 32,
68
+ }
69
+ )
70
+
71
+ def create_advanced_callbacks(self):
72
+ """Create callbacks for training optimization"""
73
+ callbacks = [
74
+ EarlyStoppingCallback(early_stopping_patience=3),
75
+ ]
76
+ return callbacks
77
+
78
+ def compute_metrics(self, eval_pred):
79
+ """Compute advanced metrics for code generation"""
80
+ predictions, labels = eval_pred
81
+ predictions = torch.tensor(predictions)
82
+ labels = torch.tensor(labels)
83
+
84
+ # Calculate perplexity
85
+ loss_fct = nn.CrossEntropyLoss()
86
+ loss = loss_fct(predictions.view(-1, predictions.size(-1)), labels.view(-1))
87
+ perplexity = torch.exp(loss)
88
+
89
+ # Calculate accuracy
90
+ preds = torch.argmax(predictions, dim=-1)
91
+ accuracy = (preds == labels).float().mean()
92
+
93
+ return {
94
+ "perplexity": perplexity.item(),
95
+ "accuracy": accuracy.item(),
96
+ "loss": loss.item()
97
+ }
98
+
99
+ def train_with_advanced_features(self, train_dataset, eval_dataset=None):
100
+ """Start advanced training with all features"""
101
+ print("πŸš€ Starting Advanced PyPilot Training...")
102
+
103
+ # Setup monitoring
104
+ self.setup_wandb_monitoring()
105
+
106
+ # Configure training
107
+ training_args = self.setup_distributed_training()
108
+ callbacks = self.create_advanced_callbacks()
109
+
110
+ # Create trainer
111
+ trainer = Trainer(
112
+ model=self.model,
113
+ args=training_args,
114
+ train_dataset=train_dataset,
115
+ eval_dataset=eval_dataset,
116
+ compute_metrics=self.compute_metrics,
117
+ callbacks=callbacks,
118
+ )
119
+
120
+ # Start training
121
+ print("🎯 Training started with advanced features:")
122
+ print(f" - FP16 Precision: Enabled")
123
+ print(f" - Gradient Checkpointing: Enabled")
124
+ print(f" - Early Stopping: Enabled")
125
+ print(f" - W&B Monitoring: Enabled")
126
+
127
+ trainer.train()
128
+
129
+ # Save final model
130
+ trainer.save_model("./pypilot-final-model")
131
+ print("βœ… Training completed and model saved!")
132
+
133
+ return trainer
134
+
135
+ def hyperparameter_search(self, train_dataset, param_combinations):
136
+ """Perform hyperparameter search"""
137
+ best_params = None
138
+
139
+ for i, params in enumerate(param_combinations):
140
+ print(f"πŸ” Testing hyperparameter combination {i+1}/{len(param_combinations)}")
141
+
142
+ # Update model with new params
143
+ self.update_model_hyperparams(params)
144
+
145
+ # Quick training run to evaluate
146
+ quick_trainer = Trainer(
147
+ model=self.model,
148
+ args=TrainingArguments(
149
+ output_dir=f"./hparam-search-{i}",
150
+ num_train_epochs=1,
151
+ per_device_train_batch_size=params['batch_size'],
152
+ learning_rate=params['learning_rate'],
153
+ ),
154
+ train_dataset=train_dataset,
155
+ )
156
+
157
+ results = quick_trainer.train()
158
+
159
+ if results.training_loss < self.best_loss:
160
+ self.best_loss = results.training_loss
161
+ best_params = params
162
+
163
+ print(f"🎯 Best hyperparameters: {best_params}")
164
+ return best_params
165
+
166
+ if __name__ == "__main__":
167
+ # Example usage
168
+ from modeling_pypilot import PyPilotModel, PyPilotConfig
169
+
170
+ config = PyPilotConfig()
171
+ model = PyPilotModel(config)
172
+
173
+ manager = PyPilotTrainingManager(model)
174
+ print("βœ… Training Manager ready!")