""" PyPilot Training Manager - Advanced distributed training with monitoring """ import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from transformers import TrainingArguments, Trainer, EarlyStoppingCallback import wandb import numpy as np import time from datetime import datetime import os class CodeDataset(Dataset): def __init__(self, tokenized_data): self.data = tokenized_data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class PyPilotTrainingManager: def __init__(self, model, model_name="PyPilot"): self.model = model self.model_name = model_name self.training_history = [] self.best_loss = float('inf') def setup_distributed_training(self, use_fp16=True, use_gradient_checkpointing=True): """Configure distributed training options""" training_args = TrainingArguments( output_dir=f"./pypilot-checkpoints", overwrite_output_dir=True, num_train_epochs=10, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=8, learning_rate=5e-5, weight_decay=0.01, warmup_steps=1000, logging_dir="./logs", logging_steps=500, eval_steps=1000, save_steps=2000, save_total_limit=5, prediction_loss_only=True, remove_unused_columns=False, fp16=use_fp16, dataloader_pin_memory=False, gradient_checkpointing=use_gradient_checkpointing, report_to=["wandb"], run_name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}", ) return training_args def setup_wandb_monitoring(self, project_name="pypilot"): """Initialize Weights & Biases for experiment tracking""" wandb.init( project=project_name, name=f"pypilot-{datetime.now().strftime('%Y-%m-%d-%H-%M')}", config={ "architecture": "Transformer", "dataset": "GitHub Code", "epochs": 10, "batch_size": 32, } ) def create_advanced_callbacks(self): """Create callbacks for training optimization""" callbacks = [ EarlyStoppingCallback(early_stopping_patience=3), ] return callbacks def compute_metrics(self, eval_pred): """Compute advanced metrics for code generation""" predictions, labels = eval_pred predictions = torch.tensor(predictions) labels = torch.tensor(labels) # Calculate perplexity loss_fct = nn.CrossEntropyLoss() loss = loss_fct(predictions.view(-1, predictions.size(-1)), labels.view(-1)) perplexity = torch.exp(loss) # Calculate accuracy preds = torch.argmax(predictions, dim=-1) accuracy = (preds == labels).float().mean() return { "perplexity": perplexity.item(), "accuracy": accuracy.item(), "loss": loss.item() } def train_with_advanced_features(self, train_dataset, eval_dataset=None): """Start advanced training with all features""" print("🚀 Starting Advanced PyPilot Training...") # Setup monitoring self.setup_wandb_monitoring() # Configure training training_args = self.setup_distributed_training() callbacks = self.create_advanced_callbacks() # Create trainer trainer = Trainer( model=self.model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=self.compute_metrics, callbacks=callbacks, ) # Start training print("🎯 Training started with advanced features:") print(f" - FP16 Precision: Enabled") print(f" - Gradient Checkpointing: Enabled") print(f" - Early Stopping: Enabled") print(f" - W&B Monitoring: Enabled") trainer.train() # Save final model trainer.save_model("./pypilot-final-model") print("✅ Training completed and model saved!") return trainer def hyperparameter_search(self, train_dataset, param_combinations): """Perform hyperparameter search""" best_params = None for i, params in enumerate(param_combinations): print(f"🔍 Testing hyperparameter combination {i+1}/{len(param_combinations)}") # Update model with new params self.update_model_hyperparams(params) # Quick training run to evaluate quick_trainer = Trainer( model=self.model, args=TrainingArguments( output_dir=f"./hparam-search-{i}", num_train_epochs=1, per_device_train_batch_size=params['batch_size'], learning_rate=params['learning_rate'], ), train_dataset=train_dataset, ) results = quick_trainer.train() if results.training_loss < self.best_loss: self.best_loss = results.training_loss best_params = params print(f"🎯 Best hyperparameters: {best_params}") return best_params if __name__ == "__main__": # Example usage from modeling_pypilot import PyPilotModel, PyPilotConfig config = PyPilotConfig() model = PyPilotModel(config) manager = PyPilotTrainingManager(model) print("✅ Training Manager ready!")