| |
| """ |
| SEM V6 PyTorch Lightning Training Script |
| |
| Production-ready training with all best practices: |
| - Automatic mixed precision (AMP) |
| - Gradient accumulation |
| - Learning rate scheduling |
| - Model checkpointing |
| - Progress bars and rich logging |
| - Distributed training support |
| - Early stopping |
| - Gradient clipping |
| - Learning rate finder |
| - Model pruning/quantization ready |
| - TensorBoard logging |
| - Memory-efficient data streaming with prefetching |
| |
| Reference: |
| - PyTorch Lightning 2.x best practices |
| - AGENTS.md: GPU-first, streaming data, maximum speed |
| """ |
|
|
| import argparse |
| import logging |
| import math |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import torch |
|
|
| |
| _SCRIPT_START_TIME = time.perf_counter() |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import pytorch_lightning as pl |
| import datasets |
| from pytorch_lightning.callbacks import ( |
| ModelCheckpoint, |
| EarlyStopping, |
| LearningRateMonitor, |
| RichProgressBar, |
| RichModelSummary, |
| ) |
| from pytorch_lightning.loggers import TensorBoardLogger |
| from torch.utils.data import DataLoader |
|
|
| |
| import sys |
|
|
| sys.path.insert(0, "src") |
| sys.path.insert(0, "ChebyKan_cuda_op") |
|
|
| from sem_v6.sem_v6 import SEMV6 |
| from sem_v6.data.streaming_dataset import StreamingDataset |
| from sem_v6.modules.text_decoder import multi_token_prediction_loss |
| from sem_v6.validation.callbacks import CombinedValidationCallback |
| from sem_v6.validation.validators import FastValidator, GrammarValidator |
| from sem_v6.validation.grammar_checker import LanguageToolClient |
|
|
| |
| torch.set_float32_matmul_precision('high') |
|
|
|
|
| class TimingCallback(pl.Callback): |
| """ |
| Callback to track wall-clock timing for 100s training target. |
| |
| Tracks: |
| - Time to first batch (includes data loading startup) |
| - Training time per step |
| - Total elapsed time from script start |
| """ |
|
|
| def __init__(self, script_start_time: float, target_time: float = 100.0): |
| super().__init__() |
| self.script_start_time = script_start_time |
| self.target_time = target_time |
| self.fit_start_time: Optional[float] = None |
| self.first_batch_time: Optional[float] = None |
| self.training_start_time: Optional[float] = None |
| self.step_count = 0 |
| self.total_training_time = 0.0 |
|
|
| def on_fit_start(self, trainer, pl_module): |
| self.fit_start_time = time.perf_counter() |
| elapsed = self.fit_start_time - self.script_start_time |
| print(f"\n⏱️ [TIMING] Model initialization took: {elapsed:.2f}s") |
| remaining = self.target_time - elapsed |
| print(f"⏱️ [TIMING] Time budget remaining: {remaining:.2f}s / {self.target_time}s") |
|
|
| def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): |
| if self.first_batch_time is None: |
| self.first_batch_time = time.perf_counter() |
| data_load_time = self.first_batch_time - self.fit_start_time |
| elapsed = self.first_batch_time - self.script_start_time |
| print(f"⏱️ [TIMING] Data loading/first batch took: {data_load_time:.2f}s") |
| print(f"⏱️ [TIMING] Total startup overhead: {elapsed:.2f}s") |
| remaining = self.target_time - elapsed |
| print(f"⏱️ [TIMING] Time budget for training: {remaining:.2f}s") |
| self.training_start_time = time.perf_counter() |
|
|
| def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
| self.step_count += 1 |
| elapsed = time.perf_counter() - self.script_start_time |
| remaining = self.target_time - elapsed |
|
|
| |
| if remaining <= 2.0: |
| print(f"\n⏱️ [TIMING] Stopping training - approaching time limit ({elapsed:.1f}s / {self.target_time}s)") |
| trainer.should_stop = True |
| return |
|
|
| |
| if self.step_count % 50 == 0: |
| training_elapsed = time.perf_counter() - self.training_start_time |
| steps_per_sec = self.step_count / training_elapsed if training_elapsed > 0 else 0 |
| print(f"⏱️ [TIMING] Step {self.step_count}: {elapsed:.1f}s elapsed, {remaining:.1f}s remaining, {steps_per_sec:.2f} steps/s") |
|
|
| def on_fit_end(self, trainer, pl_module): |
| end_time = time.perf_counter() |
| total_elapsed = end_time - self.script_start_time |
| training_time = end_time - self.training_start_time if self.training_start_time else 0 |
| startup_time = total_elapsed - training_time |
|
|
| print("\n" + "=" * 80) |
| print("⏱️ TIMING SUMMARY (100s Target)") |
| print("=" * 80) |
| print(f" Startup (init + data load): {startup_time:.2f}s") |
| print(f" Training time: {training_time:.2f}s") |
| print(f" Total wall-clock time: {total_elapsed:.2f}s") |
| print("-" * 80) |
| print(f" Steps completed: {self.step_count}") |
| if training_time > 0: |
| print(f" Training speed: {self.step_count / training_time:.2f} steps/s") |
|
|
| if total_elapsed <= self.target_time: |
| print(f"\n ✅ PASSED: Completed in {total_elapsed:.2f}s (under {self.target_time}s target)") |
| else: |
| overage = total_elapsed - self.target_time |
| print(f"\n ❌ FAILED: Took {total_elapsed:.2f}s ({overage:.2f}s over {self.target_time}s target)") |
| print("=" * 80) |
|
|
|
|
| def latent_prediction_loss( |
| z_pred: torch.Tensor, |
| x_next: torch.Tensor, |
| compressor: nn.Module, |
| lambda_sparse: float = 0.001, |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """ |
| Compute loss in latent space (not SDR reconstruction). |
| |
| Based on Lester et al. 2024 "Training LLMs over Neurally Compressed Text": |
| predicting compressed representation is more learnable than reconstructing |
| the full SDR. This is the core training objective for SEM V6. |
| |
| Loss function: |
| L = ||z_pred - compress(x_next)||² + λ||z_pred||₁ |
| |
| Where: |
| - z_pred: Output from propagator (predicted next latent) |
| - x_next: Ground truth next SDR |
| - compress(): NeuralCompressor module |
| - λ: Sparsity coefficient (default 0.001) |
| |
| Args: |
| z_pred: Predicted next latent (batch, latent_dim) from propagator |
| x_next: Ground truth next SDR (batch, sdr_dim) - will be compressed |
| compressor: NeuralCompressor module to compress x_next |
| lambda_sparse: L1 sparsity coefficient for regularization. Default: 0.001 |
| |
| Returns: |
| loss: Total loss scalar (MSE + L1) |
| metrics: Dict with individual loss components: |
| - mse_loss: Mean squared error between predicted and target latent |
| - l1_loss: L1 regularization on predicted latent |
| - total_loss: Combined loss |
| - cos_sim: Cosine similarity for monitoring prediction quality |
| |
| Example: |
| >>> compressor = nn.Linear(2048, 256) # Simplified compressor |
| >>> z_pred = torch.randn(32, 256) |
| >>> x_next = torch.randn(32, 2048) |
| >>> loss, metrics = latent_prediction_loss(z_pred, x_next, compressor) |
| >>> loss.backward() |
| >>> metrics["mse_loss"] # Individual MSE component |
| >>> metrics["cos_sim"] # Prediction quality metric |
| |
| References: |
| Lester, B., et al. (2024). "Training LLMs over Neurally Compressed Text." |
| arXiv:2404.03626 |
| """ |
| |
| with torch.no_grad(): |
| z_target = compressor(x_next.float()) |
| |
| z_target = torch.nan_to_num(z_target, nan=0.0, posinf=100.0, neginf=-100.0) |
|
|
| |
| mse_loss = F.mse_loss(z_pred, z_target) |
|
|
| |
| l1_loss = lambda_sparse * z_pred.abs().mean() |
|
|
| |
| total_loss = mse_loss + l1_loss |
|
|
| |
| with torch.no_grad(): |
| cos_sim = F.cosine_similarity(z_pred, z_target, dim=-1).mean() |
|
|
| metrics = { |
| "mse_loss": mse_loss.item(), |
| "l1_loss": l1_loss.item(), |
| "total_loss": total_loss.item(), |
| "cos_sim": cos_sim.item(), |
| } |
|
|
| return total_loss, metrics |
|
|
|
|
| class SEMV6LightningModule(pl.LightningModule): |
| """ |
| PyTorch Lightning wrapper for SEM V6 training. |
| |
| Implements all best practices: |
| - Automatic optimization with gradient accumulation |
| - Learning rate scheduling |
| - Logging and metrics |
| - Checkpoint management |
| - Mixed precision training |
| """ |
|
|
| def __init__( |
| self, |
| sdr_dim: int = 1368, |
| sparsity: float = 0.05, |
| num_hyperedges: int = 2000, |
| kan_degree: int = 5, |
| latent_dim: int = 256, |
| lambda_sparse: float = 0.001, |
| learning_rate: float = 1e-3, |
| weight_decay: float = 1e-4, |
| warmup_steps: int = 1000, |
| max_steps: int = 100000, |
| compile_model: bool = True, |
| lr_schedule: str = "onecycle", |
| warmup_pct: float = 0.1, |
| lr_div_factor: float = 25.0, |
| lr_final_div_factor: float = 10000.0, |
| ode_solver: str = "rk4", |
| ode_step_size: float = 0.01, |
| enable_text_decoder: bool = True, |
| num_predict: int = 4, |
| mtp_loss_weight: float = 1.0, |
| ): |
| super().__init__() |
|
|
| |
| self.save_hyperparameters() |
|
|
| |
| self.model = SEMV6( |
| sdr_dim=sdr_dim, |
| sparsity=sparsity, |
| num_hyperedges=num_hyperedges, |
| kan_degree=kan_degree, |
| latent_dim=latent_dim, |
| ode_solver=ode_solver, |
| ode_step_size=ode_step_size, |
| enable_text_decoder=enable_text_decoder, |
| num_predict=num_predict, |
| ) |
|
|
| |
| if compile_model: |
| self.model = torch.compile(self.model, mode='reduce-overhead') |
|
|
| |
| self.train_loss = [] |
| self.val_loss = [] |
|
|
| def forward(self, text_batch, reward=None, mode="awake"): |
| """Forward pass through SEM V6.""" |
| return self.model(text_batch, reward=reward, mode=mode) |
|
|
| def training_step(self, batch, batch_idx): |
| """ |
| Training step with latent-space next-token prediction. |
| |
| Based on Lester et al. 2024: predict latent representation of NEXT token, |
| not reconstruct full SDR. This is more efficient and learnable. |
| |
| Loss function: |
| L = ||z_pred - compress(x_next)||² + λ||z_pred||₁ |
| |
| Lightning handles: |
| - Gradient accumulation |
| - Mixed precision |
| - Gradient clipping |
| - Optimizer stepping |
| """ |
| |
| text_batch = batch |
|
|
| |
| |
| current_texts = [] |
| next_texts = [] |
|
|
| for document in text_batch: |
| |
| sentences = [] |
| for chunk in document.replace('\n', '. ').split('. '): |
| sentence = chunk.strip() |
| if len(sentence) >= 20: |
| sentences.append(sentence) |
|
|
| |
| for i in range(len(sentences) - 1): |
| current_texts.append(sentences[i]) |
| next_texts.append(sentences[i + 1]) |
|
|
| |
| if len(current_texts) >= len(text_batch): |
| break |
|
|
| if len(current_texts) >= len(text_batch): |
| break |
|
|
| |
| if len(current_texts) < 2: |
| return None |
|
|
| |
| z_pred, action_logits, value = self(current_texts, reward=0.0, mode="awake") |
|
|
| |
| z_pred = torch.nan_to_num(z_pred, nan=0.0, posinf=100.0, neginf=-100.0) |
| z_pred = torch.clamp(z_pred, -100.0, 100.0) |
|
|
| |
| with torch.no_grad(): |
| next_sdrs = self.model.encoder(next_texts) |
|
|
| |
| loss, metrics = latent_prediction_loss( |
| z_pred=z_pred, |
| x_next=next_sdrs, |
| compressor=self.model.compressor, |
| lambda_sparse=self.hparams.lambda_sparse, |
| ) |
|
|
| |
| if self.model.text_decoder is not None: |
| mtp_loss, mtp_metrics = multi_token_prediction_loss( |
| z_pred=z_pred, |
| target_text=next_texts, |
| decoder=self.model.text_decoder, |
| max_length=128, |
| loss_weight=self.hparams.mtp_loss_weight, |
| ) |
| loss = 0.1 * loss + mtp_loss |
| metrics.update({f"mtp_{k}": v for k, v in mtp_metrics.items()}) |
|
|
| |
| batch_size = len(current_texts) |
|
|
| |
| self.log( |
| "train_loss", |
| loss, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=batch_size, |
| ) |
| self.log( |
| "mse_loss", |
| metrics["mse_loss"], |
| on_step=True, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| self.log( |
| "l1_loss", |
| metrics["l1_loss"], |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| self.log( |
| "cos_sim", |
| metrics["cos_sim"], |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=batch_size, |
| ) |
|
|
| |
| self.log( |
| "train_action", |
| float(action_logits.argmax()), |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
| self.log( |
| "train_value", |
| float(value), |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
|
|
| |
| if "mtp_mtp_loss" in metrics: |
| self.log( |
| "mtp_loss", |
| metrics["mtp_mtp_loss"], |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=batch_size, |
| ) |
| self.log( |
| "mtp_accuracy", |
| metrics.get("mtp_avg_accuracy", 0.0), |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=batch_size, |
| ) |
|
|
| |
| if torch.cuda.is_available(): |
| gpu_mem = torch.cuda.max_memory_allocated() / 1e9 |
| self.log( |
| "gpu_memory_gb", |
| gpu_mem, |
| on_step=False, |
| on_epoch=True, |
| batch_size=batch_size, |
| ) |
|
|
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| """Validation step with latent prediction loss.""" |
| text_batch = batch |
|
|
| |
| if len(text_batch) < 2: |
| return None |
|
|
| current_texts = text_batch[:-1] |
| next_texts = text_batch[1:] |
|
|
| |
| z_pred, action_logits, value = self(current_texts, reward=0.0, mode="awake") |
|
|
| |
| with torch.no_grad(): |
| next_sdrs = self.model.encoder(next_texts) |
|
|
| |
| loss, metrics = latent_prediction_loss( |
| z_pred=z_pred, |
| x_next=next_sdrs, |
| compressor=self.model.compressor, |
| lambda_sparse=self.hparams.lambda_sparse, |
| ) |
|
|
| batch_size = len(current_texts) |
| self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size) |
| self.log("val_mse_loss", metrics["mse_loss"], on_step=False, on_epoch=True, batch_size=batch_size) |
| self.log("val_cos_sim", metrics["cos_sim"], on_step=False, on_epoch=True, batch_size=batch_size) |
|
|
| return loss |
|
|
| def configure_optimizers(self): |
| """ |
| Configure optimizer with aggressive learning rate schedules. |
| |
| Supports multiple scheduling strategies: |
| 1. OneCycleLR: Aggressive triangular schedule (fastest convergence) |
| 2. CosineAnnealingWarmRestarts: Warm restarts for escaping local minima |
| 3. CosineAnnealingLR: Smooth decay (default) |
| 4. ExponentialLR: Geometric decay |
| |
| Uses: |
| - AdamW optimizer with weight decay (DeepSeek V3 settings) |
| - Mixed precision training via Lightning (16-mixed) |
| - Gradient clipping via Trainer |
| """ |
| optimizer = torch.optim.AdamW( |
| self.parameters(), |
| lr=self.hparams.learning_rate, |
| weight_decay=self.hparams.weight_decay, |
| betas=(0.9, 0.95), |
| eps=1e-8, |
| ) |
|
|
| |
| warmup_steps = min(1000, self.hparams.max_steps // 10) |
|
|
| |
| |
| |
| scheduler_onecycle = torch.optim.lr_scheduler.OneCycleLR( |
| optimizer, |
| max_lr=self.hparams.learning_rate, |
| total_steps=self.hparams.max_steps, |
| pct_start=self.hparams.warmup_pct, |
| anneal_strategy='cos', |
| cycle_momentum=True, |
| base_momentum=0.85, |
| max_momentum=0.95, |
| div_factor=self.hparams.lr_div_factor, |
| final_div_factor=self.hparams.lr_final_div_factor, |
| ) |
|
|
| |
| |
| |
| scheduler_warmrestarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| optimizer, |
| T_0=int(self.hparams.max_steps * self.hparams.warmup_pct), |
| T_mult=2, |
| eta_min=self.hparams.learning_rate / 1000, |
| ) |
|
|
| |
| |
| scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=self.hparams.max_steps - warmup_steps, |
| eta_min=self.hparams.learning_rate / 100, |
| ) |
|
|
| |
| if self.hparams.lr_schedule == "onecycle": |
| scheduler = scheduler_onecycle |
| elif self.hparams.lr_schedule == "warmrestarts": |
| scheduler = scheduler_warmrestarts |
| elif self.hparams.lr_schedule == "cosine": |
| scheduler = scheduler_cosine |
| else: |
| scheduler = scheduler_onecycle |
|
|
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| "scheduler": scheduler, |
| "interval": "step", |
| "frequency": 1, |
| }, |
| } |
|
|
|
|
| class StreamingDataModule(pl.LightningDataModule): |
| """ |
| Lightning DataModule for streaming FineWeb data. |
| |
| Handles: |
| - Memory-mapped shard loading |
| - Multi-worker prefetching |
| - GPU data pinning |
| - Distributed data partitioning |
| """ |
|
|
| def __init__( |
| self, |
| data_dir: Optional[Path] = None, |
| batch_size: int = 32, |
| num_workers: int = 8, |
| val_split: float = 0.0, |
| ): |
| super().__init__() |
| self.data_dir = data_dir |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.val_split = val_split |
|
|
| self.train_shards = [] |
| self.val_shards = [] |
|
|
| def setup(self, stage: Optional[str] = None): |
| """Discover and partition shard files.""" |
| if self.data_dir is None: |
| print("No data_dir provided - using HuggingFace streaming") |
| return |
|
|
| |
| all_shards = sorted( |
| list(self.data_dir.glob("*.pt")) + list(self.data_dir.glob("*.npy")) |
| ) |
|
|
| if len(all_shards) == 0: |
| raise ValueError(f"No shard files found in {self.data_dir}") |
|
|
| |
| if self.val_split > 0: |
| val_count = int(len(all_shards) * self.val_split) |
| self.val_shards = all_shards[:val_count] |
| self.train_shards = all_shards[val_count:] |
| else: |
| self.train_shards = all_shards |
| self.val_shards = [] |
|
|
| print( |
| f"Found {len(self.train_shards)} training shards, {len(self.val_shards)} validation shards" |
| ) |
|
|
| def train_dataloader(self): |
| """Create training DataLoader with streaming dataset.""" |
| if self.data_dir is None: |
| |
| dataset = datasets.load_dataset( |
| "HuggingFaceFW/fineweb-edu", |
| name="sample-10BT", |
| split="train", |
| streaming=True, |
| ) |
|
|
| |
| def extract_text(examples): |
| return [x["text"] for x in examples] |
|
|
| |
| class HFIterable(torch.utils.data.IterableDataset): |
| def __init__(self, hf_dataset, batch_size): |
| self.hf_dataset = hf_dataset |
| self.batch_size = batch_size |
|
|
| def __iter__(self): |
| worker_info = torch.utils.data.get_worker_info() |
| dataset = self.hf_dataset |
|
|
| if worker_info is not None: |
| |
| |
| |
| import itertools |
|
|
| dataset = itertools.islice( |
| dataset, worker_info.id, None, worker_info.num_workers |
| ) |
|
|
| batch = [] |
| for item in dataset: |
| |
| text = item["text"] |
| batch.append(text) |
| if len(batch) == self.batch_size: |
| yield batch |
| batch = [] |
|
|
| iterable_dataset = HFIterable(dataset, self.batch_size) |
|
|
| return DataLoader( |
| iterable_dataset, |
| batch_size=None, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| prefetch_factor=2 if self.num_workers > 0 else None, |
| ) |
|
|
| dataset = StreamingDataset( |
| shard_paths=self.train_shards, |
| batch_size=self.batch_size, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| prefetch=True, |
| cycle=True, |
| ) |
|
|
| return DataLoader( |
| dataset, |
| batch_size=None, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| persistent_workers=True if self.num_workers > 0 else False, |
| prefetch_factor=4 if self.num_workers > 0 else None, |
| ) |
|
|
| def val_dataloader(self): |
| """Create validation DataLoader (if validation shards exist).""" |
| if len(self.val_shards) == 0: |
| return None |
|
|
| dataset = StreamingDataset( |
| shard_paths=self.val_shards, |
| batch_size=self.batch_size, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| prefetch=True, |
| cycle=False, |
| ) |
|
|
| return DataLoader( |
| dataset, |
| batch_size=None, |
| num_workers=self.num_workers // 2, |
| pin_memory=True, |
| persistent_workers=True if self.num_workers > 0 else False, |
| ) |
|
|
|
|
| def main(): |
| """Main training entry point.""" |
| parser = argparse.ArgumentParser(description="SEM V6 Lightning Training") |
|
|
| |
| parser.add_argument( |
| "--data_dir", |
| type=Path, |
| required=False, |
| default=None, |
| help="Directory with shard files", |
| ) |
| |
| parser.add_argument("--batch_size", type=int, default=6, help="Batch size") |
| parser.add_argument("--num_workers", type=int, default=8, help="DataLoader workers") |
|
|
| |
| |
| parser.add_argument("--sdr_dim", type=int, default=768, help="SDR dimension") |
| parser.add_argument("--sparsity", type=float, default=0.05, help="SDR sparsity") |
| parser.add_argument( |
| "--kan_degree", type=int, default=3, help="ChebyKAN polynomial degree" |
| ) |
| parser.add_argument( |
| "--num_hyperedges", type=int, default=2000, help="Number of hyperedges" |
| ) |
| parser.add_argument( |
| "--latent_dim", type=int, default=256, help="Latent space dimension (Lester 2024)" |
| ) |
| parser.add_argument( |
| "--lambda_sparse", type=float, default=0.001, help="L1 sparsity coefficient for latent loss" |
| ) |
|
|
| |
| parser.add_argument( |
| "--ode_solver", type=str, default="rk4", |
| choices=["rk4", "euler", "dopri5", "midpoint"], |
| help="ODE solver method: rk4 (stable, default), euler (fast for training)" |
| ) |
| parser.add_argument( |
| "--ode_step_size", type=float, default=0.01, |
| help="Step size for fixed-step ODE solvers (default 0.01, use 0.05 with euler)" |
| ) |
|
|
| |
| parser.add_argument( |
| "--learning_rate", type=float, default=1e-3, help="Base learning rate" |
| ) |
| parser.add_argument("--weight_decay", type=float, default=1e-4, help="AdamW weight decay (L2 regularization)") |
| parser.add_argument( |
| "--max_steps", type=int, default=100000, help="Max training steps" |
| ) |
| parser.add_argument( |
| "--accumulate_grad_batches", type=int, default=1, |
| help="Gradient accumulation steps (effective_batch = batch_size * this)" |
| ) |
| parser.add_argument( |
| "--gradient_clip_val", type=float, default=1.0, help="Gradient clipping norm threshold" |
| ) |
|
|
| |
| parser.add_argument( |
| "--lr_schedule", type=str, default="onecycle", |
| choices=["onecycle", "warmrestarts", "cosine"], |
| help="Learning rate schedule strategy" |
| ) |
| parser.add_argument( |
| "--warmup_pct", type=float, default=0.1, |
| help="Percentage of training used for warmup (0.0-1.0)" |
| ) |
| parser.add_argument( |
| "--lr_div_factor", type=float, default=25.0, |
| help="OneCycleLR: Initial LR = max_LR / div_factor" |
| ) |
| parser.add_argument( |
| "--lr_final_div_factor", type=float, default=10000.0, |
| help="OneCycleLR: Final LR = initial_LR / final_div_factor" |
| ) |
|
|
| |
| parser.add_argument( |
| "--accelerator", type=str, default="gpu", choices=["gpu", "cpu"] |
| ) |
| parser.add_argument("--devices", type=int, default=1, help="Number of GPUs") |
| parser.add_argument( |
| "--precision", type=str, default="16-mixed", help="Training precision" |
| ) |
| parser.add_argument( |
| "--log_every_n_steps", type=int, default=10, help="Logging frequency" |
| ) |
| parser.add_argument( |
| "--val_check_interval", type=float, default=1000, help="Validation frequency" |
| ) |
|
|
| |
| parser.add_argument( |
| "--checkpoint_dir", |
| type=Path, |
| default=Path("checkpoints"), |
| help="Checkpoint directory", |
| ) |
| parser.add_argument( |
| "--save_top_k", type=int, default=3, help="Save top K checkpoints" |
| ) |
| parser.add_argument( |
| "--no-compile", |
| action="store_true", |
| help="Disable torch.compile (useful for debugging)", |
| ) |
|
|
| |
| parser.add_argument( |
| "--fast-100s", |
| action="store_true", |
| help="Enable optimized settings for 100-second agentic coherence training" |
| ) |
| parser.add_argument( |
| "--target-time", |
| type=float, |
| default=160.0, |
| help="Target wall-clock time in seconds (default: 160.0, includes ~125s startup overhead)" |
| ) |
|
|
| |
| parser.add_argument( |
| "--enable-mtp", |
| action="store_true", |
| default=True, |
| help="Enable Multi-Token Prediction text decoder (default: True)" |
| ) |
| parser.add_argument( |
| "--disable-mtp", |
| action="store_true", |
| help="Disable Multi-Token Prediction text decoder" |
| ) |
| parser.add_argument( |
| "--num-predict", |
| type=int, |
| default=4, |
| help="Number of future tokens to predict with MTP (default: 4)" |
| ) |
| parser.add_argument( |
| "--mtp-loss-weight", |
| type=float, |
| default=1.0, |
| help="Weight for MTP loss (default: 1.0)" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| args.enable_mtp = not args.disable_mtp |
|
|
| |
| |
| if args.fast_100s: |
| |
| args.batch_size = 16 |
| args.learning_rate = 0.002 |
| args.weight_decay = 1e-5 |
| args.lr_div_factor = 10.0 |
| args.lr_final_div_factor = 100.0 |
| args.accumulate_grad_batches = 1 |
|
|
| |
| |
| |
| args.sdr_dim = 1024 |
| args.latent_dim = 256 |
| args.kan_degree = 4 |
| args.num_hyperedges = 2000 |
|
|
| |
| args.ode_solver = "euler" |
| args.ode_step_size = 0.05 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| estimated_overhead_s = 125.0 |
| available_training_s = args.target_time - estimated_overhead_s |
| estimated_steps_per_sec = 4.0 |
| args.max_steps = max(50, int(available_training_s * estimated_steps_per_sec)) |
|
|
| elapsed_so_far = time.perf_counter() - _SCRIPT_START_TIME |
| print("=" * 80) |
| print("FAST 100s MODE ENABLED - Optimized for agentic coherence") |
| print(f" Target time: {args.target_time}s (elapsed so far: {elapsed_so_far:.1f}s)") |
| print(f" Estimated overhead: {estimated_overhead_s}s (model ~7s + tokenizer ~5s + HF stream ~50s + compile/warmup ~60s)") |
| print(f" Training budget: {available_training_s}s → max {args.max_steps} steps") |
| print("-" * 40) |
| print(" MODEL CAPACITY (balanced for 6GB + 128k vocab):") |
| print(f" sdr_dim: {args.sdr_dim}, latent_dim: {args.latent_dim}") |
| print(f" kan_degree: {args.kan_degree}, num_hyperedges: {args.num_hyperedges}") |
| print("-" * 40) |
| print(" TRAINING:") |
| print(" batch_size=16, lr=0.002, weight_decay=1e-5, accumulation=1") |
| print(" ODE solver: euler with step_size=0.05") |
| print("=" * 80) |
|
|
| |
| if args.accelerator == "gpu": |
| assert torch.cuda.is_available(), "GPU required but CUDA not available" |
|
|
| |
| datamodule = StreamingDataModule( |
| data_dir=args.data_dir, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| ) |
|
|
| |
| compile_model = not args.no_compile |
| model = SEMV6LightningModule( |
| sdr_dim=args.sdr_dim, |
| sparsity=args.sparsity, |
| num_hyperedges=args.num_hyperedges, |
| kan_degree=args.kan_degree, |
| latent_dim=args.latent_dim, |
| lambda_sparse=args.lambda_sparse, |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| max_steps=args.max_steps, |
| compile_model=compile_model, |
| lr_schedule=args.lr_schedule, |
| warmup_pct=args.warmup_pct, |
| lr_div_factor=args.lr_div_factor, |
| lr_final_div_factor=args.lr_final_div_factor, |
| ode_solver=args.ode_solver, |
| ode_step_size=args.ode_step_size, |
| enable_text_decoder=args.enable_mtp, |
| num_predict=args.num_predict, |
| mtp_loss_weight=args.mtp_loss_weight, |
| ) |
|
|
| |
| callbacks = [ |
| |
| TimingCallback( |
| script_start_time=_SCRIPT_START_TIME, |
| target_time=args.target_time, |
| ), |
| |
| ModelCheckpoint( |
| dirpath=args.checkpoint_dir, |
| filename="semv6-{step:06d}-{train_loss:.4f}", |
| monitor="train_loss", |
| mode="min", |
| save_top_k=args.save_top_k, |
| every_n_train_steps=1000, |
| save_last=True, |
| ), |
| |
| LearningRateMonitor(logging_interval="step"), |
| |
| RichProgressBar(), |
| |
| RichModelSummary(max_depth=2), |
| ] |
|
|
| |
| test_prompts = [ |
| "The capital of France is", |
| "Water boils at", |
| "The human body has", |
| "Plants produce oxygen through", |
| "The speed of light is", |
| ] |
|
|
| |
| fast_validator = FastValidator(test_prompts=test_prompts) |
| grammar_client = LanguageToolClient(url="http://localhost:8081") |
| grammar_validator = GrammarValidator( |
| client=grammar_client, |
| test_prompts=test_prompts |
| ) |
|
|
| |
| callbacks.append( |
| CombinedValidationCallback( |
| fast_validator=fast_validator, |
| grammar_validator=grammar_validator, |
| test_prompts=test_prompts, |
| frequency=200, |
| ) |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| logger = TensorBoardLogger( |
| save_dir="lightning_logs", |
| name="semv6_training", |
| ) |
|
|
| |
| trainer = pl.Trainer( |
| |
| accelerator=args.accelerator, |
| devices=args.devices, |
| precision=args.precision, |
| |
| |
|
|
| |
| max_steps=args.max_steps, |
| |
| |
| accumulate_grad_batches=args.accumulate_grad_batches, |
| gradient_clip_val=args.gradient_clip_val, |
| gradient_clip_algorithm="norm", |
|
|
| |
| |
| benchmark=True, |
| |
| deterministic=False, |
| |
|
|
| |
| logger=logger, |
| callbacks=callbacks, |
| log_every_n_steps=args.log_every_n_steps, |
| val_check_interval=args.val_check_interval, |
| enable_model_summary=False, |
| enable_progress_bar=True, |
|
|
| |
| |
| limit_val_batches=0, |
| num_sanity_val_steps=0, |
|
|
| |
| |
| |
| enable_checkpointing=True, |
|
|
| |
| |
| |
| |
|
|
| |
| strategy="auto", |
| sync_batchnorm=False, |
| ) |
|
|
| |
| model_init_time = time.perf_counter() - _SCRIPT_START_TIME |
| print("=" * 80) |
| print("SEM V6 PyTorch Lightning Training (Latent Prediction - Lester 2024)") |
| print("=" * 80) |
| print(f"⏱️ Model created in {model_init_time:.1f}s (target: {args.target_time}s)") |
| print("-" * 80) |
| print(f"Data directory: {args.data_dir}") |
| print(f"Batch size: {args.batch_size}") |
| print(f"Gradient accumulation: {args.accumulate_grad_batches}") |
| print(f"Effective batch size: {args.batch_size * args.accumulate_grad_batches}") |
| print(f"Max steps: {args.max_steps}") |
| print(f"Precision: {args.precision}") |
| print(f"Devices: {args.devices} {args.accelerator}") |
| print("-" * 80) |
| print("LEARNING RATE SCHEDULE:") |
| print(f" Strategy: {args.lr_schedule.upper()}") |
| print(f" Base LR: {args.learning_rate}") |
| print(f" Weight decay (L2): {args.weight_decay}") |
| print(f" Warmup percentage: {args.warmup_pct * 100:.1f}% ({int(args.max_steps * args.warmup_pct)} steps)") |
| if args.lr_schedule == "onecycle": |
| print(f" Initial LR: {args.learning_rate / args.lr_div_factor:.6f}") |
| print(f" Max LR: {args.learning_rate}") |
| print(f" Final LR: {args.learning_rate / args.lr_div_factor / args.lr_final_div_factor:.8f}") |
| print(" Momentum cycling: 0.85 -> 0.95 -> 0.85") |
| elif args.lr_schedule == "warmrestarts": |
| print(f" Restarts every: {int(args.max_steps * args.warmup_pct)} steps") |
| print(f" Period multiplier: 2.0x after each restart") |
| print(f" Min LR: {args.learning_rate / 1000:.8f}") |
| else: |
| print(f" Min LR: {args.learning_rate / 100:.8f}") |
| print("-" * 80) |
| print(f"SDR dimension: {args.sdr_dim}") |
| print(f"Latent dimension: {args.latent_dim} (compression ratio: {args.sdr_dim / args.latent_dim:.1f}x)") |
| print(f"L1 sparsity coefficient: {args.lambda_sparse}") |
| print("-" * 80) |
| print("ODE SOLVER (Module C: SolitonPropagator):") |
| print(f" Solver: {args.ode_solver}") |
| print(f" Step size: {args.ode_step_size}") |
| print("-" * 80) |
| print("MULTI-TOKEN PREDICTION (MTP):") |
| print(f" Enabled: {args.enable_mtp}") |
| if args.enable_mtp: |
| print(f" Num tokens: {args.num_predict}") |
| print(f" Loss weight: {args.mtp_loss_weight}") |
| print(f" Tokenizer: meta-llama/Meta-Llama-3-8B") |
| print("=" * 80) |
|
|
| |
| trainer.fit(model, datamodule=datamodule) |
|
|
| print("=" * 80) |
| print("Training complete!") |
| print(f"Best checkpoint: {trainer.checkpoint_callback.best_model_path}") |
| print("=" * 80) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|