sem-v6-training / train_lightning.py
icarus112's picture
Upload train_lightning.py with huggingface_hub
e32012b verified
#!/usr/bin/env python3
"""
SEM V6 PyTorch Lightning Training Script
Production-ready training with all best practices:
- Automatic mixed precision (AMP)
- Gradient accumulation
- Learning rate scheduling
- Model checkpointing
- Progress bars and rich logging
- Distributed training support
- Early stopping
- Gradient clipping
- Learning rate finder
- Model pruning/quantization ready
- TensorBoard logging
- Memory-efficient data streaming with prefetching
Reference:
- PyTorch Lightning 2.x best practices
- AGENTS.md: GPU-first, streaming data, maximum speed
"""
import argparse
import logging
import math
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch
# Global timing - track wall-clock from absolute start
_SCRIPT_START_TIME = time.perf_counter()
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import datasets # For HF streaming
from pytorch_lightning.callbacks import (
ModelCheckpoint,
EarlyStopping,
LearningRateMonitor,
RichProgressBar,
RichModelSummary,
)
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
# SEM V6 imports
import sys
sys.path.insert(0, "src")
sys.path.insert(0, "ChebyKan_cuda_op")
from sem_v6.sem_v6 import SEMV6
from sem_v6.data.streaming_dataset import StreamingDataset
from sem_v6.modules.text_decoder import multi_token_prediction_loss
from sem_v6.validation.callbacks import CombinedValidationCallback
from sem_v6.validation.validators import FastValidator, GrammarValidator
from sem_v6.validation.grammar_checker import LanguageToolClient
# Enable Tensor Core acceleration (RTX 3060 and above)
torch.set_float32_matmul_precision('high')
class TimingCallback(pl.Callback):
"""
Callback to track wall-clock timing for 100s training target.
Tracks:
- Time to first batch (includes data loading startup)
- Training time per step
- Total elapsed time from script start
"""
def __init__(self, script_start_time: float, target_time: float = 100.0):
super().__init__()
self.script_start_time = script_start_time
self.target_time = target_time
self.fit_start_time: Optional[float] = None
self.first_batch_time: Optional[float] = None
self.training_start_time: Optional[float] = None
self.step_count = 0
self.total_training_time = 0.0
def on_fit_start(self, trainer, pl_module):
self.fit_start_time = time.perf_counter()
elapsed = self.fit_start_time - self.script_start_time
print(f"\n⏱️ [TIMING] Model initialization took: {elapsed:.2f}s")
remaining = self.target_time - elapsed
print(f"⏱️ [TIMING] Time budget remaining: {remaining:.2f}s / {self.target_time}s")
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
if self.first_batch_time is None:
self.first_batch_time = time.perf_counter()
data_load_time = self.first_batch_time - self.fit_start_time
elapsed = self.first_batch_time - self.script_start_time
print(f"⏱️ [TIMING] Data loading/first batch took: {data_load_time:.2f}s")
print(f"⏱️ [TIMING] Total startup overhead: {elapsed:.2f}s")
remaining = self.target_time - elapsed
print(f"⏱️ [TIMING] Time budget for training: {remaining:.2f}s")
self.training_start_time = time.perf_counter()
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.step_count += 1
elapsed = time.perf_counter() - self.script_start_time
remaining = self.target_time - elapsed
# Time-based early stopping: stop with 2s buffer before target
if remaining <= 2.0:
print(f"\n⏱️ [TIMING] Stopping training - approaching time limit ({elapsed:.1f}s / {self.target_time}s)")
trainer.should_stop = True
return
# Check time budget every 50 steps
if self.step_count % 50 == 0:
training_elapsed = time.perf_counter() - self.training_start_time
steps_per_sec = self.step_count / training_elapsed if training_elapsed > 0 else 0
print(f"⏱️ [TIMING] Step {self.step_count}: {elapsed:.1f}s elapsed, {remaining:.1f}s remaining, {steps_per_sec:.2f} steps/s")
def on_fit_end(self, trainer, pl_module):
end_time = time.perf_counter()
total_elapsed = end_time - self.script_start_time
training_time = end_time - self.training_start_time if self.training_start_time else 0
startup_time = total_elapsed - training_time
print("\n" + "=" * 80)
print("⏱️ TIMING SUMMARY (100s Target)")
print("=" * 80)
print(f" Startup (init + data load): {startup_time:.2f}s")
print(f" Training time: {training_time:.2f}s")
print(f" Total wall-clock time: {total_elapsed:.2f}s")
print("-" * 80)
print(f" Steps completed: {self.step_count}")
if training_time > 0:
print(f" Training speed: {self.step_count / training_time:.2f} steps/s")
if total_elapsed <= self.target_time:
print(f"\n ✅ PASSED: Completed in {total_elapsed:.2f}s (under {self.target_time}s target)")
else:
overage = total_elapsed - self.target_time
print(f"\n ❌ FAILED: Took {total_elapsed:.2f}s ({overage:.2f}s over {self.target_time}s target)")
print("=" * 80)
def latent_prediction_loss(
z_pred: torch.Tensor,
x_next: torch.Tensor,
compressor: nn.Module,
lambda_sparse: float = 0.001,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute loss in latent space (not SDR reconstruction).
Based on Lester et al. 2024 "Training LLMs over Neurally Compressed Text":
predicting compressed representation is more learnable than reconstructing
the full SDR. This is the core training objective for SEM V6.
Loss function:
L = ||z_pred - compress(x_next)||² + λ||z_pred||₁
Where:
- z_pred: Output from propagator (predicted next latent)
- x_next: Ground truth next SDR
- compress(): NeuralCompressor module
- λ: Sparsity coefficient (default 0.001)
Args:
z_pred: Predicted next latent (batch, latent_dim) from propagator
x_next: Ground truth next SDR (batch, sdr_dim) - will be compressed
compressor: NeuralCompressor module to compress x_next
lambda_sparse: L1 sparsity coefficient for regularization. Default: 0.001
Returns:
loss: Total loss scalar (MSE + L1)
metrics: Dict with individual loss components:
- mse_loss: Mean squared error between predicted and target latent
- l1_loss: L1 regularization on predicted latent
- total_loss: Combined loss
- cos_sim: Cosine similarity for monitoring prediction quality
Example:
>>> compressor = nn.Linear(2048, 256) # Simplified compressor
>>> z_pred = torch.randn(32, 256)
>>> x_next = torch.randn(32, 2048)
>>> loss, metrics = latent_prediction_loss(z_pred, x_next, compressor)
>>> loss.backward()
>>> metrics["mse_loss"] # Individual MSE component
>>> metrics["cos_sim"] # Prediction quality metric
References:
Lester, B., et al. (2024). "Training LLMs over Neurally Compressed Text."
arXiv:2404.03626
"""
# Compress ground truth to latent space (stop gradient - target only)
with torch.no_grad():
z_target = compressor(x_next.float())
# NaN safety for z_target
z_target = torch.nan_to_num(z_target, nan=0.0, posinf=100.0, neginf=-100.0)
# MSE in latent space - core prediction objective
mse_loss = F.mse_loss(z_pred, z_target)
# L1 sparsity regularization - encourages sparse latent representations
l1_loss = lambda_sparse * z_pred.abs().mean()
# Total loss
total_loss = mse_loss + l1_loss
# Cosine similarity for monitoring prediction quality (not in loss)
with torch.no_grad():
cos_sim = F.cosine_similarity(z_pred, z_target, dim=-1).mean()
metrics = {
"mse_loss": mse_loss.item(),
"l1_loss": l1_loss.item(),
"total_loss": total_loss.item(),
"cos_sim": cos_sim.item(),
}
return total_loss, metrics
class SEMV6LightningModule(pl.LightningModule):
"""
PyTorch Lightning wrapper for SEM V6 training.
Implements all best practices:
- Automatic optimization with gradient accumulation
- Learning rate scheduling
- Logging and metrics
- Checkpoint management
- Mixed precision training
"""
def __init__(
self,
sdr_dim: int = 1368,
sparsity: float = 0.05,
num_hyperedges: int = 2000,
kan_degree: int = 5,
latent_dim: int = 256,
lambda_sparse: float = 0.001,
learning_rate: float = 1e-3,
weight_decay: float = 1e-4,
warmup_steps: int = 1000,
max_steps: int = 100000,
compile_model: bool = True,
lr_schedule: str = "onecycle",
warmup_pct: float = 0.1,
lr_div_factor: float = 25.0,
lr_final_div_factor: float = 10000.0,
ode_solver: str = "rk4",
ode_step_size: float = 0.01,
enable_text_decoder: bool = True,
num_predict: int = 4,
mtp_loss_weight: float = 1.0,
):
super().__init__()
# Save hyperparameters for checkpointing
self.save_hyperparameters()
# Initialize SEM V6 model with latent space architecture
self.model = SEMV6(
sdr_dim=sdr_dim,
sparsity=sparsity,
num_hyperedges=num_hyperedges,
kan_degree=kan_degree,
latent_dim=latent_dim,
ode_solver=ode_solver,
ode_step_size=ode_step_size,
enable_text_decoder=enable_text_decoder,
num_predict=num_predict,
)
# Compile model for faster execution (PyTorch 2.0+)
if compile_model:
self.model = torch.compile(self.model, mode='reduce-overhead')
# Training metrics
self.train_loss = []
self.val_loss = []
def forward(self, text_batch, reward=None, mode="awake"):
"""Forward pass through SEM V6."""
return self.model(text_batch, reward=reward, mode=mode)
def training_step(self, batch, batch_idx):
"""
Training step with latent-space next-token prediction.
Based on Lester et al. 2024: predict latent representation of NEXT token,
not reconstruct full SDR. This is more efficient and learnable.
Loss function:
L = ||z_pred - compress(x_next)||² + λ||z_pred||₁
Lightning handles:
- Gradient accumulation
- Mixed precision
- Gradient clipping
- Optimizer stepping
"""
# Unpack batch (text samples) - need pairs for next-token prediction
text_batch = batch
# Split each document into sentences and create (sentence_i, sentence_i+1) pairs
# This fixes the cross-document bug where unrelated documents were paired
current_texts = []
next_texts = []
for document in text_batch:
# Split into sentences (on '. ' and '\n')
sentences = []
for chunk in document.replace('\n', '. ').split('. '):
sentence = chunk.strip()
if len(sentence) >= 20: # Filter out short fragments
sentences.append(sentence)
# Create (current, next) pairs from consecutive sentences within document
for i in range(len(sentences) - 1):
current_texts.append(sentences[i])
next_texts.append(sentences[i + 1])
# Cap total pairs at original batch size for consistent GPU utilization
if len(current_texts) >= len(text_batch):
break
if len(current_texts) >= len(text_batch):
break
# Handle edge case: no valid pairs found
if len(current_texts) < 2:
return None
# Forward pass on current texts -> get z_pred
z_pred, action_logits, value = self(current_texts, reward=0.0, mode="awake")
# NaN safety: use nan_to_num which preserves gradients better
z_pred = torch.nan_to_num(z_pred, nan=0.0, posinf=100.0, neginf=-100.0)
z_pred = torch.clamp(z_pred, -100.0, 100.0)
# Get next SDRs (targets) - encode next texts to SDR
with torch.no_grad():
next_sdrs = self.model.encoder(next_texts) # (batch-1, sdr_dim)
# Compute latent prediction loss (Lester 2024)
loss, metrics = latent_prediction_loss(
z_pred=z_pred,
x_next=next_sdrs,
compressor=self.model.compressor,
lambda_sparse=self.hparams.lambda_sparse,
)
# Add Multi-Token Prediction loss if decoder is enabled
if self.model.text_decoder is not None:
mtp_loss, mtp_metrics = multi_token_prediction_loss(
z_pred=z_pred,
target_text=next_texts,
decoder=self.model.text_decoder,
max_length=128,
loss_weight=self.hparams.mtp_loss_weight,
)
loss = 0.1 * loss + mtp_loss
metrics.update({f"mtp_{k}": v for k, v in mtp_metrics.items()})
# Batch size for logging
batch_size = len(current_texts)
# Log primary metrics
self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
batch_size=batch_size,
)
self.log(
"mse_loss",
metrics["mse_loss"],
on_step=True,
on_epoch=True,
batch_size=batch_size,
)
self.log(
"l1_loss",
metrics["l1_loss"],
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
self.log(
"cos_sim",
metrics["cos_sim"],
on_step=True,
on_epoch=True,
prog_bar=True,
batch_size=batch_size,
)
# Meta-controller metrics
self.log(
"train_action",
float(action_logits.argmax()),
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
self.log(
"train_value",
float(value),
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
# MTP metrics (if decoder is enabled)
if "mtp_mtp_loss" in metrics:
self.log(
"mtp_loss",
metrics["mtp_mtp_loss"],
on_step=True,
on_epoch=True,
prog_bar=True,
batch_size=batch_size,
)
self.log(
"mtp_accuracy",
metrics.get("mtp_avg_accuracy", 0.0),
on_step=True,
on_epoch=True,
prog_bar=True,
batch_size=batch_size,
)
# GPU memory usage
if torch.cuda.is_available():
gpu_mem = torch.cuda.max_memory_allocated() / 1e9
self.log(
"gpu_memory_gb",
gpu_mem,
on_step=False,
on_epoch=True,
batch_size=batch_size,
)
return loss
def validation_step(self, batch, batch_idx):
"""Validation step with latent prediction loss."""
text_batch = batch
# For next-token prediction, we need at least 2 samples
if len(text_batch) < 2:
return None
current_texts = text_batch[:-1]
next_texts = text_batch[1:]
# Forward pass
z_pred, action_logits, value = self(current_texts, reward=0.0, mode="awake")
# Get next SDRs (targets)
with torch.no_grad():
next_sdrs = self.model.encoder(next_texts)
# Compute latent prediction loss
loss, metrics = latent_prediction_loss(
z_pred=z_pred,
x_next=next_sdrs,
compressor=self.model.compressor,
lambda_sparse=self.hparams.lambda_sparse,
)
batch_size = len(current_texts)
self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=batch_size)
self.log("val_mse_loss", metrics["mse_loss"], on_step=False, on_epoch=True, batch_size=batch_size)
self.log("val_cos_sim", metrics["cos_sim"], on_step=False, on_epoch=True, batch_size=batch_size)
return loss
def configure_optimizers(self):
"""
Configure optimizer with aggressive learning rate schedules.
Supports multiple scheduling strategies:
1. OneCycleLR: Aggressive triangular schedule (fastest convergence)
2. CosineAnnealingWarmRestarts: Warm restarts for escaping local minima
3. CosineAnnealingLR: Smooth decay (default)
4. ExponentialLR: Geometric decay
Uses:
- AdamW optimizer with weight decay (DeepSeek V3 settings)
- Mixed precision training via Lightning (16-mixed)
- Gradient clipping via Trainer
"""
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
betas=(0.9, 0.95), # DeepSeek V3 settings
eps=1e-8,
)
# Calculate warmup steps (10% of total by default)
warmup_steps = min(1000, self.hparams.max_steps // 10)
# STRATEGY 1: OneCycleLR (RECOMMENDED for aggressive training)
# Triangular schedule: Linear warmup -> peak -> linear decay
# Achieves fastest convergence with aggressive LR exploration
scheduler_onecycle = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.hparams.learning_rate,
total_steps=self.hparams.max_steps,
pct_start=self.hparams.warmup_pct, # % of cycle spent increasing LR
anneal_strategy='cos', # Cosine annealing during decay phase
cycle_momentum=True, # Cycle momentum: high at start, low at end
base_momentum=0.85,
max_momentum=0.95,
div_factor=self.hparams.lr_div_factor, # Initial LR = max_LR / div_factor
final_div_factor=self.hparams.lr_final_div_factor, # Final LR = initial_LR / final_div_factor
)
# STRATEGY 2: CosineAnnealingWarmRestarts (for warm-starting)
# Multiple annealing cycles with periodic restarts
# Helps escape local minima by periodically reheating LR
scheduler_warmrestarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=int(self.hparams.max_steps * self.hparams.warmup_pct), # Initial period
T_mult=2, # Double the period after each restart (must be int)
eta_min=self.hparams.learning_rate / 1000, # Minimum LR floor
)
# STRATEGY 3: Standard CosineAnnealingLR with warmup wrapper
# Smooth cosine decay after linear warmup
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.hparams.max_steps - warmup_steps,
eta_min=self.hparams.learning_rate / 100,
)
# Select scheduler based on hparams
if self.hparams.lr_schedule == "onecycle":
scheduler = scheduler_onecycle
elif self.hparams.lr_schedule == "warmrestarts":
scheduler = scheduler_warmrestarts
elif self.hparams.lr_schedule == "cosine":
scheduler = scheduler_cosine
else:
scheduler = scheduler_onecycle # Default
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step", # Step-based (not epoch-based) for granular control
"frequency": 1, # Apply every step
},
}
class StreamingDataModule(pl.LightningDataModule):
"""
Lightning DataModule for streaming FineWeb data.
Handles:
- Memory-mapped shard loading
- Multi-worker prefetching
- GPU data pinning
- Distributed data partitioning
"""
def __init__(
self,
data_dir: Optional[Path] = None,
batch_size: int = 32,
num_workers: int = 8,
val_split: float = 0.0, # No validation by default (streaming setup)
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.val_split = val_split
self.train_shards = []
self.val_shards = []
def setup(self, stage: Optional[str] = None):
"""Discover and partition shard files."""
if self.data_dir is None:
print("No data_dir provided - using HuggingFace streaming")
return
# Find all shard files
all_shards = sorted(
list(self.data_dir.glob("*.pt")) + list(self.data_dir.glob("*.npy"))
)
if len(all_shards) == 0:
raise ValueError(f"No shard files found in {self.data_dir}")
# Split into train/val if needed
if self.val_split > 0:
val_count = int(len(all_shards) * self.val_split)
self.val_shards = all_shards[:val_count]
self.train_shards = all_shards[val_count:]
else:
self.train_shards = all_shards
self.val_shards = []
print(
f"Found {len(self.train_shards)} training shards, {len(self.val_shards)} validation shards"
)
def train_dataloader(self):
"""Create training DataLoader with streaming dataset."""
if self.data_dir is None:
# Stream from HuggingFace
dataset = datasets.load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split="train",
streaming=True,
)
# Extract text content
def extract_text(examples):
return [x["text"] for x in examples]
# Create iterable dataset that yields batches of text
class HFIterable(torch.utils.data.IterableDataset):
def __init__(self, hf_dataset, batch_size):
self.hf_dataset = hf_dataset
self.batch_size = batch_size
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
dataset = self.hf_dataset
if worker_info is not None:
# Simple sharding: use islice to split the stream
# worker 0 gets 0, 4, 8...
# worker 1 gets 1, 5, 9...
import itertools
dataset = itertools.islice(
dataset, worker_info.id, None, worker_info.num_workers
)
batch = []
for item in dataset:
# Full text (no truncation), handled by background workers
text = item["text"]
batch.append(text)
if len(batch) == self.batch_size:
yield batch
batch = []
iterable_dataset = HFIterable(dataset, self.batch_size)
return DataLoader(
iterable_dataset,
batch_size=None, # Generator handles batching
num_workers=self.num_workers, # Enable workers!
pin_memory=True,
prefetch_factor=2 if self.num_workers > 0 else None,
)
dataset = StreamingDataset(
shard_paths=self.train_shards,
batch_size=self.batch_size,
device="cuda" if torch.cuda.is_available() else "cpu",
prefetch=True, # Background prefetching
cycle=True, # Infinite cycling
)
return DataLoader(
dataset,
batch_size=None, # StreamingDataset handles batching
num_workers=self.num_workers,
pin_memory=True, # Fast GPU transfer
persistent_workers=True if self.num_workers > 0 else False,
prefetch_factor=4 if self.num_workers > 0 else None,
)
def val_dataloader(self):
"""Create validation DataLoader (if validation shards exist)."""
if len(self.val_shards) == 0:
return None
dataset = StreamingDataset(
shard_paths=self.val_shards,
batch_size=self.batch_size,
device="cuda" if torch.cuda.is_available() else "cpu",
prefetch=True,
cycle=False, # Don't cycle validation data
)
return DataLoader(
dataset,
batch_size=None,
num_workers=self.num_workers // 2, # Fewer workers for validation
pin_memory=True,
persistent_workers=True if self.num_workers > 0 else False,
)
def main():
"""Main training entry point."""
parser = argparse.ArgumentParser(description="SEM V6 Lightning Training")
# Data arguments
parser.add_argument(
"--data_dir",
type=Path,
required=False,
default=None,
help="Directory with shard files",
)
# Batch size (6 default for 6GB VRAM, use 16 with --fast-100s)
parser.add_argument("--batch_size", type=int, default=6, help="Batch size")
parser.add_argument("--num_workers", type=int, default=8, help="DataLoader workers")
# Model arguments
# Dim 768 (3x 256) - Aggressive scaling
parser.add_argument("--sdr_dim", type=int, default=768, help="SDR dimension")
parser.add_argument("--sparsity", type=float, default=0.05, help="SDR sparsity")
parser.add_argument(
"--kan_degree", type=int, default=3, help="ChebyKAN polynomial degree"
)
parser.add_argument(
"--num_hyperedges", type=int, default=2000, help="Number of hyperedges"
)
parser.add_argument(
"--latent_dim", type=int, default=256, help="Latent space dimension (Lester 2024)"
)
parser.add_argument(
"--lambda_sparse", type=float, default=0.001, help="L1 sparsity coefficient for latent loss"
)
# ODE solver arguments (for Module C: SolitonPropagator)
parser.add_argument(
"--ode_solver", type=str, default="rk4",
choices=["rk4", "euler", "dopri5", "midpoint"],
help="ODE solver method: rk4 (stable, default), euler (fast for training)"
)
parser.add_argument(
"--ode_step_size", type=float, default=0.01,
help="Step size for fixed-step ODE solvers (default 0.01, use 0.05 with euler)"
)
# Training arguments
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="Base learning rate"
)
parser.add_argument("--weight_decay", type=float, default=1e-4, help="AdamW weight decay (L2 regularization)")
parser.add_argument(
"--max_steps", type=int, default=100000, help="Max training steps"
)
parser.add_argument(
"--accumulate_grad_batches", type=int, default=1,
help="Gradient accumulation steps (effective_batch = batch_size * this)"
)
parser.add_argument(
"--gradient_clip_val", type=float, default=1.0, help="Gradient clipping norm threshold"
)
# Learning rate schedule arguments
parser.add_argument(
"--lr_schedule", type=str, default="onecycle",
choices=["onecycle", "warmrestarts", "cosine"],
help="Learning rate schedule strategy"
)
parser.add_argument(
"--warmup_pct", type=float, default=0.1,
help="Percentage of training used for warmup (0.0-1.0)"
)
parser.add_argument(
"--lr_div_factor", type=float, default=25.0,
help="OneCycleLR: Initial LR = max_LR / div_factor"
)
parser.add_argument(
"--lr_final_div_factor", type=float, default=10000.0,
help="OneCycleLR: Final LR = initial_LR / final_div_factor"
)
# Lightning Trainer arguments
parser.add_argument(
"--accelerator", type=str, default="gpu", choices=["gpu", "cpu"]
)
parser.add_argument("--devices", type=int, default=1, help="Number of GPUs")
parser.add_argument(
"--precision", type=str, default="16-mixed", help="Training precision"
)
parser.add_argument(
"--log_every_n_steps", type=int, default=10, help="Logging frequency"
)
parser.add_argument(
"--val_check_interval", type=float, default=1000, help="Validation frequency"
)
# Checkpointing
parser.add_argument(
"--checkpoint_dir",
type=Path,
default=Path("checkpoints"),
help="Checkpoint directory",
)
parser.add_argument(
"--save_top_k", type=int, default=3, help="Save top K checkpoints"
)
parser.add_argument(
"--no-compile",
action="store_true",
help="Disable torch.compile (useful for debugging)",
)
# Fast 100-second training mode (agentic coherence optimization)
parser.add_argument(
"--fast-100s",
action="store_true",
help="Enable optimized settings for 100-second agentic coherence training"
)
parser.add_argument(
"--target-time",
type=float,
default=160.0,
help="Target wall-clock time in seconds (default: 160.0, includes ~125s startup overhead)"
)
# Multi-Token Prediction (MTP) arguments
parser.add_argument(
"--enable-mtp",
action="store_true",
default=True,
help="Enable Multi-Token Prediction text decoder (default: True)"
)
parser.add_argument(
"--disable-mtp",
action="store_true",
help="Disable Multi-Token Prediction text decoder"
)
parser.add_argument(
"--num-predict",
type=int,
default=4,
help="Number of future tokens to predict with MTP (default: 4)"
)
parser.add_argument(
"--mtp-loss-weight",
type=float,
default=1.0,
help="Weight for MTP loss (default: 1.0)"
)
args = parser.parse_args()
# Handle MTP enable/disable
args.enable_mtp = not args.disable_mtp
# Apply --fast-100s optimizations if flag is set
# Based on research: OneCycleLR super-convergence, euler solver, no accumulation
if args.fast_100s:
# === TRAINING HYPERPARAMETERS ===
args.batch_size = 16 # Larger batches, no accumulation
args.learning_rate = 0.002 # 2x for super-convergence
args.weight_decay = 1e-5 # Reduced for super-convergence
args.lr_div_factor = 10.0 # Initial LR = 0.002 / 10 = 0.0002
args.lr_final_div_factor = 100.0 # Final LR = 0.0002 / 100 = 0.000002
args.accumulate_grad_batches = 1 # CRITICAL: no accumulation for 100 updates
# === MODEL CAPACITY (balanced for 6GB VRAM with 128k vocab) ===
# Per CLAUDE.md: "Model parameters CAN be increased"
# BUT 128k vocab tokenizer + text decoder needs ~1GB alone
args.sdr_dim = 1024 # 768 → 1024: moderate increase
args.latent_dim = 256 # Keep 256: balances with 128k vocab decoder
args.kan_degree = 4 # 3 → 4: slightly more expressive
args.num_hyperedges = 2000 # Keep 2000: sufficient for short training
# === ODE SOLVER (speed optimization) ===
args.ode_solver = "euler" # Fast ODE solver (4x fewer evaluations)
args.ode_step_size = 0.05 # Larger step size (5x fewer steps)
# === TIME BUDGET CALCULATION ===
# ACTUAL overhead: ~60s for model init + tokenizer + HuggingFace streaming startup
# - Model init: ~7s
# - Llama tokenizer download: ~5s
# - HuggingFace streaming first batch: ~50s
# - torch.compile / GPU warmup pause: ~60s (mysterious quiet period)
# Conservative estimate: 125s overhead
# At ~4 steps/sec with remaining time
estimated_overhead_s = 125.0
available_training_s = args.target_time - estimated_overhead_s
estimated_steps_per_sec = 4.0 # Conservative estimate
args.max_steps = max(50, int(available_training_s * estimated_steps_per_sec))
elapsed_so_far = time.perf_counter() - _SCRIPT_START_TIME
print("=" * 80)
print("FAST 100s MODE ENABLED - Optimized for agentic coherence")
print(f" Target time: {args.target_time}s (elapsed so far: {elapsed_so_far:.1f}s)")
print(f" Estimated overhead: {estimated_overhead_s}s (model ~7s + tokenizer ~5s + HF stream ~50s + compile/warmup ~60s)")
print(f" Training budget: {available_training_s}s → max {args.max_steps} steps")
print("-" * 40)
print(" MODEL CAPACITY (balanced for 6GB + 128k vocab):")
print(f" sdr_dim: {args.sdr_dim}, latent_dim: {args.latent_dim}")
print(f" kan_degree: {args.kan_degree}, num_hyperedges: {args.num_hyperedges}")
print("-" * 40)
print(" TRAINING:")
print(" batch_size=16, lr=0.002, weight_decay=1e-5, accumulation=1")
print(" ODE solver: euler with step_size=0.05")
print("=" * 80)
# Validate CUDA availability
if args.accelerator == "gpu":
assert torch.cuda.is_available(), "GPU required but CUDA not available"
# Create data module
datamodule = StreamingDataModule(
data_dir=args.data_dir,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
# Create model with latent-space architecture (Lester 2024)
compile_model = not args.no_compile
model = SEMV6LightningModule(
sdr_dim=args.sdr_dim,
sparsity=args.sparsity,
num_hyperedges=args.num_hyperedges,
kan_degree=args.kan_degree,
latent_dim=args.latent_dim,
lambda_sparse=args.lambda_sparse,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
max_steps=args.max_steps,
compile_model=compile_model,
lr_schedule=args.lr_schedule,
warmup_pct=args.warmup_pct,
lr_div_factor=args.lr_div_factor,
lr_final_div_factor=args.lr_final_div_factor,
ode_solver=args.ode_solver,
ode_step_size=args.ode_step_size,
enable_text_decoder=args.enable_mtp,
num_predict=args.num_predict,
mtp_loss_weight=args.mtp_loss_weight,
)
# Setup callbacks
callbacks = [
# Timing callback for 100s target tracking
TimingCallback(
script_start_time=_SCRIPT_START_TIME,
target_time=args.target_time,
),
# Model checkpointing (save best models)
ModelCheckpoint(
dirpath=args.checkpoint_dir,
filename="semv6-{step:06d}-{train_loss:.4f}",
monitor="train_loss",
mode="min",
save_top_k=args.save_top_k,
every_n_train_steps=1000,
save_last=True,
),
# Learning rate monitoring
LearningRateMonitor(logging_interval="step"),
# Rich progress bar
RichProgressBar(),
# Model summary
RichModelSummary(max_depth=2),
]
# Validation callbacks for English coherence monitoring
test_prompts = [
"The capital of France is",
"Water boils at",
"The human body has",
"Plants produce oxygen through",
"The speed of light is",
]
# Create validators
fast_validator = FastValidator(test_prompts=test_prompts)
grammar_client = LanguageToolClient(url="http://localhost:8081")
grammar_validator = GrammarValidator(
client=grammar_client,
test_prompts=test_prompts
)
# Add combined validation callback
callbacks.append(
CombinedValidationCallback(
fast_validator=fast_validator,
grammar_validator=grammar_validator,
test_prompts=test_prompts,
frequency=200,
)
)
# Early stopping (optional, disable for continuous training)
# callbacks.append(
# EarlyStopping(
# monitor='val_loss',
# patience=10,
# mode='min',
# )
# )
# Setup logger
logger = TensorBoardLogger(
save_dir="lightning_logs",
name="semv6_training",
)
# Create Lightning Trainer optimized for maximum speed
trainer = pl.Trainer(
# ==================== HARDWARE CONFIGURATION ====================
accelerator=args.accelerator,
devices=args.devices,
precision=args.precision, # 16-mixed: FP16 computation, FP32 accumulation
# For even faster training on RTX 3060, use "16" (pure FP16, less stable)
# Default "16-mixed" is safest for ODE-based models
# ==================== TRAINING HYPERPARAMETERS ====================
max_steps=args.max_steps,
# Gradient accumulation: effective_batch = batch_size * accumulate_grad_batches
# With 6GB VRAM: batch_size=6, accumulate=4 -> effective=24 (good convergence)
accumulate_grad_batches=args.accumulate_grad_batches,
gradient_clip_val=args.gradient_clip_val,
gradient_clip_algorithm="norm", # L2 norm clipping (more stable than 'value')
# ==================== PERFORMANCE OPTIMIZATIONS ====================
# cuDNN benchmark mode: optimize kernel selection for your hardware
benchmark=True,
# Allow non-deterministic ops for speed (disable reproducibility)
deterministic=False,
# TF32 is enabled via torch.set_float32_matmul_precision('high') at module level
# ==================== LOGGING & MONITORING ====================
logger=logger,
callbacks=callbacks,
log_every_n_steps=args.log_every_n_steps,
val_check_interval=args.val_check_interval,
enable_model_summary=False, # Skip model summary for faster startup
enable_progress_bar=True, # Rich progress bar for real-time feedback
# ==================== VALIDATION STRATEGY ====================
# Disable validation for streaming (infinite dataset, no validation set)
limit_val_batches=0,
num_sanity_val_steps=0,
# ==================== MEMORY OPTIMIZATION ====================
# Gradient checkpointing: trade compute for memory (disabled by default)
# Enable if OOM: enable_checkpointing=True
enable_checkpointing=True,
# ==================== DEBUGGING (PRODUCTION: DISABLED) ====================
# Uncomment for debugging:
# detect_anomaly=True, # Detect NaN/Inf propagation
# profiler="pytorch", # PyTorch profiler (slows training 10-30%)
# ==================== DISTRIBUTED TRAINING (SINGLE GPU) ====================
strategy="auto", # Auto-detect best strategy (single GPU -> no strategy)
sync_batchnorm=False, # No batch norm sync for single GPU
)
# Print training info
model_init_time = time.perf_counter() - _SCRIPT_START_TIME
print("=" * 80)
print("SEM V6 PyTorch Lightning Training (Latent Prediction - Lester 2024)")
print("=" * 80)
print(f"⏱️ Model created in {model_init_time:.1f}s (target: {args.target_time}s)")
print("-" * 80)
print(f"Data directory: {args.data_dir}")
print(f"Batch size: {args.batch_size}")
print(f"Gradient accumulation: {args.accumulate_grad_batches}")
print(f"Effective batch size: {args.batch_size * args.accumulate_grad_batches}")
print(f"Max steps: {args.max_steps}")
print(f"Precision: {args.precision}")
print(f"Devices: {args.devices} {args.accelerator}")
print("-" * 80)
print("LEARNING RATE SCHEDULE:")
print(f" Strategy: {args.lr_schedule.upper()}")
print(f" Base LR: {args.learning_rate}")
print(f" Weight decay (L2): {args.weight_decay}")
print(f" Warmup percentage: {args.warmup_pct * 100:.1f}% ({int(args.max_steps * args.warmup_pct)} steps)")
if args.lr_schedule == "onecycle":
print(f" Initial LR: {args.learning_rate / args.lr_div_factor:.6f}")
print(f" Max LR: {args.learning_rate}")
print(f" Final LR: {args.learning_rate / args.lr_div_factor / args.lr_final_div_factor:.8f}")
print(" Momentum cycling: 0.85 -> 0.95 -> 0.85")
elif args.lr_schedule == "warmrestarts":
print(f" Restarts every: {int(args.max_steps * args.warmup_pct)} steps")
print(f" Period multiplier: 2.0x after each restart")
print(f" Min LR: {args.learning_rate / 1000:.8f}")
else: # cosine
print(f" Min LR: {args.learning_rate / 100:.8f}")
print("-" * 80)
print(f"SDR dimension: {args.sdr_dim}")
print(f"Latent dimension: {args.latent_dim} (compression ratio: {args.sdr_dim / args.latent_dim:.1f}x)")
print(f"L1 sparsity coefficient: {args.lambda_sparse}")
print("-" * 80)
print("ODE SOLVER (Module C: SolitonPropagator):")
print(f" Solver: {args.ode_solver}")
print(f" Step size: {args.ode_step_size}")
print("-" * 80)
print("MULTI-TOKEN PREDICTION (MTP):")
print(f" Enabled: {args.enable_mtp}")
if args.enable_mtp:
print(f" Num tokens: {args.num_predict}")
print(f" Loss weight: {args.mtp_loss_weight}")
print(f" Tokenizer: meta-llama/Meta-Llama-3-8B")
print("=" * 80)
# Train the model
trainer.fit(model, datamodule=datamodule)
print("=" * 80)
print("Training complete!")
print(f"Best checkpoint: {trainer.checkpoint_callback.best_model_path}")
print("=" * 80)
if __name__ == "__main__":
main()