#!/usr/bin/env python3 """ SEM V6 PyTorch Lightning Training Script Production-ready training with all best practices: - Automatic mixed precision (AMP) - Gradient accumulation - Learning rate scheduling - Model checkpointing - Progress bars and rich logging - Distributed training support - Early stopping - Gradient clipping - Learning rate finder - Model pruning/quantization ready - TensorBoard logging - Memory-efficient data streaming with prefetching Reference: - PyTorch Lightning 2.x best practices - AGENTS.md: GPU-first, streaming data, maximum speed """ import argparse import logging import math import time from pathlib import Path from typing import Any, Dict, Optional, Tuple import torch # Global timing - track wall-clock from absolute start _SCRIPT_START_TIME = time.perf_counter() import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl import datasets # For HF streaming from pytorch_lightning.callbacks import ( ModelCheckpoint, EarlyStopping, LearningRateMonitor, RichProgressBar, RichModelSummary, ) from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader # SEM V6 imports import sys sys.path.insert(0, "src") sys.path.insert(0, "ChebyKan_cuda_op") from sem_v6.sem_v6 import SEMV6 from sem_v6.data.streaming_dataset import StreamingDataset from sem_v6.modules.text_decoder import multi_token_prediction_loss from sem_v6.validation.callbacks import CombinedValidationCallback from sem_v6.validation.validators import FastValidator, GrammarValidator from sem_v6.validation.grammar_checker import LanguageToolClient # Enable Tensor Core acceleration (RTX 3060 and above) torch.set_float32_matmul_precision('high') class TimingCallback(pl.Callback): """ Callback to track wall-clock timing for 100s training target. Tracks: - Time to first batch (includes data loading startup) - Training time per step - Total elapsed time from script start """ def __init__(self, script_start_time: float, target_time: float = 100.0): super().__init__() self.script_start_time = script_start_time self.target_time = target_time self.fit_start_time: Optional[float] = None self.first_batch_time: Optional[float] = None self.training_start_time: Optional[float] = None self.step_count = 0 self.total_training_time = 0.0 def on_fit_start(self, trainer, pl_module): self.fit_start_time = time.perf_counter() elapsed = self.fit_start_time - self.script_start_time print(f"\n⏱️ [TIMING] Model initialization took: {elapsed:.2f}s") remaining = self.target_time - elapsed print(f"⏱️ [TIMING] Time budget remaining: {remaining:.2f}s / {self.target_time}s") def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if self.first_batch_time is None: self.first_batch_time = time.perf_counter() data_load_time = self.first_batch_time - self.fit_start_time elapsed = self.first_batch_time - self.script_start_time print(f"⏱️ [TIMING] Data loading/first batch took: {data_load_time:.2f}s") print(f"⏱️ [TIMING] Total startup overhead: {elapsed:.2f}s") remaining = self.target_time - elapsed print(f"⏱️ [TIMING] Time budget for training: {remaining:.2f}s") self.training_start_time = time.perf_counter() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.step_count += 1 elapsed = time.perf_counter() - self.script_start_time remaining = self.target_time - elapsed # Time-based early stopping: stop with 2s buffer before target if remaining <= 2.0: print(f"\n⏱️ [TIMING] Stopping training - approaching time limit ({elapsed:.1f}s / {self.target_time}s)") trainer.should_stop = True return # Check time budget every 50 steps if self.step_count % 50 == 0: training_elapsed = time.perf_counter() - self.training_start_time steps_per_sec = self.step_count / training_elapsed if training_elapsed > 0 else 0 print(f"⏱️ [TIMING] Step {self.step_count}: {elapsed:.1f}s elapsed, {remaining:.1f}s remaining, {steps_per_sec:.2f} steps/s") def on_fit_end(self, trainer, pl_module): end_time = time.perf_counter() total_elapsed = end_time - self.script_start_time training_time = end_time - self.training_start_time if self.training_start_time else 0 startup_time = total_elapsed - training_time print("\n" + "=" * 80) print("⏱️ TIMING SUMMARY (100s Target)") print("=" * 80) print(f" Startup (init + data load): {startup_time:.2f}s") print(f" Training time: {training_time:.2f}s") print(f" Total wall-clock time: {total_elapsed:.2f}s") print("-" * 80) print(f" Steps completed: {self.step_count}") if training_time > 0: print(f" Training speed: {self.step_count / training_time:.2f} steps/s") if total_elapsed <= self.target_time: print(f"\n ✅ PASSED: Completed in {total_elapsed:.2f}s (under {self.target_time}s target)") else: overage = total_elapsed - self.target_time print(f"\n ❌ FAILED: Took {total_elapsed:.2f}s ({overage:.2f}s over {self.target_time}s target)") print("=" * 80) def latent_prediction_loss( z_pred: torch.Tensor, x_next: torch.Tensor, compressor: nn.Module, lambda_sparse: float = 0.001, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute loss in latent space (not SDR reconstruction). Based on Lester et al. 2024 "Training LLMs over Neurally Compressed Text": predicting compressed representation is more learnable than reconstructing the full SDR. This is the core training objective for SEM V6. Loss function: L = ||z_pred - compress(x_next)||² + λ||z_pred||₁ Where: - z_pred: Output from propagator (predicted next latent) - x_next: Ground truth next SDR - compress(): NeuralCompressor module - λ: Sparsity coefficient (default 0.001) Args: z_pred: Predicted next latent (batch, latent_dim) from propagator x_next: Ground truth next SDR (batch, sdr_dim) - will be compressed compressor: NeuralCompressor module to compress x_next lambda_sparse: L1 sparsity coefficient for regularization. Default: 0.001 Returns: loss: Total loss scalar (MSE + L1) metrics: Dict with individual loss components: - mse_loss: Mean squared error between predicted and target latent - l1_loss: L1 regularization on predicted latent - total_loss: Combined loss - cos_sim: Cosine similarity for monitoring prediction quality Example: >>> compressor = nn.Linear(2048, 256) # Simplified compressor >>> z_pred = torch.randn(32, 256) >>> x_next = torch.randn(32, 2048) >>> loss, metrics = latent_prediction_loss(z_pred, x_next, compressor) >>> loss.backward() >>> metrics["mse_loss"] # Individual MSE component >>> metrics["cos_sim"] # Prediction quality metric References: Lester, B., et al. (2024). "Training LLMs over Neurally Compressed Text." arXiv:2404.03626 """ # Compress ground truth to latent space (stop gradient - target only) with torch.no_grad(): z_target = compressor(x_next.float()) # NaN safety for z_target z_target = torch.nan_to_num(z_target, nan=0.0, posinf=100.0, neginf=-100.0) # MSE in latent space - core prediction objective mse_loss = F.mse_loss(z_pred, z_target) # L1 sparsity regularization - encourages sparse latent representations l1_loss = lambda_sparse * z_pred.abs().mean() # Total loss total_loss = mse_loss + l1_loss # Cosine similarity for monitoring prediction quality (not in loss) with torch.no_grad(): cos_sim = F.cosine_similarity(z_pred, z_target, dim=-1).mean() metrics = { "mse_loss": mse_loss.item(), "l1_loss": l1_loss.item(), "total_loss": total_loss.item(), "cos_sim": cos_sim.item(), } return total_loss, metrics class SEMV6LightningModule(pl.LightningModule): """ PyTorch Lightning wrapper for SEM V6 training. Implements all best practices: - Automatic optimization with gradient accumulation - Learning rate scheduling - Logging and metrics - Checkpoint management - Mixed precision training """ def __init__( self, sdr_dim: int = 1368, sparsity: float = 0.05, num_hyperedges: int = 2000, kan_degree: int = 5, latent_dim: int = 256, lambda_sparse: float = 0.001, learning_rate: float = 1e-3, weight_decay: float = 1e-4, warmup_steps: int = 1000, max_steps: int = 100000, compile_model: bool = True, lr_schedule: str = "onecycle", warmup_pct: float = 0.1, lr_div_factor: float = 25.0, lr_final_div_factor: float = 10000.0, ode_solver: str = "rk4", ode_step_size: float = 0.01, enable_text_decoder: bool = True, num_predict: int = 4, mtp_loss_weight: float = 1.0, ): super().__init__() # Save hyperparameters for checkpointing self.save_hyperparameters() # Initialize SEM V6 model with latent space architecture self.model = SEMV6( sdr_dim=sdr_dim, sparsity=sparsity, num_hyperedges=num_hyperedges, kan_degree=kan_degree, latent_dim=latent_dim, ode_solver=ode_solver, ode_step_size=ode_step_size, enable_text_decoder=enable_text_decoder, num_predict=num_predict, ) # Compile model for faster execution (PyTorch 2.0+) if compile_model: self.model = torch.compile(self.model, mode='reduce-overhead') # Training metrics self.train_loss = [] self.val_loss = [] def forward(self, text_batch, reward=None, mode="awake"): """Forward pass through SEM V6.""" return self.model(text_batch, reward=reward, mode=mode) def training_step(self, batch, batch_idx): """ Training step with latent-space next-token prediction. Based on Lester et al. 2024: predict latent representation of NEXT token, not reconstruct full SDR. This is more efficient and learnable. Loss function: L = ||z_pred - compress(x_next)||² + λ||z_pred||₁ Lightning handles: - Gradient accumulation - Mixed precision - Gradient clipping - Optimizer stepping """ # Unpack batch (text samples) - need pairs for next-token prediction text_batch = batch # Split each document into sentences and create (sentence_i, sentence_i+1) pairs # This fixes the cross-document bug where unrelated documents were paired current_texts = [] next_texts = [] for document in text_batch: # Split into sentences (on '. ' and '\n') sentences = [] for chunk in document.replace('\n', '. ').split('. '): sentence = chunk.strip() if len(sentence) >= 20: # Filter out short fragments sentences.append(sentence) # Create (current, next) pairs from consecutive sentences within document for i in range(len(sentences) - 1): current_texts.append(sentences[i]) next_texts.append(sentences[i + 1]) # Cap total pairs at original batch size for consistent GPU utilization if len(current_texts) >= len(text_batch): break if len(current_texts) >= len(text_batch): break # Handle edge case: no valid pairs found if len(current_texts) < 2: return None # Forward pass on current texts -> get z_pred z_pred, action_logits, value = self(current_texts, reward=0.0, mode="awake") # NaN safety: use nan_to_num which preserves gradients better z_pred = torch.nan_to_num(z_pred, nan=0.0, posinf=100.0, neginf=-100.0) z_pred = torch.clamp(z_pred, -100.0, 100.0) # Get next SDRs (targets) - encode next texts to SDR with torch.no_grad(): next_sdrs = self.model.encoder(next_texts) # (batch-1, sdr_dim) # Compute latent prediction loss (Lester 2024) loss, metrics = latent_prediction_loss( z_pred=z_pred, x_next=next_sdrs, compressor=self.model.compressor, lambda_sparse=self.hparams.lambda_sparse, ) # Add Multi-Token Prediction loss if decoder is enabled if self.model.text_decoder is not None: mtp_loss, mtp_metrics = multi_token_prediction_loss( z_pred=z_pred, target_text=next_texts, decoder=self.model.text_decoder, max_length=128, loss_weight=self.hparams.mtp_loss_weight, ) loss = 0.1 * loss + mtp_loss metrics.update({f"mtp_{k}": v for k, v in mtp_metrics.items()}) # Batch size for logging batch_size = len(current_texts) # Log primary metrics self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, ) self.log( "mse_loss", metrics["mse_loss"], on_step=True, on_epoch=True, batch_size=batch_size, ) self.log( "l1_loss", metrics["l1_loss"], on_step=False, on_epoch=True, batch_size=batch_size, ) self.log( "cos_sim", metrics["cos_sim"], on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, ) # Meta-controller metrics self.log( "train_action", float(action_logits.argmax()), on_step=False, on_epoch=True, batch_size=batch_size, ) self.log( "train_value", float(value), on_step=False, on_epoch=True, batch_size=batch_size, ) # MTP metrics (if decoder is enabled) if "mtp_mtp_loss" in metrics: self.log( "mtp_loss", metrics["mtp_mtp_loss"], on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, ) self.log( "mtp_accuracy", metrics.get("mtp_avg_accuracy", 0.0), on_step=True, on_epoch=True, prog_bar=True, batch_size=batch_size, ) # GPU memory usage if torch.cuda.is_available(): gpu_mem = torch.cuda.max_memory_allocated() / 1e9 self.log( "gpu_memory_gb", gpu_mem, on_step=False, on_epoch=True, batch_size=batch_size, ) return loss def validation_step(self, batch, batch_idx): """Validation step with latent prediction loss.""" text_batch = batch # For next-token prediction, we need at least 2 samples if len(text_batch) < 2: return None current_texts = text_batch[:-1] next_texts = text_batch[1:] # Forward pass z_pred, action_logits, value = self(current_texts, reward=0.0, mode="awake") # Get next SDRs (targets) with torch.no_grad(): next_sdrs = self.model.encoder(next_texts) # Compute latent prediction loss loss, metrics = latent_prediction_loss( z_pred=z_pred, x_next=next_sdrs, compressor=self.model.compressor, lambda_sparse=self.hparams.lambda_sparse, ) batch_size = len(current_texts) self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size) self.log("val_mse_loss", metrics["mse_loss"], on_step=False, on_epoch=True, batch_size=batch_size) self.log("val_cos_sim", metrics["cos_sim"], on_step=False, on_epoch=True, batch_size=batch_size) return loss def configure_optimizers(self): """ Configure optimizer with aggressive learning rate schedules. Supports multiple scheduling strategies: 1. OneCycleLR: Aggressive triangular schedule (fastest convergence) 2. CosineAnnealingWarmRestarts: Warm restarts for escaping local minima 3. CosineAnnealingLR: Smooth decay (default) 4. ExponentialLR: Geometric decay Uses: - AdamW optimizer with weight decay (DeepSeek V3 settings) - Mixed precision training via Lightning (16-mixed) - Gradient clipping via Trainer """ optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, betas=(0.9, 0.95), # DeepSeek V3 settings eps=1e-8, ) # Calculate warmup steps (10% of total by default) warmup_steps = min(1000, self.hparams.max_steps // 10) # STRATEGY 1: OneCycleLR (RECOMMENDED for aggressive training) # Triangular schedule: Linear warmup -> peak -> linear decay # Achieves fastest convergence with aggressive LR exploration scheduler_onecycle = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.hparams.learning_rate, total_steps=self.hparams.max_steps, pct_start=self.hparams.warmup_pct, # % of cycle spent increasing LR anneal_strategy='cos', # Cosine annealing during decay phase cycle_momentum=True, # Cycle momentum: high at start, low at end base_momentum=0.85, max_momentum=0.95, div_factor=self.hparams.lr_div_factor, # Initial LR = max_LR / div_factor final_div_factor=self.hparams.lr_final_div_factor, # Final LR = initial_LR / final_div_factor ) # STRATEGY 2: CosineAnnealingWarmRestarts (for warm-starting) # Multiple annealing cycles with periodic restarts # Helps escape local minima by periodically reheating LR scheduler_warmrestarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=int(self.hparams.max_steps * self.hparams.warmup_pct), # Initial period T_mult=2, # Double the period after each restart (must be int) eta_min=self.hparams.learning_rate / 1000, # Minimum LR floor ) # STRATEGY 3: Standard CosineAnnealingLR with warmup wrapper # Smooth cosine decay after linear warmup scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.hparams.max_steps - warmup_steps, eta_min=self.hparams.learning_rate / 100, ) # Select scheduler based on hparams if self.hparams.lr_schedule == "onecycle": scheduler = scheduler_onecycle elif self.hparams.lr_schedule == "warmrestarts": scheduler = scheduler_warmrestarts elif self.hparams.lr_schedule == "cosine": scheduler = scheduler_cosine else: scheduler = scheduler_onecycle # Default return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "interval": "step", # Step-based (not epoch-based) for granular control "frequency": 1, # Apply every step }, } class StreamingDataModule(pl.LightningDataModule): """ Lightning DataModule for streaming FineWeb data. Handles: - Memory-mapped shard loading - Multi-worker prefetching - GPU data pinning - Distributed data partitioning """ def __init__( self, data_dir: Optional[Path] = None, batch_size: int = 32, num_workers: int = 8, val_split: float = 0.0, # No validation by default (streaming setup) ): super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers self.val_split = val_split self.train_shards = [] self.val_shards = [] def setup(self, stage: Optional[str] = None): """Discover and partition shard files.""" if self.data_dir is None: print("No data_dir provided - using HuggingFace streaming") return # Find all shard files all_shards = sorted( list(self.data_dir.glob("*.pt")) + list(self.data_dir.glob("*.npy")) ) if len(all_shards) == 0: raise ValueError(f"No shard files found in {self.data_dir}") # Split into train/val if needed if self.val_split > 0: val_count = int(len(all_shards) * self.val_split) self.val_shards = all_shards[:val_count] self.train_shards = all_shards[val_count:] else: self.train_shards = all_shards self.val_shards = [] print( f"Found {len(self.train_shards)} training shards, {len(self.val_shards)} validation shards" ) def train_dataloader(self): """Create training DataLoader with streaming dataset.""" if self.data_dir is None: # Stream from HuggingFace dataset = datasets.load_dataset( "HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True, ) # Extract text content def extract_text(examples): return [x["text"] for x in examples] # Create iterable dataset that yields batches of text class HFIterable(torch.utils.data.IterableDataset): def __init__(self, hf_dataset, batch_size): self.hf_dataset = hf_dataset self.batch_size = batch_size def __iter__(self): worker_info = torch.utils.data.get_worker_info() dataset = self.hf_dataset if worker_info is not None: # Simple sharding: use islice to split the stream # worker 0 gets 0, 4, 8... # worker 1 gets 1, 5, 9... import itertools dataset = itertools.islice( dataset, worker_info.id, None, worker_info.num_workers ) batch = [] for item in dataset: # Full text (no truncation), handled by background workers text = item["text"] batch.append(text) if len(batch) == self.batch_size: yield batch batch = [] iterable_dataset = HFIterable(dataset, self.batch_size) return DataLoader( iterable_dataset, batch_size=None, # Generator handles batching num_workers=self.num_workers, # Enable workers! pin_memory=True, prefetch_factor=2 if self.num_workers > 0 else None, ) dataset = StreamingDataset( shard_paths=self.train_shards, batch_size=self.batch_size, device="cuda" if torch.cuda.is_available() else "cpu", prefetch=True, # Background prefetching cycle=True, # Infinite cycling ) return DataLoader( dataset, batch_size=None, # StreamingDataset handles batching num_workers=self.num_workers, pin_memory=True, # Fast GPU transfer persistent_workers=True if self.num_workers > 0 else False, prefetch_factor=4 if self.num_workers > 0 else None, ) def val_dataloader(self): """Create validation DataLoader (if validation shards exist).""" if len(self.val_shards) == 0: return None dataset = StreamingDataset( shard_paths=self.val_shards, batch_size=self.batch_size, device="cuda" if torch.cuda.is_available() else "cpu", prefetch=True, cycle=False, # Don't cycle validation data ) return DataLoader( dataset, batch_size=None, num_workers=self.num_workers // 2, # Fewer workers for validation pin_memory=True, persistent_workers=True if self.num_workers > 0 else False, ) def main(): """Main training entry point.""" parser = argparse.ArgumentParser(description="SEM V6 Lightning Training") # Data arguments parser.add_argument( "--data_dir", type=Path, required=False, default=None, help="Directory with shard files", ) # Batch size (6 default for 6GB VRAM, use 16 with --fast-100s) parser.add_argument("--batch_size", type=int, default=6, help="Batch size") parser.add_argument("--num_workers", type=int, default=8, help="DataLoader workers") # Model arguments # Dim 768 (3x 256) - Aggressive scaling parser.add_argument("--sdr_dim", type=int, default=768, help="SDR dimension") parser.add_argument("--sparsity", type=float, default=0.05, help="SDR sparsity") parser.add_argument( "--kan_degree", type=int, default=3, help="ChebyKAN polynomial degree" ) parser.add_argument( "--num_hyperedges", type=int, default=2000, help="Number of hyperedges" ) parser.add_argument( "--latent_dim", type=int, default=256, help="Latent space dimension (Lester 2024)" ) parser.add_argument( "--lambda_sparse", type=float, default=0.001, help="L1 sparsity coefficient for latent loss" ) # ODE solver arguments (for Module C: SolitonPropagator) parser.add_argument( "--ode_solver", type=str, default="rk4", choices=["rk4", "euler", "dopri5", "midpoint"], help="ODE solver method: rk4 (stable, default), euler (fast for training)" ) parser.add_argument( "--ode_step_size", type=float, default=0.01, help="Step size for fixed-step ODE solvers (default 0.01, use 0.05 with euler)" ) # Training arguments parser.add_argument( "--learning_rate", type=float, default=1e-3, help="Base learning rate" ) parser.add_argument("--weight_decay", type=float, default=1e-4, help="AdamW weight decay (L2 regularization)") parser.add_argument( "--max_steps", type=int, default=100000, help="Max training steps" ) parser.add_argument( "--accumulate_grad_batches", type=int, default=1, help="Gradient accumulation steps (effective_batch = batch_size * this)" ) parser.add_argument( "--gradient_clip_val", type=float, default=1.0, help="Gradient clipping norm threshold" ) # Learning rate schedule arguments parser.add_argument( "--lr_schedule", type=str, default="onecycle", choices=["onecycle", "warmrestarts", "cosine"], help="Learning rate schedule strategy" ) parser.add_argument( "--warmup_pct", type=float, default=0.1, help="Percentage of training used for warmup (0.0-1.0)" ) parser.add_argument( "--lr_div_factor", type=float, default=25.0, help="OneCycleLR: Initial LR = max_LR / div_factor" ) parser.add_argument( "--lr_final_div_factor", type=float, default=10000.0, help="OneCycleLR: Final LR = initial_LR / final_div_factor" ) # Lightning Trainer arguments parser.add_argument( "--accelerator", type=str, default="gpu", choices=["gpu", "cpu"] ) parser.add_argument("--devices", type=int, default=1, help="Number of GPUs") parser.add_argument( "--precision", type=str, default="16-mixed", help="Training precision" ) parser.add_argument( "--log_every_n_steps", type=int, default=10, help="Logging frequency" ) parser.add_argument( "--val_check_interval", type=float, default=1000, help="Validation frequency" ) # Checkpointing parser.add_argument( "--checkpoint_dir", type=Path, default=Path("checkpoints"), help="Checkpoint directory", ) parser.add_argument( "--save_top_k", type=int, default=3, help="Save top K checkpoints" ) parser.add_argument( "--no-compile", action="store_true", help="Disable torch.compile (useful for debugging)", ) # Fast 100-second training mode (agentic coherence optimization) parser.add_argument( "--fast-100s", action="store_true", help="Enable optimized settings for 100-second agentic coherence training" ) parser.add_argument( "--target-time", type=float, default=160.0, help="Target wall-clock time in seconds (default: 160.0, includes ~125s startup overhead)" ) # Multi-Token Prediction (MTP) arguments parser.add_argument( "--enable-mtp", action="store_true", default=True, help="Enable Multi-Token Prediction text decoder (default: True)" ) parser.add_argument( "--disable-mtp", action="store_true", help="Disable Multi-Token Prediction text decoder" ) parser.add_argument( "--num-predict", type=int, default=4, help="Number of future tokens to predict with MTP (default: 4)" ) parser.add_argument( "--mtp-loss-weight", type=float, default=1.0, help="Weight for MTP loss (default: 1.0)" ) args = parser.parse_args() # Handle MTP enable/disable args.enable_mtp = not args.disable_mtp # Apply --fast-100s optimizations if flag is set # Based on research: OneCycleLR super-convergence, euler solver, no accumulation if args.fast_100s: # === TRAINING HYPERPARAMETERS === args.batch_size = 16 # Larger batches, no accumulation args.learning_rate = 0.002 # 2x for super-convergence args.weight_decay = 1e-5 # Reduced for super-convergence args.lr_div_factor = 10.0 # Initial LR = 0.002 / 10 = 0.0002 args.lr_final_div_factor = 100.0 # Final LR = 0.0002 / 100 = 0.000002 args.accumulate_grad_batches = 1 # CRITICAL: no accumulation for 100 updates # === MODEL CAPACITY (balanced for 6GB VRAM with 128k vocab) === # Per CLAUDE.md: "Model parameters CAN be increased" # BUT 128k vocab tokenizer + text decoder needs ~1GB alone args.sdr_dim = 1024 # 768 → 1024: moderate increase args.latent_dim = 256 # Keep 256: balances with 128k vocab decoder args.kan_degree = 4 # 3 → 4: slightly more expressive args.num_hyperedges = 2000 # Keep 2000: sufficient for short training # === ODE SOLVER (speed optimization) === args.ode_solver = "euler" # Fast ODE solver (4x fewer evaluations) args.ode_step_size = 0.05 # Larger step size (5x fewer steps) # === TIME BUDGET CALCULATION === # ACTUAL overhead: ~60s for model init + tokenizer + HuggingFace streaming startup # - Model init: ~7s # - Llama tokenizer download: ~5s # - HuggingFace streaming first batch: ~50s # - torch.compile / GPU warmup pause: ~60s (mysterious quiet period) # Conservative estimate: 125s overhead # At ~4 steps/sec with remaining time estimated_overhead_s = 125.0 available_training_s = args.target_time - estimated_overhead_s estimated_steps_per_sec = 4.0 # Conservative estimate args.max_steps = max(50, int(available_training_s * estimated_steps_per_sec)) elapsed_so_far = time.perf_counter() - _SCRIPT_START_TIME print("=" * 80) print("FAST 100s MODE ENABLED - Optimized for agentic coherence") print(f" Target time: {args.target_time}s (elapsed so far: {elapsed_so_far:.1f}s)") print(f" Estimated overhead: {estimated_overhead_s}s (model ~7s + tokenizer ~5s + HF stream ~50s + compile/warmup ~60s)") print(f" Training budget: {available_training_s}s → max {args.max_steps} steps") print("-" * 40) print(" MODEL CAPACITY (balanced for 6GB + 128k vocab):") print(f" sdr_dim: {args.sdr_dim}, latent_dim: {args.latent_dim}") print(f" kan_degree: {args.kan_degree}, num_hyperedges: {args.num_hyperedges}") print("-" * 40) print(" TRAINING:") print(" batch_size=16, lr=0.002, weight_decay=1e-5, accumulation=1") print(" ODE solver: euler with step_size=0.05") print("=" * 80) # Validate CUDA availability if args.accelerator == "gpu": assert torch.cuda.is_available(), "GPU required but CUDA not available" # Create data module datamodule = StreamingDataModule( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, ) # Create model with latent-space architecture (Lester 2024) compile_model = not args.no_compile model = SEMV6LightningModule( sdr_dim=args.sdr_dim, sparsity=args.sparsity, num_hyperedges=args.num_hyperedges, kan_degree=args.kan_degree, latent_dim=args.latent_dim, lambda_sparse=args.lambda_sparse, learning_rate=args.learning_rate, weight_decay=args.weight_decay, max_steps=args.max_steps, compile_model=compile_model, lr_schedule=args.lr_schedule, warmup_pct=args.warmup_pct, lr_div_factor=args.lr_div_factor, lr_final_div_factor=args.lr_final_div_factor, ode_solver=args.ode_solver, ode_step_size=args.ode_step_size, enable_text_decoder=args.enable_mtp, num_predict=args.num_predict, mtp_loss_weight=args.mtp_loss_weight, ) # Setup callbacks callbacks = [ # Timing callback for 100s target tracking TimingCallback( script_start_time=_SCRIPT_START_TIME, target_time=args.target_time, ), # Model checkpointing (save best models) ModelCheckpoint( dirpath=args.checkpoint_dir, filename="semv6-{step:06d}-{train_loss:.4f}", monitor="train_loss", mode="min", save_top_k=args.save_top_k, every_n_train_steps=1000, save_last=True, ), # Learning rate monitoring LearningRateMonitor(logging_interval="step"), # Rich progress bar RichProgressBar(), # Model summary RichModelSummary(max_depth=2), ] # Validation callbacks for English coherence monitoring test_prompts = [ "The capital of France is", "Water boils at", "The human body has", "Plants produce oxygen through", "The speed of light is", ] # Create validators fast_validator = FastValidator(test_prompts=test_prompts) grammar_client = LanguageToolClient(url="http://localhost:8081") grammar_validator = GrammarValidator( client=grammar_client, test_prompts=test_prompts ) # Add combined validation callback callbacks.append( CombinedValidationCallback( fast_validator=fast_validator, grammar_validator=grammar_validator, test_prompts=test_prompts, frequency=200, ) ) # Early stopping (optional, disable for continuous training) # callbacks.append( # EarlyStopping( # monitor='val_loss', # patience=10, # mode='min', # ) # ) # Setup logger logger = TensorBoardLogger( save_dir="lightning_logs", name="semv6_training", ) # Create Lightning Trainer optimized for maximum speed trainer = pl.Trainer( # ==================== HARDWARE CONFIGURATION ==================== accelerator=args.accelerator, devices=args.devices, precision=args.precision, # 16-mixed: FP16 computation, FP32 accumulation # For even faster training on RTX 3060, use "16" (pure FP16, less stable) # Default "16-mixed" is safest for ODE-based models # ==================== TRAINING HYPERPARAMETERS ==================== max_steps=args.max_steps, # Gradient accumulation: effective_batch = batch_size * accumulate_grad_batches # With 6GB VRAM: batch_size=6, accumulate=4 -> effective=24 (good convergence) accumulate_grad_batches=args.accumulate_grad_batches, gradient_clip_val=args.gradient_clip_val, gradient_clip_algorithm="norm", # L2 norm clipping (more stable than 'value') # ==================== PERFORMANCE OPTIMIZATIONS ==================== # cuDNN benchmark mode: optimize kernel selection for your hardware benchmark=True, # Allow non-deterministic ops for speed (disable reproducibility) deterministic=False, # TF32 is enabled via torch.set_float32_matmul_precision('high') at module level # ==================== LOGGING & MONITORING ==================== logger=logger, callbacks=callbacks, log_every_n_steps=args.log_every_n_steps, val_check_interval=args.val_check_interval, enable_model_summary=False, # Skip model summary for faster startup enable_progress_bar=True, # Rich progress bar for real-time feedback # ==================== VALIDATION STRATEGY ==================== # Disable validation for streaming (infinite dataset, no validation set) limit_val_batches=0, num_sanity_val_steps=0, # ==================== MEMORY OPTIMIZATION ==================== # Gradient checkpointing: trade compute for memory (disabled by default) # Enable if OOM: enable_checkpointing=True enable_checkpointing=True, # ==================== DEBUGGING (PRODUCTION: DISABLED) ==================== # Uncomment for debugging: # detect_anomaly=True, # Detect NaN/Inf propagation # profiler="pytorch", # PyTorch profiler (slows training 10-30%) # ==================== DISTRIBUTED TRAINING (SINGLE GPU) ==================== strategy="auto", # Auto-detect best strategy (single GPU -> no strategy) sync_batchnorm=False, # No batch norm sync for single GPU ) # Print training info model_init_time = time.perf_counter() - _SCRIPT_START_TIME print("=" * 80) print("SEM V6 PyTorch Lightning Training (Latent Prediction - Lester 2024)") print("=" * 80) print(f"⏱️ Model created in {model_init_time:.1f}s (target: {args.target_time}s)") print("-" * 80) print(f"Data directory: {args.data_dir}") print(f"Batch size: {args.batch_size}") print(f"Gradient accumulation: {args.accumulate_grad_batches}") print(f"Effective batch size: {args.batch_size * args.accumulate_grad_batches}") print(f"Max steps: {args.max_steps}") print(f"Precision: {args.precision}") print(f"Devices: {args.devices} {args.accelerator}") print("-" * 80) print("LEARNING RATE SCHEDULE:") print(f" Strategy: {args.lr_schedule.upper()}") print(f" Base LR: {args.learning_rate}") print(f" Weight decay (L2): {args.weight_decay}") print(f" Warmup percentage: {args.warmup_pct * 100:.1f}% ({int(args.max_steps * args.warmup_pct)} steps)") if args.lr_schedule == "onecycle": print(f" Initial LR: {args.learning_rate / args.lr_div_factor:.6f}") print(f" Max LR: {args.learning_rate}") print(f" Final LR: {args.learning_rate / args.lr_div_factor / args.lr_final_div_factor:.8f}") print(" Momentum cycling: 0.85 -> 0.95 -> 0.85") elif args.lr_schedule == "warmrestarts": print(f" Restarts every: {int(args.max_steps * args.warmup_pct)} steps") print(f" Period multiplier: 2.0x after each restart") print(f" Min LR: {args.learning_rate / 1000:.8f}") else: # cosine print(f" Min LR: {args.learning_rate / 100:.8f}") print("-" * 80) print(f"SDR dimension: {args.sdr_dim}") print(f"Latent dimension: {args.latent_dim} (compression ratio: {args.sdr_dim / args.latent_dim:.1f}x)") print(f"L1 sparsity coefficient: {args.lambda_sparse}") print("-" * 80) print("ODE SOLVER (Module C: SolitonPropagator):") print(f" Solver: {args.ode_solver}") print(f" Step size: {args.ode_step_size}") print("-" * 80) print("MULTI-TOKEN PREDICTION (MTP):") print(f" Enabled: {args.enable_mtp}") if args.enable_mtp: print(f" Num tokens: {args.num_predict}") print(f" Loss weight: {args.mtp_loss_weight}") print(f" Tokenizer: meta-llama/Meta-Llama-3-8B") print("=" * 80) # Train the model trainer.fit(model, datamodule=datamodule) print("=" * 80) print("Training complete!") print(f"Best checkpoint: {trainer.checkpoint_callback.best_model_path}") print("=" * 80) if __name__ == "__main__": main()