| |
| """ |
| Main training script for LLM training on TPU v4-32. |
| Optimized for 128K token context length and 30-day training. |
| """ |
|
|
| import os |
| import sys |
| import time |
| import json |
| import argparse |
| import logging |
| import threading |
| import queue |
| from typing import Dict, Any, Optional, List, Tuple |
| import jax |
| import jax.numpy as jnp |
| import flax |
| import tensorflow as tf |
| import numpy as np |
| import sentencepiece as spm |
| from functools import partial |
|
|
| |
| logger = logging.getLogger(__name__) |
|
|
| |
| try: |
| import wandb |
| WANDB_AVAILABLE = True |
| except ImportError: |
| logger.warning("Weights & Biases not available. WandB logging will be disabled.") |
| WANDB_AVAILABLE = False |
|
|
| |
| from model.llm import LLM, LLMConfig |
| from data.tokenizer import SentencePieceTokenizer |
| from data.dataset import TextDataset, load_jsonl_dataset, StreamingDataset |
| from data.dataloader import TPUDataLoader |
| from training.trainer import Trainer, TrainingState, TrainingConfig as TrainerConfig |
| from training.optimizer import create_adamw_optimizer, create_lion_optimizer |
| from training.scheduler import create_linear_warmup_cosine_decay_schedule |
| from parallelism.data_parallel import DataParallel |
| from parallelism.tensor_parallel import TensorParallel |
| from config import create_config, Config |
| from utils.checkpoint import save_checkpoint, load_checkpoint |
| from utils.logging import setup_logger, log_metrics, create_summary_writer, log_metrics_to_tensorboard |
| from config import TrainingConfig, get_model_config |
|
|
|
|
| def parse_args(): |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser(description="Train LLM on TPU v4-32") |
|
|
| |
| parser.add_argument("--model_size", type=str, default="7b", choices=["7b", "13b", "70b", "175b", "600b"], |
| help="Model size") |
|
|
| |
| parser.add_argument("--learning_rate", type=float, default=3e-4, |
| help="Learning rate") |
| parser.add_argument("--batch_size", type=int, default=32, |
| help="Batch size per device") |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1, |
| help="Number of steps to accumulate gradients") |
| parser.add_argument("--max_steps", type=int, default=100000, |
| help="Maximum number of training steps") |
| parser.add_argument("--warmup_steps", type=int, default=1000, |
| help="Number of warmup steps") |
|
|
| |
| parser.add_argument("--train_file", type=str, required=True, |
| help="Path to training file or HuggingFace dataset name") |
| parser.add_argument("--eval_file", type=str, default="", |
| help="Path to evaluation file or HuggingFace dataset name") |
| parser.add_argument("--tokenizer_file", type=str, required=True, |
| help="Path to tokenizer file") |
| parser.add_argument("--max_seq_length", type=int, default=131072, |
| help="Maximum sequence length (default: 128K tokens)") |
| parser.add_argument("--use_streaming", action="store_true", default=True, |
| help="Use streaming dataset for efficient training") |
| parser.add_argument("--streaming_buffer_size", type=int, default=10000, |
| help="Buffer size for streaming dataset") |
| parser.add_argument("--text_column", type=str, default="text", |
| help="Name of text column in dataset") |
| parser.add_argument("--preprocessing_num_workers", type=int, default=16, |
| help="Number of workers for dataset preprocessing") |
|
|
| |
| parser.add_argument("--parallelism_type", type=str, default="data", choices=["data", "tensor"], |
| help="Type of parallelism") |
| parser.add_argument("--tensor_parallel_size", type=int, default=8, |
| help="Number of tensor parallel devices") |
|
|
| |
| parser.add_argument("--use_flash_attention", action="store_true", default=True, |
| help="Use flash attention for efficiency") |
| parser.add_argument("--use_gradient_checkpointing", action="store_true", default=True, |
| help="Use gradient checkpointing to save memory") |
|
|
| |
| parser.add_argument("--use_rope_scaling", action="store_true", default=True, |
| help="Use RoPE scaling for longer contexts") |
| parser.add_argument("--rope_scaling_factor", type=float, default=0.5, |
| help="Scaling factor for RoPE frequencies") |
|
|
| |
| parser.add_argument("--use_reasoning_layer", action="store_true", default=True, |
| help="Use additional reasoning layers") |
| parser.add_argument("--num_reasoning_layers", type=int, default=None, |
| help="Number of additional reasoning layers (overrides model config)") |
|
|
| |
| parser.add_argument("--output_dir", type=str, default="output", |
| help="Output directory") |
| parser.add_argument("--logging_steps", type=int, default=100, |
| help="Number of steps between logging") |
| parser.add_argument("--save_steps", type=int, default=1000, |
| help="Number of steps between checkpoints") |
| parser.add_argument("--eval_steps", type=int, default=1000, |
| help="Number of steps between evaluations") |
|
|
| |
| parser.add_argument("--use_wandb", action="store_true", default=True, |
| help="Use Weights & Biases for logging") |
| parser.add_argument("--wandb_project", type=str, default="llm-training", |
| help="Weights & Biases project name") |
| parser.add_argument("--wandb_entity", type=str, default=None, |
| help="Weights & Biases entity name") |
| parser.add_argument("--wandb_run_name", type=str, default=None, |
| help="Weights & Biases run name") |
| parser.add_argument("--log_memory_usage", action="store_true", default=True, |
| help="Log memory usage during training") |
| parser.add_argument("--profile_steps", type=int, default=100, |
| help="Number of steps between profiling") |
|
|
| |
| parser.add_argument("--seed", type=int, default=42, |
| help="Random seed") |
| parser.add_argument("--resume_from_checkpoint", type=str, default="", |
| help="Path to checkpoint to resume from") |
|
|
| return parser.parse_args() |
|
|
|
|
| def create_config(args): |
| """Create training configuration.""" |
| |
| model_config = get_model_config(args.model_size) |
|
|
| |
| if args.num_reasoning_layers is not None: |
| model_config.num_reasoning_layers = args.num_reasoning_layers |
|
|
| |
| model_config.use_flash_attention = args.use_flash_attention |
| model_config.use_gradient_checkpointing = args.use_gradient_checkpointing |
| model_config.use_rope_scaling = args.use_rope_scaling |
| model_config.rope_scaling_factor = args.rope_scaling_factor |
| model_config.use_reasoning_layer = args.use_reasoning_layer |
|
|
| |
| config = TrainingConfig( |
| output_dir=args.output_dir, |
| model_config=model_config, |
|
|
| |
| learning_rate=args.learning_rate, |
| batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| max_steps=args.max_steps, |
| warmup_steps=args.warmup_steps, |
|
|
| |
| train_file=args.train_file, |
| eval_file=args.eval_file, |
| tokenizer_file=args.tokenizer_file, |
| max_seq_length=args.max_seq_length, |
|
|
| |
| parallelism_type=args.parallelism_type, |
| tensor_parallel_size=args.tensor_parallel_size, |
|
|
| |
| use_flash_attention=args.use_flash_attention, |
| use_gradient_checkpointing=args.use_gradient_checkpointing, |
|
|
| |
| use_rope_scaling=args.use_rope_scaling, |
| rope_scaling_factor=args.rope_scaling_factor, |
|
|
| |
| use_reasoning_layer=args.use_reasoning_layer, |
| num_reasoning_layers=args.num_reasoning_layers if args.num_reasoning_layers is not None else model_config.num_reasoning_layers, |
| reasoning_intermediate_size=model_config.reasoning_intermediate_size, |
|
|
| |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| eval_steps=args.eval_steps, |
|
|
| |
| seed=args.seed |
| ) |
|
|
| return config |
|
|
|
|
| def setup_parallelism(config): |
| """Set up parallelism.""" |
| if config.parallelism_type == "data": |
| return DataParallel() |
| elif config.parallelism_type == "tensor": |
| return TensorParallel(num_tp=config.tensor_parallel_size) |
| else: |
| raise ValueError(f"Parallelism type {config.parallelism_type} not supported") |
|
|
|
|
| def create_model(config): |
| """Create model.""" |
| return LLM(config.model_config) |
|
|
|
|
| def create_optimizer(config, num_train_steps): |
| """Create optimizer.""" |
| |
| lr_schedule = create_linear_warmup_cosine_decay_schedule( |
| learning_rate=config.learning_rate, |
| warmup_steps=config.warmup_steps, |
| decay_steps=num_train_steps - config.warmup_steps, |
| final_learning_rate_factor=0.1 |
| ) |
|
|
| |
| if config.optimizer == "adamw": |
| return create_adamw_optimizer( |
| learning_rate=lr_schedule, |
| weight_decay=config.weight_decay, |
| b1=config.adam_beta1, |
| b2=config.adam_beta2, |
| eps=config.adam_epsilon |
| ) |
| elif config.optimizer == "lion": |
| return create_lion_optimizer( |
| learning_rate=lr_schedule, |
| weight_decay=config.weight_decay, |
| b1=config.adam_beta1, |
| b2=config.adam_beta2 |
| ) |
| else: |
| raise ValueError(f"Optimizer {config.optimizer} not supported") |
|
|
|
|
| def create_train_state(config, model, optimizer, rng): |
| """Create training state.""" |
| |
| dummy_input = jnp.ones((1, 1), dtype=jnp.int32) |
|
|
| |
| params_rng, dropout_rng = jax.random.split(rng) |
| params = model.init(params_rng, dummy_input) |
|
|
| |
| return TrainingState.create( |
| apply_fn=model.apply, |
| params=params, |
| tx=optimizer, |
| dropout_rng=dropout_rng, |
| loss_scale=1.0 |
| ) |
|
|
|
|
| def load_tokenizer(config): |
| """Load tokenizer.""" |
| return SentencePieceTokenizer(config.tokenizer_file) |
|
|
|
|
| def load_dataset(config, tokenizer): |
| """Load dataset with streaming support for efficient training.""" |
| |
| if config.use_streaming: |
| logger.info(f"Loading streaming dataset from {config.train_file}") |
| train_dataset = StreamingDataset( |
| tokenizer=tokenizer, |
| dataset_path=config.train_file, |
| max_seq_length=config.max_seq_length, |
| streaming=True, |
| buffer_size=config.streaming_buffer_size, |
| seed=config.seed, |
| text_column=config.text_column, |
| preprocessing_num_workers=config.preprocessing_num_workers |
| ) |
| logger.info("Streaming dataset loaded successfully") |
| else: |
| logger.info(f"Loading standard dataset from {config.train_file}") |
| train_dataset = load_jsonl_dataset( |
| file_path=config.train_file, |
| tokenizer=tokenizer, |
| max_length=config.max_seq_length |
| ) |
| logger.info(f"Dataset loaded with {len(train_dataset)} examples") |
|
|
| |
| eval_dataset = None |
| if config.eval_file: |
| if config.use_streaming: |
| logger.info(f"Loading streaming evaluation dataset from {config.eval_file}") |
| eval_dataset = StreamingDataset( |
| tokenizer=tokenizer, |
| dataset_path=config.eval_file, |
| max_seq_length=config.max_seq_length, |
| streaming=False, |
| buffer_size=config.streaming_buffer_size, |
| seed=config.seed, |
| text_column=config.text_column, |
| preprocessing_num_workers=config.preprocessing_num_workers |
| ) |
| logger.info("Streaming evaluation dataset loaded successfully") |
| else: |
| logger.info(f"Loading standard evaluation dataset from {config.eval_file}") |
| eval_dataset = load_jsonl_dataset( |
| file_path=config.eval_file, |
| tokenizer=tokenizer, |
| max_length=config.max_seq_length |
| ) |
| logger.info(f"Evaluation dataset loaded with {len(eval_dataset)} examples") |
|
|
| return train_dataset, eval_dataset |
|
|
|
|
| def create_data_loaders(config, train_dataset, eval_dataset, tokenizer): |
| """Create data loaders.""" |
| |
| train_loader = TPUDataLoader( |
| dataset=train_dataset, |
| batch_size=config.batch_size, |
| shuffle=True, |
| drop_last=True, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
|
|
| |
| eval_loader = None |
| if eval_dataset is not None: |
| eval_loader = TPUDataLoader( |
| dataset=eval_dataset, |
| batch_size=config.batch_size, |
| shuffle=False, |
| drop_last=False, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
|
|
| return train_loader, eval_loader |
|
|
|
|
| def main(): |
| """Main function optimized for TPU v4-32.""" |
| |
| args = parse_args() |
|
|
| |
| print("TPU Configuration:") |
| print(f"Number of TPU devices: {jax.device_count()}") |
| print(f"TPU devices: {jax.devices()}") |
| print(f"JAX process index: {jax.process_index()}") |
| print(f"JAX process count: {jax.process_count()}") |
| print(f"JAX local devices: {jax.local_devices()}") |
| print(f"JAX local device count: {jax.local_device_count()}") |
|
|
| |
| config = create_config(args) |
|
|
| |
| os.makedirs(config.output_dir, exist_ok=True) |
|
|
| |
| logger = setup_logger( |
| name="tpu_train", |
| log_file=os.path.join(config.output_dir, "train.log") |
| ) |
|
|
| |
| logger.info(f"Configuration: {config}") |
|
|
| |
| if args.use_wandb and WANDB_AVAILABLE: |
| logger.info("Initializing Weights & Biases") |
| wandb_run_name = args.wandb_run_name or f"{args.model_size}-{time.strftime('%Y%m%d-%H%M%S')}" |
| wandb.init( |
| project=args.wandb_project, |
| entity=args.wandb_entity, |
| name=wandb_run_name, |
| config={ |
| "model_size": args.model_size, |
| "learning_rate": args.learning_rate, |
| "batch_size": args.batch_size, |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, |
| "max_steps": args.max_steps, |
| "warmup_steps": args.warmup_steps, |
| "max_seq_length": args.max_seq_length, |
| "parallelism_type": args.parallelism_type, |
| "tensor_parallel_size": args.tensor_parallel_size, |
| "use_flash_attention": args.use_flash_attention, |
| "use_gradient_checkpointing": args.use_gradient_checkpointing, |
| "use_rope_scaling": args.use_rope_scaling, |
| "rope_scaling_factor": args.rope_scaling_factor, |
| "use_reasoning_layer": args.use_reasoning_layer, |
| "num_reasoning_layers": args.num_reasoning_layers, |
| "use_streaming": args.use_streaming, |
| "streaming_buffer_size": args.streaming_buffer_size, |
| "text_column": args.text_column, |
| "preprocessing_num_workers": args.preprocessing_num_workers, |
| "seed": args.seed, |
| } |
| ) |
| logger.info(f"Weights & Biases initialized with run name: {wandb_run_name}") |
| elif args.use_wandb and not WANDB_AVAILABLE: |
| logger.warning("Weights & Biases not available. Install wandb package to enable logging.") |
| else: |
| logger.info("Weights & Biases logging disabled.") |
|
|
| |
| logger.info(f"Training on TPU v4-32 with {jax.device_count()} devices") |
| logger.info(f"Model size: {args.model_size} ({config.model_config.hidden_size} hidden size, " |
| f"{config.model_config.num_hidden_layers} layers)") |
| logger.info(f"Max sequence length: {args.max_seq_length} tokens") |
| logger.info(f"Batch size: {args.batch_size} per device, {args.batch_size * jax.device_count()} global") |
| logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}") |
| logger.info(f"Effective batch size: {args.batch_size * jax.device_count() * args.gradient_accumulation_steps}") |
| logger.info(f"Learning rate: {args.learning_rate}") |
| logger.info(f"Warmup steps: {args.warmup_steps}") |
| logger.info(f"Max steps: {args.max_steps}") |
| logger.info(f"Parallelism type: {args.parallelism_type}") |
| logger.info(f"Tensor parallel size: {args.tensor_parallel_size}") |
| logger.info(f"Using streaming dataset: {args.use_streaming}") |
| logger.info(f"Using flash attention: {args.use_flash_attention}") |
| logger.info(f"Using gradient checkpointing: {args.use_gradient_checkpointing}") |
| logger.info(f"Using RoPE scaling: {args.use_rope_scaling}") |
| logger.info(f"RoPE scaling factor: {args.rope_scaling_factor}") |
| logger.info(f"Using reasoning layer: {args.use_reasoning_layer}") |
| logger.info(f"Number of reasoning layers: {config.model_config.num_reasoning_layers}") |
| logger.info(f"Random seed: {args.seed}") |
| logger.info(f"Output directory: {args.output_dir}") |
| logger.info(f"Logging steps: {args.logging_steps}") |
| logger.info(f"Save steps: {args.save_steps}") |
| logger.info(f"Eval steps: {args.eval_steps}") |
| logger.info(f"Profile steps: {args.profile_steps}") |
| logger.info(f"Using Weights & Biases: {args.use_wandb and WANDB_AVAILABLE}") |
| logger.info(f"Logging memory usage: {args.log_memory_usage}") |
|
|
| |
| param_count = ( |
| |
| config.model_config.vocab_size * config.model_config.hidden_size + |
| |
| config.model_config.num_hidden_layers * ( |
| |
| 4 * config.model_config.hidden_size * config.model_config.hidden_size + |
| |
| 2 * config.model_config.hidden_size * config.model_config.intermediate_size + |
| |
| 4 * config.model_config.hidden_size |
| ) + |
| |
| (config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * ( |
| |
| 4 * config.model_config.hidden_size * config.model_config.hidden_size + |
| |
| 2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size + |
| |
| 4 * config.model_config.hidden_size |
| ) + |
| |
| config.model_config.hidden_size + |
| |
| config.model_config.hidden_size * config.model_config.vocab_size |
| ) |
|
|
| |
| logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters") |
|
|
| |
| bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4 |
| model_size_gb = param_count * bytes_per_param / 1e9 |
| optimizer_size_gb = model_size_gb * 2 |
| activation_size_gb = model_size_gb * 0.2 |
| total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb |
|
|
| |
| logger.info(f"Estimated memory requirements:") |
| logger.info(f" Model parameters: {model_size_gb:.2f} GB") |
| logger.info(f" Optimizer states: {optimizer_size_gb:.2f} GB") |
| logger.info(f" Activations: {activation_size_gb:.2f} GB") |
| logger.info(f" Total: {total_memory_gb:.2f} GB") |
|
|
| |
| tpu_memory_gb = 32 * jax.device_count() |
| logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB") |
| if total_memory_gb > tpu_memory_gb * 0.9: |
| logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)") |
| logger.warning("Consider enabling gradient checkpointing and using a smaller batch size") |
|
|
| |
| param_count = ( |
| |
| config.model_config.vocab_size * config.model_config.hidden_size + |
| |
| config.model_config.num_hidden_layers * ( |
| |
| 4 * config.model_config.hidden_size * config.model_config.hidden_size + |
| |
| 2 * config.model_config.hidden_size * config.model_config.intermediate_size + |
| |
| 4 * config.model_config.hidden_size |
| ) + |
| |
| (config.model_config.use_reasoning_layer and config.model_config.num_reasoning_layers) * ( |
| |
| 4 * config.model_config.hidden_size * config.model_config.hidden_size + |
| |
| 2 * config.model_config.hidden_size * config.model_config.reasoning_intermediate_size + |
| |
| 4 * config.model_config.hidden_size |
| ) + |
| |
| config.model_config.hidden_size + |
| |
| config.model_config.hidden_size * config.model_config.vocab_size |
| ) |
|
|
| |
| logger.info(f"Approximate parameter count: {param_count / 1e9:.2f} billion parameters") |
|
|
| |
| bytes_per_param = 2 if config.dtype == jnp.bfloat16 else 4 |
| model_size_gb = param_count * bytes_per_param / 1e9 |
| optimizer_size_gb = model_size_gb * 2 |
| activation_size_gb = model_size_gb * 0.2 |
| total_memory_gb = model_size_gb + optimizer_size_gb + activation_size_gb |
|
|
| |
| logger.info(f"Estimated memory requirements:") |
| logger.info(f" Model parameters: {model_size_gb:.2f} GB") |
| logger.info(f" Optimizer states: {optimizer_size_gb:.2f} GB") |
| logger.info(f" Activations: {activation_size_gb:.2f} GB") |
| logger.info(f" Total: {total_memory_gb:.2f} GB") |
|
|
| |
| tpu_memory_gb = 32 * jax.device_count() |
| logger.info(f"Available TPU memory: {tpu_memory_gb:.2f} GB") |
| if total_memory_gb > tpu_memory_gb * 0.9: |
| logger.warning(f"Memory requirements ({total_memory_gb:.2f} GB) may exceed available TPU memory ({tpu_memory_gb:.2f} GB)") |
| logger.warning("Consider enabling gradient checkpointing and using a smaller batch size") |
|
|
| |
| rng = jax.random.PRNGKey(config.seed) |
|
|
| |
| start_time = time.time() |
|
|
| |
| parallel = setup_parallelism(config) |
|
|
| |
| model = create_model(config) |
| logger.info(f"Model created in {time.time() - start_time:.2f} seconds") |
|
|
| |
| optimizer = create_optimizer(config, config.max_steps) |
|
|
| |
| state_start_time = time.time() |
| state = create_train_state(config, model, optimizer, rng) |
| logger.info(f"Training state created in {time.time() - state_start_time:.2f} seconds") |
|
|
| |
| shard_start_time = time.time() |
| state = state.replace(params=parallel.shard_params(state.params)) |
| logger.info(f"Parameters sharded in {time.time() - shard_start_time:.2f} seconds") |
|
|
| |
| if args.resume_from_checkpoint: |
| checkpoint_start_time = time.time() |
| state, step = load_checkpoint(args.resume_from_checkpoint, state) |
| logger.info(f"Checkpoint loaded in {time.time() - checkpoint_start_time:.2f} seconds") |
|
|
| |
| tokenizer_start_time = time.time() |
| tokenizer = load_tokenizer(config) |
| logger.info(f"Tokenizer loaded in {time.time() - tokenizer_start_time:.2f} seconds") |
|
|
| |
| dataset_start_time = time.time() |
| train_dataset, eval_dataset = load_dataset(config, tokenizer) |
| logger.info(f"Datasets loaded in {time.time() - dataset_start_time:.2f} seconds") |
|
|
| |
| dataloader_start_time = time.time() |
| train_loader, eval_loader = create_data_loaders( |
| config, |
| train_dataset, |
| eval_dataset, |
| tokenizer |
| ) |
| logger.info(f"Data loaders created in {time.time() - dataloader_start_time:.2f} seconds") |
|
|
| |
| summary_writer = create_summary_writer( |
| os.path.join(config.output_dir, "tensorboard") |
| ) |
|
|
| |
| trainer_config = TrainerConfig( |
| model_config=config.model_config, |
| learning_rate=config.learning_rate, |
| weight_decay=config.weight_decay, |
| warmup_steps=config.warmup_steps, |
| max_steps=config.max_steps, |
| batch_size=config.batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| max_grad_norm=config.max_grad_norm, |
| adam_beta1=config.adam_beta1, |
| adam_beta2=config.adam_beta2, |
| adam_epsilon=config.adam_epsilon, |
| logging_steps=config.logging_steps, |
| save_steps=config.save_steps, |
| eval_steps=config.eval_steps, |
| output_dir=config.output_dir, |
| seed=config.seed, |
| dtype=config.dtype, |
| |
| use_pjit=True, |
| use_scan=True, |
| use_remat=config.model_config.use_gradient_checkpointing, |
| use_sharded_optim=True, |
| profile_steps=100, |
| async_checkpointing=True, |
| ) |
|
|
| |
| trainer = Trainer( |
| config=trainer_config, |
| model=model, |
| train_dataloader=train_loader, |
| eval_dataloader=eval_loader, |
| state=state, |
| parallel=parallel, |
| ) |
|
|
| |
| logger.info(f"Total initialization time: {time.time() - start_time:.2f} seconds") |
|
|
| |
| steps_per_day = 24 * 60 * 60 / (5 * 60) |
| estimated_days = config.max_steps / steps_per_day |
| logger.info(f"Estimated training time: {estimated_days:.2f} days for {config.max_steps} steps") |
|
|
| |
| try: |
| train_start_time = time.time() |
| trainer.train() |
| train_duration = time.time() - train_start_time |
| logger.info(f"Training completed in {train_duration / 3600:.2f} hours") |
| logger.info(f"Average training speed: {config.max_steps / train_duration:.2f} steps/second") |
| except Exception as e: |
| logger.error(f"Training failed with error: {e}") |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|