#!/usr/bin/env python3 """ SymbolicLight V1 pre-training entry point. Examples: python train_base.py --dry_run python train_base.py --batch_size 56 --total_tokens 3000000000 torchrun --nproc_per_node=4 train_base.py --batch_size 14 --total_tokens 3000000000 torchrun --nproc_per_node=4 train_base.py --resume --batch_size 14 --total_tokens 3000000000 """ import os import subprocess os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') if os.environ.get("SYMBOLICLIGHT_INSECURE_TLS", "").lower() in {"1", "true", "yes"}: print("[Net] [WARN] SYMBOLICLIGHT_INSECURE_TLS is ignored for security.") print("[Net] [WARN] Fix certificates or proxy settings instead of weakening TLS verification.") if False: print("[Net] [WARN] ⚠️ TLS certificate verification WEAKENED for HF downloads") os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', '300') os.environ.setdefault('HF_HUB_ETAG_TIMEOUT', '60') if os.path.exists('/etc/network_turbo'): try: result = subprocess.run( 'bash -c "source /etc/network_turbo && env | grep -i proxy"', shell=True, capture_output=True, text=True, timeout=5 ) for line in result.stdout.splitlines(): if '=' in line: var, value = line.split('=', 1) os.environ[var] = value if os.environ.get('http_proxy') or os.environ.get('https_proxy'): print("[Net] [OK] AutoDL network_turbo proxy loaded") except Exception: pass _USE_MODELSCOPE = False _has_proxy = bool(os.environ.get('http_proxy') or os.environ.get('https_proxy')) if _has_proxy: os.environ.pop('HF_ENDPOINT', None) print("[Net] Proxy detected; using direct endpoint configuration") else: print("[Net] No proxy detected. Set HF_ENDPOINT manually if your environment requires a mirror.") import sys import json import time import math import random import inspect import argparse from pathlib import Path from datetime import datetime import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, IterableDataset import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import datetime as dt sys.path.insert(0, str(Path(__file__).parent)) from model import SymbolicLightConfig, SymbolicLightModel from data_pipeline import ( DEFAULT_CURRICULUM_PHASE_1_RATIO, DEFAULT_CURRICULUM_PHASE_2_RATIO, DEFAULT_CURRICULUM_PRESET, MemmapDataset, StreamingParquetDataset, format_source_histogram, get_curriculum_phase, resolve_tokenizer_path, ) _PACKAGE_ROOT = Path(__file__).resolve().parent.parent _SL_TOKENIZER_PATH = str(_PACKAGE_ROOT / "tokenizer" / "sl_tokenizer.model") if not os.path.exists(_SL_TOKENIZER_PATH): raise FileNotFoundError(f"Tokenizer model not found: {_SL_TOKENIZER_PATH}") from train_tokenizer import SLTokenizer print(f"[Tokenizer] [OK] SL tokenizer: {_SL_TOKENIZER_PATH}") def setup_distributed(): """Initialize distributed training when launched via torchrun.""" if 'RANK' in os.environ: rank = int(os.environ['RANK']) local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device(f'cuda:{local_rank}') backend = os.environ.get('DIST_BACKEND', 'nccl') else: device = torch.device('cpu') backend = 'gloo' timeout = dt.timedelta(minutes=30) try: if device.type == "cuda": dist.init_process_group(backend, timeout=timeout, device_id=device) else: dist.init_process_group(backend, timeout=timeout) except Exception as e: if backend == 'nccl': print(f"[DDP] NCCL init failed ({e}), falling back to gloo...") dist.init_process_group('gloo', timeout=timeout) else: raise dist.barrier() return rank, local_rank, world_size else: return 0, 0, 1 def cleanup_distributed(): if dist.is_initialized(): dist.destroy_process_group() def is_main_process(rank): return rank == 0 DEFAULT_DATA_DIR = "./data/private_corpus" def build_data_recipe(data_dir): """Build a domain-level recipe without exposing source-level dataset names.""" return [ {"name": "reference-web", "local_dir": os.path.join(data_dir, "reference-web"), "split": "train", "weight": 0.25, "text_key": "text"}, {"name": "math-web", "local_dir": os.path.join(data_dir, "math-web"), "split": "train", "weight": 0.20, "text_key": "text"}, {"name": "code-text", "local_dir": os.path.join(data_dir, "code-text"), "split": "train", "weight": 0.15, "text_key": "text"}, {"name": "general-web", "local_dir": os.path.join(data_dir, "general-web"), "split": "train", "weight": 0.15, "text_key": "text"}, {"name": "academic-educational","local_dir": os.path.join(data_dir, "academic-educational"),"split": "train", "weight": 0.10, "text_key": "text"}, {"name": "open-educational", "local_dir": os.path.join(data_dir, "open-educational"), "split": "train", "weight": 0.08, "text_key": "text"}, {"name": "synthetic-narrative", "local_dir": os.path.join(data_dir, "synthetic-narrative"), "split": "train", "weight": 0.05, "text_key": "text"}, {"name": "translation", "local_dir": os.path.join(data_dir, "translation"), "split": "train", "weight": 0.02, "text_key": "translation"}, ] class SmokeTestStreamingDataset(IterableDataset): """Small synthetic stream used only to test the training loop.""" def __init__(self, seq_len=512, vocab_size=57344): self.seq_len = seq_len self.vocab_size = vocab_size self.enc = SLTokenizer(_SL_TOKENIZER_PATH) def __iter__(self): examples = [ "SymbolicLight uses sparse spike-gated computation for language modeling.", "This public smoke-test stream is not part of the reported training data.", "Use your own legally available corpus under the aggregate domain recipe.", ] token_buffer = [] while True: text = examples[len(token_buffer) % len(examples)] tokens = self.enc.encode(text, add_bos=False) token_buffer.extend(tokens) while len(token_buffer) >= self.seq_len + 1: chunk = token_buffer[:self.seq_len + 1] token_buffer = token_buffer[self.seq_len:] x = torch.tensor(chunk[:-1], dtype=torch.long).clamp(0, self.vocab_size - 1) y = torch.tensor(chunk[1:], dtype=torch.long).clamp(0, self.vocab_size - 1) yield x, y def get_lr(step, warmup_steps, total_steps, max_lr, min_lr=1e-5): if step < warmup_steps: return max_lr * step / warmup_steps if step >= total_steps: return min_lr progress = (step - warmup_steps) / (total_steps - warmup_steps) return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress)) def realign_data_iterator(dataloader, skip_target, rank, label="Resume", report_every=1000): """Replay the deterministic streaming pipeline until the saved batch offset.""" data_iter = iter(dataloader) if skip_target <= 0: return data_iter, 0 if is_main_process(rank): print(f"[{label}] Realigning data stream by skipping {skip_target} batches...") skip_count = 0 while skip_count < skip_target: try: next(data_iter) skip_count += 1 except StopIteration: data_iter = iter(dataloader) continue if report_every > 0 and skip_count % report_every == 0 and is_main_process(rank): print(f"[{label}] Skipped {skip_count}/{skip_target} batches...") if is_main_process(rank): print(f"[{label}] Data stream aligned after skipping {skip_count} batches") return data_iter, skip_count def build_mixed_dataset(args, rank, world_size, vocab_size, seed_offset=0): tokenizer_path = resolve_tokenizer_path(args.tokenizer_path) if args.data_bin: return MemmapDataset( data_bin_dir=args.data_bin, seq_len=args.max_seq_len, rank=rank, world_size=world_size, model_vocab_size=vocab_size, seed_offset=seed_offset, strict_no_repeat=not args.allow_source_restarts, ) return StreamingParquetDataset( data_dir=args.data_dir, tokenizer_path=tokenizer_path, seq_len=args.max_seq_len, rank=rank, world_size=world_size, model_vocab_size=vocab_size, seed_offset=seed_offset, max_oversample=args.max_oversample, strict_no_repeat=not args.allow_source_restarts, ) def maybe_switch_curriculum_phase(args, dataset, tokens_seen, current_phase, rank): phase, recipe = get_curriculum_phase( tokens_seen, args.total_tokens, phase1_ratio=args.curriculum_phase1_ratio, phase2_ratio=args.curriculum_phase2_ratio, preset=args.curriculum_preset, ) if phase != current_phase: dataset.set_phase(recipe) current_phase = phase if is_main_process(rank): print(f"\n[Curriculum] Entering Phase {phase}/3 at {tokens_seen / 1e9:.2f}B tokens") if hasattr(dataset, "get_phase_summary"): phase_summary = dataset.get_phase_summary() active_sources = phase_summary.get("active_sources", {}) unresolved_sources = phase_summary.get("unresolved_sources", []) if phase_summary.get("uses_legacy_spans"): print(" NOTE: current data_bin has no explicit source_spans; using legacy source_stats spans") if active_sources: print(f" Active sources: {format_source_histogram(active_sources)}") if unresolved_sources: print(f" Unresolved sources: {', '.join(unresolved_sources)}") for source in recipe: print(f" - {source['name']}: {source['weight'] * 100:.0f}%") return current_phase def merge_source_histograms(histograms): merged = {} for histogram in histograms: if not histogram: continue for source_id, count in histogram.items(): merged[source_id] = merged.get(source_id, 0) + int(count) return merged def merge_source_sampling_stats(stats_list): merged_sources = {} for snapshot in stats_list: if not snapshot: continue for source_id, entry in snapshot.get("sources", {}).items(): merged = merged_sources.setdefault( source_id, { "mode": entry.get("mode", snapshot.get("mode", "unknown")), "budget_is_exact": bool(entry.get("budget_is_exact", False)), "replicated": False, "sampled_train_tokens": 0, "sampled_windows": 0, "completed_files": 0, "active": False, "_all_exhausted": True, "_budget_values": [], "_window_values": [], "_file_values": [], "_remaining_window_values": [], }, ) replicated = bool(entry.get("replicated", False)) merged["replicated"] = merged["replicated"] or replicated merged["budget_is_exact"] = merged["budget_is_exact"] or bool(entry.get("budget_is_exact", False)) merged["sampled_train_tokens"] += int(entry.get("sampled_train_tokens", 0)) merged["sampled_windows"] += int(entry.get("sampled_windows", 0)) merged["completed_files"] += int(entry.get("completed_files", 0)) merged["active"] = merged["active"] or bool(entry.get("active", False)) merged["_all_exhausted"] = merged["_all_exhausted"] and bool(entry.get("exhausted", False)) budget = int(entry.get("unique_token_budget", 0) or 0) total_windows = int(entry.get("total_windows", 0) or 0) total_files = int(entry.get("total_files", 0) or 0) if budget > 0: merged["_budget_values"].append((budget, replicated)) if total_windows > 0: merged["_window_values"].append((total_windows, replicated)) remaining_windows = int(entry.get("remaining_windows", 0) or 0) if remaining_windows > 0: merged["_remaining_window_values"].append((remaining_windows, replicated)) if total_files > 0: merged["_file_values"].append((total_files, replicated)) for source_id, entry in merged_sources.items(): budget_values = entry.pop("_budget_values") window_values = entry.pop("_window_values") file_values = entry.pop("_file_values") remaining_window_values = entry.pop("_remaining_window_values") if budget_values: entry["unique_token_budget"] = ( max(value for value, _ in budget_values) if entry["replicated"] else sum(value for value, _ in budget_values) ) else: entry["unique_token_budget"] = 0 if window_values: entry["total_windows"] = ( max(value for value, _ in window_values) if entry["replicated"] else sum(value for value, _ in window_values) ) else: entry["total_windows"] = 0 if remaining_window_values: entry["remaining_windows"] = ( max(value for value, _ in remaining_window_values) if entry["replicated"] else sum(value for value, _ in remaining_window_values) ) else: entry["remaining_windows"] = 0 if file_values: entry["total_files"] = ( max(value for value, _ in file_values) if entry["replicated"] else sum(value for value, _ in file_values) ) else: entry["total_files"] = 0 entry["exhausted"] = entry.pop("_all_exhausted") budget = entry.get("unique_token_budget", 0) total_windows = entry.get("total_windows", 0) total_files = entry.get("total_files", 0) entry["epoch"] = ( float(entry["sampled_train_tokens"]) / float(budget) if budget > 0 else None ) if total_windows > 0: entry["coverage"] = float(entry["sampled_windows"]) / float(total_windows) elif total_files > 0: entry["coverage"] = float(entry["completed_files"]) / float(total_files) else: entry["coverage"] = None return {"sources": merged_sources} DATA_STATE_BUNDLE_KIND = "symboliclight_ranked_data_state" def collect_checkpoint_data_state(dataset, *, use_direct_dataset, is_ddp, rank, world_size): if not use_direct_dataset or dataset is None or not hasattr(dataset, "state_dict"): return None local_state = dataset.state_dict() if not is_ddp: return local_state gathered_states = [None for _ in range(world_size)] dist.all_gather_object(gathered_states, local_state) if not is_main_process(rank): return None return { "__kind__": DATA_STATE_BUNDLE_KIND, "__version__": 1, "world_size": int(world_size), "per_rank": { str(rank_id): gathered_states[rank_id] for rank_id in range(world_size) if gathered_states[rank_id] is not None }, } def resolve_rank_checkpoint_data_state(data_state, *, rank, world_size): if not data_state: return None, None if isinstance(data_state, dict) and data_state.get("__kind__") == DATA_STATE_BUNDLE_KIND: per_rank = data_state.get("per_rank", {}) selected = per_rank.get(str(rank)) stored_world_size = int(data_state.get("world_size", len(per_rank) or 1)) warning = None if stored_world_size != int(world_size): warning = ( f"checkpoint data_state was saved with world_size={stored_world_size}, " f"current world_size={world_size}; attempting rank-wise restore" ) if selected is None: available = ", ".join(sorted(str(key) for key in per_rank)) or "none" return None, ( f"checkpoint has no sampler state for rank {rank} " f"(available ranks: {available}); skipping dataset restore" ) return selected, warning if int(world_size) > 1: return None, ( "checkpoint only contains a single shared data_state. " "Skipping sampler restore under DDP to avoid cross-rank data duplication." ) return data_state, None def restore_checkpoint_data_state(dataset, data_state, *, rank, world_size, label="Resume"): if dataset is None or not hasattr(dataset, "load_state_dict") or not data_state: return False selected_state, warning = resolve_rank_checkpoint_data_state( data_state, rank=rank, world_size=world_size, ) if warning and is_main_process(rank): print(f"[{label}] WARNING: {warning}") if selected_state is None: return False try: dataset.load_state_dict(selected_state) if is_main_process(rank): print(f"[{label}] Data sampler state restored") return True except Exception as exc: if is_main_process(rank): print(f"[{label}] WARNING: failed to restore data sampler state ({exc})") return False def summarize_source_sampling_stats(snapshot, *, warn_threshold=0.80, top_k=4): if not snapshot: return "", [], {} sources = snapshot.get("sources", {}) if not sources: return "", [], {} ranked = [] compact = {} warnings = [] def _fmt_metric(value): if value is None: return "n/a" return f"{value:.3f}" if abs(float(value)) < 0.01 else f"{value:.2f}" for source_id, entry in sources.items(): epoch = entry.get("epoch") coverage = entry.get("coverage") mode = entry.get("mode", "unknown") sampled_train_tokens = int(entry.get("sampled_train_tokens", 0) or 0) score = epoch if epoch is not None else (coverage if coverage is not None else 0.0) if sampled_train_tokens <= 0 and (coverage is None or coverage <= 0): continue if mode == "memmap_exact": if epoch is None: continue label = f"{source_id}:{_fmt_metric(epoch)}ep" if coverage is not None: label += f"/{coverage * 100:.0f}%cov" else: observed = epoch if epoch is not None else 0.0 label = f"{source_id}:{_fmt_metric(observed)}obs" if coverage is not None: label += f"/{coverage * 100:.0f}%files" ranked.append((score, source_id, label)) compact[source_id] = { "mode": mode, "epoch": epoch, "coverage": coverage, "sampled_train_tokens": sampled_train_tokens, "unique_token_budget": entry.get("unique_token_budget", 0), "sampled_windows": entry.get("sampled_windows", 0), "total_windows": entry.get("total_windows", 0), "remaining_windows": entry.get("remaining_windows", 0), "completed_files": entry.get("completed_files", 0), "total_files": entry.get("total_files", 0), "replicated": entry.get("replicated", False), "active": entry.get("active", False), "exhausted": entry.get("exhausted", False), } trigger_value = epoch if entry.get("budget_is_exact", False) else coverage if warn_threshold > 0 and trigger_value is not None and trigger_value >= warn_threshold: if entry.get("budget_is_exact", False): warnings.append(f"{source_id}:{_fmt_metric(epoch)}ep") else: warnings.append(f"{source_id}:{coverage * 100:.0f}%files") ranked.sort(key=lambda item: (-item[0], item[1])) summary = ", ".join(label for _, _, label in ranked[: max(1, int(top_k))]) warnings = sorted(set(warnings)) return summary, warnings[: max(1, int(top_k))], compact def parse_args(): p = argparse.ArgumentParser(description="SymbolicLight 0.8B Trainer (DDP + Auto Aux CE)") p.add_argument("--data_bin", type=str, default=None, help="Pretokenized memmap directory (train.bin + train.meta.json)") p.add_argument("--tokenizer_path", type=str, default=_SL_TOKENIZER_PATH, help="SentencePiece tokenizer path for parquet/memmap pipeline") p.add_argument("--curriculum_phase1_ratio", type=float, default=DEFAULT_CURRICULUM_PHASE_1_RATIO, help="Fraction of total tokens reserved for phase 1 before phase 2") p.add_argument("--curriculum_phase2_ratio", type=float, default=DEFAULT_CURRICULUM_PHASE_2_RATIO, help="Fraction of total tokens reserved for phase 1+2 before phase 3") p.add_argument("--curriculum_preset", type=str, default=DEFAULT_CURRICULUM_PRESET, choices=["default", "30b08", "30b08_30b", "50b08"], help="Curriculum source-weight preset") p.add_argument("--max_oversample", type=float, default=5.0, help="Cap phase source weight to natural source share * max_oversample") p.add_argument("--allow_source_restarts", action="store_true", help="Allow exhausted sources/windows to restart and repeat") p.add_argument("--source_epoch_warn", type=float, default=0.80, help="Warn when a source epoch/coverage reaches this threshold") p.add_argument("--data_dir", type=str, default=DEFAULT_DATA_DIR, help="Root directory containing parquet sources") p.add_argument("--dataset", type=str, default="mixed", choices=["mixed", "smoke"], help="Dataset mode: mixed curriculum sampler or synthetic smoke test") p.add_argument("--total_tokens", type=int, default=3_000_000_000, help="Total training tokens") p.add_argument("--vocab_size", type=int, default=57344, help="Vocabulary size") p.add_argument("--embed_dim", type=int, default=1536) p.add_argument("--n_layers", type=int, default=22) p.add_argument("--n_heads", type=int, default=24) p.add_argument("--head_dim", type=int, default=64) p.add_argument("--intermediate_dim", type=int, default=6144) p.add_argument("--max_seq_len", type=int, default=512, help="Sequence length") p.add_argument("--batch_size", type=int, default=13, help="Per-device batch size (lower it to avoid OOM while using grad_accum to preserve the effective batch)") p.add_argument("--grad_accum", type=int, default=2, help="Gradient accumulation steps (with batch_size=13 this helps preserve the effective batch size)") p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--warmup_steps", type=int, default=2000, help="Warmup steps") p.add_argument("--weight_decay", type=float, default=0.1) p.add_argument("--max_grad_norm", type=float, default=1.0) p.add_argument("--fp16", dest="fp16", action="store_true", help="Enable mixed precision training (default: on)") p.add_argument("--no_fp16", dest="fp16", action="store_false", help="Disable mixed precision and run in FP32 for debugging") p.add_argument("--grad_checkpoint", dest="grad_checkpoint", action="store_true", help="Enable activation checkpointing (default: on)") p.add_argument("--no_grad_checkpoint", dest="grad_checkpoint", action="store_false", help="Disable activation checkpointing") p.set_defaults(fp16=True, grad_checkpoint=True) p.add_argument("--num_workers", type=int, default=0, help="DataLoader workers. Streaming parquet on multi-GPU is safer with 0.") p.add_argument("--save_dir", type=str, default="checkpoints_0p8b") p.add_argument("--save_every", type=int, default=2000, help="Save a checkpoint every N steps") p.add_argument("--log_every", type=int, default=10) p.add_argument("--keep_checkpoints", type=int, default=3, help="Number of recent checkpoints to keep") p.add_argument("--resume", action="store_true") p.add_argument("--seed", type=int, default=42, help="Global random seed") p.add_argument("--sparse_attn_window", type=int, default=512, help="Sparse attention sliding window size (default: 512, covers full seq_len)") p.add_argument("--disable_sparse_attn", action="store_true", help="Disable Sparse Local Attention (Decay-Only / No-Attn ablation)") p.add_argument("--disable_dynamic_prior", action="store_true", help="Disable Dynamic Bayesian Prior (Static Prior ablation)") p.add_argument("--use_topk_mask", action="store_true", help="[W1 Ablation] Replace LIF spike gating with a fixed top-k mask") p.add_argument("--topk_sparsity", type=float, default=0.89, help="[W1 Ablation] Target sparsity for top-k mask") p.add_argument("--dry_run", action="store_true", help="Run a short smoke test") return p.parse_args() def train(args): print(f"[DEBUG] Rank {os.environ.get('RANK', '?')}: entering train()", flush=True) rank, local_rank, world_size = setup_distributed() is_ddp = world_size > 1 device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") print(f"[DEBUG] Rank {rank}: DDP init done, device={device}", flush=True) if args.dry_run: args.total_tokens = 200 * args.batch_size * args.max_seq_len * world_size args.save_every = 100 args.log_every = 10 if is_main_process(rank): print("\n" + "!" * 60) print(" DRY RUN MODE - validating the streaming data pipeline and model") print("!" * 60) if args.embed_dim != args.n_heads * args.head_dim: raise ValueError( f"embed_dim ({args.embed_dim}) must equal n_heads * head_dim " f"({args.n_heads} * {args.head_dim} = {args.n_heads * args.head_dim})" ) if is_main_process(rank): print(f"\n{'=' * 60}") print(f" SymbolicLight V1 Pre-Training") print(f"{'=' * 60}") print(f"Device: {device}") print(f"World size: {world_size} GPU(s)") n_params_est = ( args.vocab_size * args.embed_dim + args.n_layers * ( 3 * args.embed_dim * args.embed_dim + args.embed_dim * args.embed_dim + 2 * args.embed_dim * args.intermediate_dim ) ) print( f"Model profile: ~{n_params_est / 1e6:.1f}M params | d_model={args.embed_dim}, layers={args.n_layers}, " f"heads={args.n_heads}, head_dim={args.head_dim}, ffn={args.intermediate_dim}" ) if args.disable_sparse_attn: print(f"ABLATION: Sparse Attention DISABLED (Decay-Only mode)") if args.disable_dynamic_prior: print(f"ABLATION: Dynamic Prior DISABLED (Static Prior mode)") if getattr(args, 'use_topk_mask', False): print(f"ABLATION: Top-K Mask ENABLED (sparsity={args.topk_sparsity:.0%}, replacing LIF)") if device.type == "cuda": gpu_name = torch.cuda.get_device_name(local_rank) gpu_mem = torch.cuda.get_device_properties(local_rank).total_memory / 1e9 print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)") args._initial_lr = args.lr torch.manual_seed(args.seed) random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) if is_main_process(rank): print(f"Seed: {args.seed}") config = SymbolicLightConfig( vocab_size=args.vocab_size, embed_dim=args.embed_dim, n_layers=args.n_layers, n_heads=args.n_heads, head_dim=args.head_dim, intermediate_dim=args.intermediate_dim, max_seq_len=args.max_seq_len, enable_stdp=False, enable_sparse_attn=not args.disable_sparse_attn, sparse_attn_window=getattr(args, 'sparse_attn_window', 512), enable_dynamic_prior=not args.disable_dynamic_prior, use_topk_mask=getattr(args, 'use_topk_mask', False), topk_sparsity=getattr(args, 'topk_sparsity', 0.89), ) if is_main_process(rank): print(f"[DEBUG] Creating model...", flush=True) model = SymbolicLightModel(config).to(device) if is_main_process(rank): print(f"[DEBUG] Model created and moved to {device}", flush=True) if is_ddp: if device.type == "cuda": model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) else: model = DDP(model, find_unused_parameters=False) if is_main_process(rank): print(f"[DDP] DistributedDataParallel enabled on {world_size} GPUs") if is_main_process(rank): print(f"[DEBUG] Creating dataset...", flush=True) dataset = None dataloader = None if args.dataset == "smoke": dataset = SmokeTestStreamingDataset(seq_len=args.max_seq_len, vocab_size=args.vocab_size) dataloader_kwargs = dict( dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.num_workers > 0, ) if args.num_workers > 0: dataloader_kwargs["prefetch_factor"] = 2 dataloader = DataLoader(**dataloader_kwargs) if is_main_process(rank): print(f"[DEBUG] Dataset handle prepared", flush=True) tokens_per_step = args.batch_size * args.max_seq_len * args.grad_accum * world_size total_steps = args.total_tokens // tokens_per_step if is_main_process(rank): print(f"\n[Training Plan]") print(f" Total tokens: {args.total_tokens / 1e9:.1f}B") print(f" Tokens per step: {tokens_per_step:,}") print(f" Total steps: {total_steps:,}") print(f" Per-GPU batch: {args.batch_size}") print(f" Grad accum: {args.grad_accum}") print(f" Effective batch: {args.batch_size * args.grad_accum * world_size}") print(f" Seq len: {args.max_seq_len}") print(f" Warmup: {args.warmup_steps} steps") print(f" LR: {args.lr}") if args.dataset == "mixed": source_mode = f"memmap ({args.data_bin})" if args.data_bin else f"streaming parquet ({args.data_dir})" print(f" Data mode: {source_mode}") print( f" Curriculum: preset={args.curriculum_preset} | " f"P1<{args.curriculum_phase1_ratio:.2f}, P2<{args.curriculum_phase2_ratio:.2f}, else P3" ) print(f" Oversample cap: {args.max_oversample:.1f}x") repeat_policy = "allow restarts" if args.allow_source_restarts else "strict no-repeat" print(f" Repeat policy: {repeat_policy}") print(f" Source warn @: {args.source_epoch_warn:.2f}") else: print(f" DataLoader: workers={args.num_workers}, pin_memory={args.num_workers > 0}") raw_model = model.module if is_ddp else model if args.grad_checkpoint: raw_model.gradient_checkpointing_enable() if is_main_process(rank): print("[Memory] Gradient checkpointing: ON") else: raw_model.gradient_checkpointing_disable() if is_main_process(rank): print("[Memory] Gradient checkpointing: OFF") decay_params = [] no_decay_params = [] for name, param in raw_model.named_parameters(): if param.requires_grad: if "bias" in name or "norm" in name or "log_prior" in name or "prior_weight" in name: no_decay_params.append(param) else: decay_params.append(param) optimizer = torch.optim.AdamW([ {"params": decay_params, "weight_decay": args.weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ], lr=args.lr, betas=(0.9, 0.95)) use_amp = args.fp16 and device.type == "cuda" use_bf16 = use_amp and torch.cuda.is_bf16_supported() use_fp16 = use_amp and not use_bf16 scaler = torch.amp.GradScaler('cuda', enabled=use_fp16) if is_main_process(rank): if use_bf16: print(f"[Memory] Mixed Precision: BF16 (optimal for SNN)") elif use_fp16: print(f"[Memory] Mixed Precision: FP16 (fallback, BF16 not supported)") else: print(f"[Memory] Mixed Precision: OFF (FP32)") global_step = 0 tokens_seen = 0 best_loss = float("inf") train_log = [] current_phase = 0 resume_ckpt = None data_samples_seen = 0 save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) if args.resume: ckpt_path = save_dir / "latest.pt" if ckpt_path.exists(): resume_ckpt = torch.load(ckpt_path, map_location=device, weights_only=True, mmap=True) ckpt = resume_ckpt state = {k: v for k, v in ckpt["model"].items() if 'v_mem' not in k} raw_model.load_state_dict(state, strict=False) optimizer.load_state_dict(ckpt["optimizer"]) scaler.load_state_dict(ckpt["scaler"]) global_step = ckpt["global_step"] tokens_seen = ckpt.get("tokens_seen", global_step * tokens_per_step) best_loss = ckpt.get("best_loss", float("inf")) data_samples_seen = ckpt.get("data_samples_seen", 0) current_phase = ckpt.get("curriculum_phase", 0) if "spike_encoder_vmem" in ckpt: raw_model.spike_encoder.v_mem = ckpt["spike_encoder_vmem"].to(device) if is_main_process(rank): print(f"[Resume] LIF membrane potential restored") else: if is_main_process(rank): print(f"[Resume] WARNING: No LIF v_mem in checkpoint (old format).") print(f" Using 3-step warmup buffer to smooth transition...") if is_main_process(rank): print(f"[Resume] Loaded: step={global_step}, tokens_seen={tokens_seen / 1e9:.2f}B") print(f"[Resume] Data offset: {data_samples_seen} samples will be skipped") else: if is_main_process(rank): print(f"[Resume] No checkpoint found at {ckpt_path}, starting fresh") if args.dataset == "mixed": if is_main_process(rank): print(f"[DEBUG] Creating mixed dataset...", flush=True) dataset = build_mixed_dataset( args, rank, world_size, config.vocab_size, seed_offset=data_samples_seen, ) restored_data_state = restore_checkpoint_data_state( dataset, resume_ckpt.get("data_state") if resume_ckpt else None, rank=rank, world_size=world_size, label="Resume", ) phase_anchor = current_phase if restored_data_state else 0 current_phase = maybe_switch_curriculum_phase(args, dataset, tokens_seen, phase_anchor, rank) elif is_main_process(rank): print(f"[DEBUG] Dataset created", flush=True) if is_main_process(rank): print(f"\nStarting training...\n") model.train() train_start = time.time() epoch_loss = 0.0 epoch_steps = 0 accum_loss = 0.0 micro_step = 0 resume_warmup_remaining = 0 use_direct_dataset = args.dataset == "mixed" data_iter = None if use_direct_dataset else iter(dataloader) MAX_SKIP = 10000 skip_target = 0 if use_direct_dataset else data_samples_seen if skip_target > 0 and is_main_process(rank): print(f"[Resume] Skipping {skip_target} samples (of {data_samples_seen} total)...") skip_count = 0 while skip_count < skip_target: try: next(data_iter) skip_count += 1 except StopIteration: data_iter = iter(dataloader) continue if skip_target > 0 and is_main_process(rank): print(f"[Resume] Skipped {skip_count} samples, data stream aligned") if args.resume and global_step > 0: if 'ckpt' in dir() and "spike_encoder_vmem" not in ckpt: resume_warmup_remaining = 3 if is_main_process(rank): print(f"[Resume] Warmup buffer: {resume_warmup_remaining} steps at 1/10 LR") while tokens_seen < args.total_tokens: if use_direct_dataset: current_phase = maybe_switch_curriculum_phase(args, dataset, tokens_seen, current_phase, rank) x, y = dataset.get_batch(args.batch_size, device) else: try: x, y = next(data_iter) except StopIteration: data_iter = iter(dataloader) x, y = next(data_iter) x, y = x.to(device), y.to(device) lr = get_lr(global_step, args.warmup_steps, total_steps, args.lr) if resume_warmup_remaining > 0: lr = lr * 0.1 resume_warmup_remaining -= 1 for pg in optimizer.param_groups: pg["lr"] = lr data_samples_seen += args.batch_size amp_dtype = torch.bfloat16 if use_bf16 else torch.float16 with torch.amp.autocast('cuda', dtype=amp_dtype, enabled=use_amp): logits = model(x) main_loss = F.cross_entropy( logits.reshape(-1, config.vocab_size), y.reshape(-1), ) flat_logits = logits.reshape(-1, config.vocab_size) n_sample = min(128, flat_logits.size(0)) z_idx = torch.randint(flat_logits.size(0), (n_sample,), device=logits.device) log_z = torch.logsumexp(flat_logits[z_idx], dim=-1) z_loss = 1e-4 * (log_z ** 2).mean() loss = (main_loss + z_loss) / args.grad_accum scaler.scale(loss).backward() accum_loss += main_loss.item() tokens_seen += x.numel() * world_size micro_step += 1 if micro_step >= args.grad_accum: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) current_loss = accum_loss / args.grad_accum accum_loss = 0.0 micro_step = 0 epoch_loss += current_loss epoch_steps += 1 should_log = global_step % args.log_every == 0 and global_step > 0 merged_recent_source_hist = {} merged_source_sampling = {} if should_log: local_recent_source_hist = ( dataset.consume_recent_source_histogram() if use_direct_dataset and hasattr(dataset, "consume_recent_source_histogram") else {} ) local_source_sampling = ( dataset.get_source_sampling_stats() if use_direct_dataset and hasattr(dataset, "get_source_sampling_stats") else {} ) if is_ddp: gathered_stats = [None for _ in range(world_size)] dist.all_gather_object( gathered_stats, { "hist": local_recent_source_hist, "sampling": local_source_sampling, }, ) else: gathered_stats = [ { "hist": local_recent_source_hist, "sampling": local_source_sampling, } ] if is_main_process(rank): merged_recent_source_hist = merge_source_histograms( [item.get("hist", {}) for item in gathered_stats] ) merged_source_sampling = merge_source_sampling_stats( [item.get("sampling", {}) for item in gathered_stats] ) if should_log and is_main_process(rank): ppl = math.exp(min(current_loss, 20)) elapsed = time.time() - train_start tokens_per_sec = tokens_seen / elapsed with torch.no_grad(): spikes, _ = raw_model.spike_encoder(x[:1, :32]) sparsity = 1.0 - spikes.mean().item() progress = tokens_seen / args.total_tokens * 100 eta_seconds = (args.total_tokens - tokens_seen) / max(tokens_per_sec, 1) eta_hours = eta_seconds / 3600 source_epoch_summary, source_epoch_warnings, source_epoch_snapshot = summarize_source_sampling_stats( merged_source_sampling, warn_threshold=args.source_epoch_warn, ) gpu_info = f" | GPUs: {world_size}" if is_ddp else "" source_info = ( f" | Src: {format_source_histogram(merged_recent_source_hist)}" if merged_recent_source_hist else "" ) source_epoch_info = f" | SrcEpoch: {source_epoch_summary}" if source_epoch_summary else "" print(f"Step {global_step:6d}/{total_steps} | " f"Loss: {current_loss:.4f} | " f"PPL: {ppl:8.1f} | " f"LR: {lr:.2e} | " f"Sparsity: {sparsity * 100:.1f}% | " f"Tok/s: {tokens_per_sec:.0f} | " f"Progress: {progress:.1f}% | " f"ETA: {eta_hours:.1f}h{gpu_info}{source_info}{source_epoch_info}") if source_epoch_warnings: print(f" [SourceWarn] {', '.join(source_epoch_warnings)}") log_entry = { "step": global_step, "loss": current_loss, "ppl": ppl, "lr": lr, "sparsity": sparsity, "tokens_seen": tokens_seen, "tokens_per_sec": tokens_per_sec, } if merged_recent_source_hist: log_entry["source_histogram"] = merged_recent_source_hist if source_epoch_snapshot: log_entry["source_sampling"] = source_epoch_snapshot log_entry["source_sampling_warnings"] = source_epoch_warnings train_log.append(log_entry) should_rollback = False if current_loss > 15.0 and global_step > args.warmup_steps: should_rollback = True if is_ddp: rollback_tensor = torch.tensor([1.0 if should_rollback else 0.0], device=device) dist.all_reduce(rollback_tensor, op=dist.ReduceOp.MAX) should_rollback = rollback_tensor.item() > 0.5 if should_rollback: if is_main_process(rank): print(f"\n[ALERT] Loss spike detected: {current_loss:.2f} > 15.0!") if is_ddp: print(" At least one rank requested rollback; restoring all ranks to the latest checkpoint...") print(" Rolling back to the latest checkpoint...") ckpt_path = save_dir / "latest.pt" if ckpt_path.exists(): ckpt = torch.load(ckpt_path, map_location=device, weights_only=True, mmap=True) state = {k: v for k, v in ckpt["model"].items() if 'v_mem' not in k} raw_model.load_state_dict(state, strict=False) optimizer.load_state_dict(ckpt["optimizer"]) scaler.load_state_dict(ckpt["scaler"]) global_step = ckpt["global_step"] tokens_seen = ckpt.get("tokens_seen", global_step * tokens_per_step) data_samples_seen = ckpt.get("data_samples_seen", 0) current_phase = ckpt.get("curriculum_phase", current_phase) if "spike_encoder_vmem" in ckpt and ckpt["spike_encoder_vmem"] is not None: raw_model.spike_encoder.v_mem = ckpt["spike_encoder_vmem"].to(device) min_lr = args.lr * 0.125 if not hasattr(args, '_initial_lr') else args._initial_lr * 0.125 args.lr = max(args.lr * 0.5, min_lr) if is_main_process(rank): print(f" Rolled back to step {global_step}, LR -> {args.lr:.2e} (min: {min_lr:.2e})") if is_ddp: dist.barrier() if use_direct_dataset: dataset = build_mixed_dataset( args, rank, world_size, config.vocab_size, seed_offset=data_samples_seen, ) restored_data_state = restore_checkpoint_data_state( dataset, ckpt.get("data_state"), rank=rank, world_size=world_size, label="Rollback", ) phase_anchor = current_phase if restored_data_state else 0 current_phase = maybe_switch_curriculum_phase(args, dataset, tokens_seen, phase_anchor, rank) else: data_iter, _ = realign_data_iterator(dataloader, data_samples_seen, rank, label="Rollback") accum_loss = 0.0 accum_aux = 0.0 micro_step = 0 continue should_save = global_step % args.save_every == 0 and global_step > 0 checkpoint_data_state = None if should_save: checkpoint_data_state = collect_checkpoint_data_state( dataset, use_direct_dataset=use_direct_dataset, is_ddp=is_ddp, rank=rank, world_size=world_size, ) if should_save and is_main_process(rank): _save_checkpoint(raw_model, optimizer, scaler, global_step, tokens_seen, best_loss, config, save_dir, train_log, data_samples_seen=data_samples_seen, curriculum_phase=current_phase, data_state=checkpoint_data_state, keep_n=args.keep_checkpoints) global_step += 1 total_time = time.time() - train_start avg_loss = epoch_loss / max(epoch_steps, 1) final_ppl = math.exp(min(avg_loss, 20)) final_data_state = collect_checkpoint_data_state( dataset, use_direct_dataset=use_direct_dataset, is_ddp=is_ddp, rank=rank, world_size=world_size, ) if is_main_process(rank): print(f"\n{'=' * 60}") print(f" Training Complete!") print(f"{'=' * 60}") print(f"Total steps: {global_step:,}") print(f"Total tokens: {tokens_seen / 1e9:.2f}B") print(f"Total time: {total_time / 3600:.1f} hours") print(f"Avg tok/s: {tokens_seen / total_time:.0f}") print(f"Final avg loss: {avg_loss:.4f}") print(f"Final PPL: {final_ppl:.2f}") _save_checkpoint(raw_model, optimizer, scaler, global_step, tokens_seen, best_loss, config, save_dir, train_log, data_samples_seen=data_samples_seen, curriculum_phase=current_phase, data_state=final_data_state) log_path = save_dir / "train_log.json" with open(log_path, "w", encoding="utf-8") as f: json.dump(train_log, f, indent=2, ensure_ascii=False) print(f"Log saved to {log_path}") cleanup_distributed() def _save_checkpoint(model, optimizer, scaler, step, tokens_seen, best_loss, config, save_dir, train_log, data_samples_seen=0, curriculum_phase=0, data_state=None, keep_n=3): raw_model = model.module if hasattr(model, 'module') else model ckpt = { "model": raw_model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "global_step": step, "tokens_seen": tokens_seen, "best_loss": best_loss, "config": config.__dict__, "spike_encoder_vmem": raw_model.spike_encoder.v_mem.cpu() if raw_model.spike_encoder.v_mem is not None else None, "data_samples_seen": data_samples_seen, "curriculum_phase": curriculum_phase, "data_state": data_state, } torch.save(ckpt, save_dir / "latest.pt") torch.save(ckpt, save_dir / f"step_{step}.pt") print(f" [Checkpoint] Saved step {step} (tokens: {tokens_seen / 1e9:.2f}B, v_mem: saved)") if keep_n > 0: import glob as _glob saved = sorted( _glob.glob(str(save_dir / "step_*.pt")), key=lambda p: int(Path(p).stem.split('_')[1]) ) while len(saved) > keep_n: old = saved.pop(0) os.remove(old) print(f" [Checkpoint] Deleted old: {Path(old).name}") if __name__ == "__main__": args = parse_args() if 'RANK' not in os.environ or int(os.environ.get('RANK', 0)) == 0: print(f"\n{'=' * 60}") print(f" SymbolicLight V1 Training") print(f" Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"{'=' * 60}") print(f"Config: {vars(args)}") train(args)