""" Train Deeplm on indonesian-nlp/mc4-id — ALL features enabled. Features: - MLA (Multi-head Latent Attention) - MoE (Mixture of Experts: 4 routed + 1 shared, top-k=2) - Hybrid Attention (Softmax + Lightning v2) - Hyper-Connections + Sinkhorn routing - MTP (Multi-Token Prediction, depth=2) - BitNet b1.58 ternary quantization from init - AutoTuner (adaptive LR, GN, WD, momentum, revive, etc.) - Curriculum Router (phase-based category weighting) - Self-Evolution Framework (autonomous hypothesis → experiment → decision) - SmartLogger (anomaly detection, ETA, JSONL logging) - StrictFilter (Indonesian quality filter) - BucketDataset + WeightedBucketSampler (length-bucketed efficient batching) - Gradient checkpointing - TrainingControl (unified control plane for all hyperparams) Usage: python train.py python train.py --max_steps 100000 --batch_size 4 --seq_len 512 python train.py --no_auto_tuner --no_self_evolution python train.py --resume ./deeplm_output/checkpoint-1000 """ import os import sys import json import time import argparse import math from pathlib import Path from collections import defaultdict import torch from torch.utils.data import Dataset sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "deeplm")) from deeplm.config import DeeplmConfig from deeplm.model.deeplm import DeeplmModel from deeplm.quantization.bitnet_quantize import apply_bitnet_quantization from deeplm.training.trainer import Trainer, TrainingArgs from deeplm.training.auto_tuner import AutoTuner from deeplm.training.curriculum_router import CurriculumRouter from deeplm.training.data_pipeline import StrictFilter, TokenCache, BucketDataset, WeightedBucketSampler from deeplm.training.logger import SmartLogger, MetricsTracker from deeplm.training.control import TrainingControl from deeplm.self_evolution.framework import ( SelfEvolutionFramework, TrainingMetrics as SEMetrics, MetaMemory, ) from datasets import load_dataset class C: GREEN = "\033[92m" YELLOW = "\033[93m" RED = "\033[91m" CYAN = "\033[96m" MAGENTA = "\033[95m" BOLD = "\033[1m" END = "\033[0m" def log(msg): print(f"{C.CYAN}[deeplm]{C.END} {msg}", flush=True) # ═══════════════════════════════════════════════════════════════ # Tokenizer loader # ═══════════════════════════════════════════════════════════════ def load_tokenizer(repo_id="samcheng0/deeplm-108m", token=None): from tokenizers import Tokenizer from huggingface_hub import hf_hub_download local_tok = Path("tokenizer.json") if local_tok.exists(): log(f"Loading tokenizer from {local_tok}") tok = Tokenizer.from_file(str(local_tok)) else: log(f"Downloading tokenizer from {repo_id}") path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json", token=token) tok = Tokenizer.from_file(path) local_cfg = Path("tokenizer_config.json") if local_cfg.exists(): with open(local_cfg) as f: cfg = json.load(f) else: cfg = {"vocab_size": 32000} special = cfg.get("special_tokens", {}) bos = special.get("bos_token", "<|begin_of_sentence|>") eos = special.get("eos_token", "<|end_of_sentence|>") pad = special.get("pad_token", "<|pad|>") bos_id = tok.token_to_id(bos) or 1 eos_id = tok.token_to_id(eos) or 2 pad_id = tok.token_to_id(pad) or 0 return tok, bos_id, eos_id, pad_id # ═══════════════════════════════════════════════════════════════ # Dataset # ═══════════════════════════════════════════════════════════════ class Mc4Dataset(Dataset): """mc4-id dataset with length-bucket support and category tagging.""" def __init__(self, texts, tokenizer, max_seq_length, bos, eos, pad, category_map=None): self.texts = texts self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.bos = bos self.eos = eos self.pad = pad self.category_map = category_map or ["general"] * len(texts) def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] enc = self.tokenizer.encode(text) ids = enc.ids if hasattr(enc, 'ids') else enc if isinstance(ids, list) and len(ids) == 0: ids = [self.bos, self.eos] else: ids = [self.bos] + ids[:self.max_seq_length - 2] + [self.eos] ln = len(ids) if ln < self.max_seq_length: ids = ids + [self.pad] * (self.max_seq_length - ln) else: ids = ids[:self.max_seq_length] attn = [1] * ln + [0] * (self.max_seq_length - ln) labels = ids[:ln] + [-100] * (self.max_seq_length - ln) labels[0] = -100 return { "input_ids": torch.tensor(ids, dtype=torch.long), "attention_mask": torch.tensor(attn, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } def categorize_text(text): """Heuristic category assignment for curriculum routing.""" words = set(text.lower().split()) id_words = {"yang", "dan", "di", "dengan", "ini", "itu", "dari", "dalam", "untuk", "pada", "adalah", "merupakan", "atau", "bahwa", "karena"} reasoning_words = {"karena", "sehingga", "oleh", "jika", "maka", "sebab", "akibat", "disebabkan", "berakibat", "mengakibatkan"} creative_words = {"cerita", "kisah", "dongeng", "novel", "puisi", "syair", "imajinasi", "fantasi", "mimpi"} academic_words = {"penelitian", "studi", "analisis", "metode", "hipotesis", "eksperimen", "data", "hasil", "kesimpulan"} code_words = {"def", "function", "class", "import", "return", "if", "else", "for", "while", "print", "variable", "kode", "program"} overlap_reasoning = len(words & reasoning_words) overlap_creative = len(words & creative_words) overlap_academic = len(words & academic_words) overlap_code = len(words & code_words) max_overlap = max(overlap_reasoning, overlap_creative, overlap_academic, overlap_code) if max_overlap == 0: return "general" if overlap_reasoning == max_overlap: return "reasoning" if overlap_creative == max_overlap: return "creative" if overlap_academic == max_overlap: return "academic" return "code" def load_dataset_indonesian(dataset_name="afrizalha/KamusOne-28M-Indonesian", sample_size=None, cache_dir="/tmp/kamusone_cache"): log(f"Loading {dataset_name}...") os.makedirs(cache_dir, exist_ok=True) cache_file = os.path.join(cache_dir, "texts.jsonl") if os.path.exists(cache_file): log(f"Loading cached texts from {cache_file}") texts = [] with open(cache_file, "r") as f: for line in f: texts.append(json.loads(line)["text"]) log(f" Loaded {len(texts):,} cached texts") return texts log(" Downloading dataset (streaming)...") ds = load_dataset(dataset_name, split="train", streaming=True) texts = [] for item in ds: text = item.get("text", "").strip() if len(text) > 50: texts.append(text) if sample_size and len(texts) >= sample_size: break log(f" Downloaded {len(texts):,} texts, caching...") with open(cache_file, "w") as f: for t in texts: f.write(json.dumps({"text": t}) + "\n") return texts # ═══════════════════════════════════════════════════════════════ # Model builder # ═══════════════════════════════════════════════════════════════ def build_model(config: DeeplmConfig, bitnet=True): log(f"Creating DeeplmModel (vocab={config.vocab_size}, hidden={config.architecture.hidden_size}, " f"layers={config.architecture.num_layers}, heads={config.architecture.num_attention_heads})") model = DeeplmModel(config) if bitnet: log(f"Applying BitNet b1.58 ternary quantization (absmean) from init...") stats = apply_bitnet_quantization(model, scale="absmean", verbose=False) log(f" Quantized {stats['quantized']}/{stats['total_linear']} linear layers") log(f" Total parameters: {model.num_parameters():,}") return model # ═══════════════════════════════════════════════════════════════ # Main # ═══════════════════════════════════════════════════════════════ def main(): parser = argparse.ArgumentParser(description="Train Deeplm — ALL features") parser.add_argument("--output_dir", type=str, default="./deeplm_output") parser.add_argument("--max_steps", type=int, default=50000) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--grad_accum", type=int, default=4) parser.add_argument("--seq_len", type=int, default=1024) parser.add_argument("--lr", type=float, default=6.0e-4) parser.add_argument("--min_lr", type=float, default=6.0e-6) parser.add_argument("--warmup_steps", type=int, default=150) parser.add_argument("--weight_decay", type=float, default=0.1) parser.add_argument("--max_grad_norm", type=float, default=1.0) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument("--save_steps", type=int, default=1000) parser.add_argument("--eval_steps", type=int, default=2000) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--sample_size", type=int, default=0, help="0 = all data") parser.add_argument("--resume", type=str, default=None) parser.add_argument("--no_bitnet", action="store_true") parser.add_argument("--no_auto_tuner", action="store_true") parser.add_argument("--no_curriculum", action="store_true") parser.add_argument("--no_self_evolution", action="store_true") parser.add_argument("--hf_token", type=str, default=None) args = parser.parse_args() torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log(f"Device: {device}") # ── 1. Tokenizer ── tokenizer, bos_id, eos_id, pad_id = load_tokenizer(token=args.hf_token) # ── 2. Config — ALL features enabled ── config = DeeplmConfig( vocab_size=32000, max_seq_length=args.seq_len, dtype="float32", ) # Architecture config.architecture.num_layers = 10 config.architecture.hidden_size = 512 config.architecture.intermediate_size = 2048 config.architecture.num_attention_heads = 8 config.architecture.num_key_value_heads = 1 config.architecture.head_dim = 128 config.architecture.rope_head_dim = 64 config.architecture.nope_head_dim = 64 config.architecture.max_seq_length = args.seq_len config.architecture.rope_theta = 50000.0 # MLA config.mla.q_lora_rank = 192 config.mla.kv_lora_rank = 64 config.mla.qk_rope_head_dim = 64 config.mla.qk_nope_head_dim = 64 config.mla.v_head_dim = 128 config.mla.num_heads = 8 config.mla.kv_heads = 1 # MoE config.moe.num_routed_experts = 4 config.moe.num_shared_experts = 1 config.moe.top_k = 2 # MTP config.mtp.num_mtp_layers = 2 config.mtp.mtp_depth = 2 config.mtp.mtp_hidden_size = 512 # Hybrid Attention config.hybrid_attention.softmax_layers = [0, 4, 8] config.hybrid_attention.linear_layers = [1, 2, 3, 5, 6, 7, 9] # Hyper-Connections config.hyper_connections.enabled = True # Output heads config.output_heads.lm_head.type = "tied" config.output_heads.lm_head.bias = False # ── 3. Model ── model = build_model(config, bitnet=not args.no_bitnet) # ── 4. Data ── sample = args.sample_size if args.sample_size > 0 else None texts = load_dataset_indonesian(sample_size=sample) # StrictFilter log("Applying StrictFilter (Indonesian quality)...") filt = StrictFilter(config={ "min_length": 50, "max_length": 4000, "min_char_ratio": 0.25, "max_repetition_ratio": 0.4, "min_words": 10, }) texts = filt.filter(texts, lang="id") log(f" Filter: {filt.summary()}") # Categorize for curriculum routing log("Categorizing texts for curriculum routing...") category_map = [categorize_text(t) for t in texts] cat_counts = defaultdict(int) for c in category_map: cat_counts[c] += 1 for c, n in sorted(cat_counts.items()): log(f" {c}: {n:,}") # Split eval_size = min(1000, len(texts) // 10) train_texts = texts[eval_size:] eval_texts = texts[:eval_size] train_cats = category_map[eval_size:] eval_cats = category_map[:eval_size] log(f" Train: {len(train_texts):,} | Eval: {len(eval_texts):,}") train_ds = Mc4Dataset(train_texts, tokenizer, args.seq_len, bos_id, eos_id, pad_id, category_map=train_cats) eval_ds = Mc4Dataset(eval_texts, tokenizer, args.seq_len, bos_id, eos_id, pad_id, category_map=eval_cats) # ── 5. Trainer ── warmup_ratio = args.warmup_steps / max(args.max_steps, 1) train_args = TrainingArgs( output_dir=args.output_dir, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, weight_decay=args.weight_decay, max_grad_norm=args.max_grad_norm, warmup_ratio=warmup_ratio, lr_schedule="cosine", logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, max_steps=args.max_steps, gradient_checkpointing=True, seed=args.seed, eval_dataset=eval_ds, max_eval_samples=100, use_auto_tuner=not args.no_auto_tuner, ) trainer = Trainer( model=model, config=config, train_dataset=train_ds, eval_dataset=eval_ds, args=train_args, ) # ── 6. SmartLogger ── os.makedirs(args.output_dir, exist_ok=True) gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu" smart_log = SmartLogger( log_dir=args.output_dir, total_steps=args.max_steps, model_params=model.num_parameters(), batch_size=args.batch_size, grad_accum=args.grad_accum, seq_length=args.seq_len, lr=args.lr, vocab_size=config.vocab_size, gpu_name=gpu_name, ) metrics_tracker = MetricsTracker(window=500) # ── 7. Curriculum Router ── if not args.no_curriculum: curriculum_router = CurriculumRouter() log(f"CurriculumRouter: enabled (phase={curriculum_router.get_phase()})") else: curriculum_router = None # ── 8. Self-Evolution Framework ── if not args.no_self_evolution: se_metrics = SEMetrics(window_size=100) se_framework = SelfEvolutionFramework( config=config.self_evolution, model=model, optimizer=trainer.optimizer, metrics=se_metrics, ) se_framework.memory.load() log(f"Self-Evolution: enabled (max_iterations={se_framework.max_iterations})") else: se_framework = None se_metrics = None # ── 9. Training header ── log(f"{C.BOLD}{'='*70}{C.END}") log(f"{C.BOLD} Deeplm Training — ALL FEATURES{C.END}") log(f"{C.BOLD}{'='*70}{C.END}") log(f" Dataset: afrizalha/KamusOne-28M-Indonesian") log(f" MLA: enabled (q_lora=192, kv_lora=64, 8 heads, 1 kv)") log(f" MoE: enabled (4 routed + 1 shared, top-k=2)") log(f" Hybrid Attn: Softmax [{','.join(map(str, config.hybrid_attention.softmax_layers))}] + Lightning") log(f" Hyper-Conn: enabled (Sinkhorn routing)") log(f" MTP: enabled (depth=2, {config.mtp.num_mtp_layers} layers)") log(f" BitNet: {'enabled (absmean)' if not args.no_bitnet else 'disabled'}") log(f" AutoTuner: {'enabled' if not args.no_auto_tuner else 'disabled'}") log(f" Curriculum: {'enabled' if curriculum_router else 'disabled'}") log(f" Self-Evolution: {'enabled' if se_framework else 'disabled'}") log(f" SmartLogger: enabled (JSONL + anomaly detection + ETA)") log(f" Steps: {args.max_steps:,}") log(f" Batch: {args.batch_size} x grad_accum={args.grad_accum} = eff {args.batch_size * args.grad_accum}") log(f" Seq len: {args.seq_len}") log(f" LR: {args.lr:.2e} (warmup={args.warmup_steps})") log(f" Output: {args.output_dir}") log(f"{'='*70}") # ── 10. Custom training loop with ALL features ── trainer.model.train() trainer.optimizer.zero_grad() train_loader = trainer.get_train_dataloader() eval_loader = trainer.get_eval_dataloader() num_training_steps = args.max_steps scheduler = trainer.get_scheduler(num_training_steps) # AutoTuner init if not args.no_auto_tuner: trainer.auto_tuner = AutoTuner( args.lr, args.warmup_steps, num_training_steps, args.max_grad_norm ) if curriculum_router: trainer.auto_tuner._curriculum_router = curriculum_router log(" AutoTuner: initialized") # Resume if args.resume: trainer._load_checkpoint(args.resume, scheduler) micro_step = 0 t_start = time.time() se_evolution_step = 0 se_consolidation_interval = 5000 for epoch in range(train_args.num_train_epochs): trainer.epoch = epoch for step, batch in enumerate(train_loader): batch = {k: v.to(trainer.device) for k, v in batch.items()} # Forward output = trainer.model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], output_mtp_loss=config.mtp.enabled, ) loss = output["loss"] / args.grad_accum loss.backward() trainer.total_loss += loss.item() * args.grad_accum micro_step += 1 if micro_step % args.grad_accum == 0: # AutoTuner adjustments if trainer.auto_tuner is not None: adj = trainer.auto_tuner.get_adjustments(trainer.global_step) trainer.auto_tuner.capture_gradients(trainer.model) control = trainer.auto_tuner.get_training_control(trainer.global_step) trainer.apply_training_control(control) grad_norm_clip = trainer._get_grad_norm_limit() else: grad_norm_clip = args.max_grad_norm torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), grad_norm_clip) trainer.optimizer.step() if trainer.auto_tuner is not None: trainer.auto_tuner.capture_update(trainer.model) trainer.optimizer.zero_grad() scheduler.step() trainer.global_step += 1 # Metrics avg_loss = trainer.total_loss / (args.logging_steps * args.grad_accum) lr = trainer.optimizer.param_groups[0]["lr"] grad_norm = trainer._get_grad_norm() elapsed = time.time() - t_start n_tokens = batch["input_ids"].numel() * args.grad_accum tok_per_sec = n_tokens * trainer.global_step / max(elapsed, 1.0) # AutoTuner record if trainer.auto_tuner is not None: trainer.auto_tuner.record(trainer.global_step, avg_loss, grad_norm) # Curriculum Router if curriculum_router and trainer.auto_tuner is not None: phase = trainer.auto_tuner.phase curriculum_router.set_phase(phase, trainer.global_step) signals = { "plateau": trainer.auto_tuner.state == "plateau", "overfit": getattr(trainer.auto_tuner, 'overfit_signal', False), "sick": len(getattr(trainer.auto_tuner, 'sick_layers', set())) > 0, } curriculum_router.explore(trainer.global_step, signals) # Self-Evolution metrics if se_metrics is not None: se_metrics.record(avg_loss, grad_norm, lr, tok_per_sec) # SmartLogger smart_log.log_step( trainer.global_step, avg_loss, grad_norm, lr, tok_per_sec ) metrics_tracker.update( loss=avg_loss, grad_norm=grad_norm, lr=lr, tok_per_sec=tok_per_sec ) # Self-Evolution periodic check if se_framework is not None and trainer.global_step % se_consolidation_interval == 0: log(f"{C.MAGENTA}[Self-Evolution] Running autonomous round {se_evolution_step + 1}{C.END}") anomalies = se_metrics.detect_anomalies() trend = se_metrics.get_trend() hypothesis = se_framework._generate_hypothesis() experiment = se_framework._design_experiment(hypothesis) changes = se_framework._execute_changes(experiment) bugs = se_framework._diagnose_bugs() fixes = se_framework._apply_fixes(bugs) result = se_framework._evaluate() decision = se_framework._make_decision(result) from deeplm.self_evolution.framework import EvolutionEpisode episode = EvolutionEpisode( step=trainer.global_step, phase="autonomous_check", hypothesis=hypothesis, experiment_design=experiment, changes_applied=changes, metrics_before=se_metrics.get_recent(), experiment_result=result, decision=decision, bugs_found=bugs, fixes_applied=fixes, timestamp=time.time(), ) se_framework.memory.add(episode) if decision == "keep": se_framework._commit_changes(episode) elif decision == "revert": se_framework._revert_changes() log(f" Hypothesis: {hypothesis}") log(f" Result: {result:.2f} | Decision: {decision}") if bugs: log(f" Bugs: {bugs}") if fixes: log(f" Fixes: {fixes}") se_evolution_step += 1 if se_evolution_step % 5 == 0: se_framework.memory.save() stats = se_framework.get_memory_stats() log(f" Memory stats: {stats['total_entries']} entries, " f"keep_rate={stats['consolidation'].get('keep_rate', 0):.2f}") # Logging if trainer.global_step % args.logging_steps == 0: line = (f"Step {trainer.global_step:>6,} | Loss: {avg_loss:.4f} | " f"LR: {lr:.6f} | Grad: {grad_norm:.2f} | " f"Tok/s: {tok_per_sec:,.0f}") if trainer.auto_tuner is not None: at = trainer.auto_tuner line += f" | AT: {at.state} LRx{at.lr_mult:.2f}" if hasattr(at, 'diagnosis') and at.diagnosis != "unknown": line += f" | Diag: {at.diagnosis}" if curriculum_router is not None: weights = curriculum_router.get_weights() top_cats = sorted(weights.items(), key=lambda x: -x[1])[:3] line += f" | Curr: {', '.join(f'{c}={w:.1f}' for c, w in top_cats)}" print(f"{C.CYAN}[train]{C.END} {line}") trainer.total_loss = 0.0 # Evaluation if eval_loader and trainer.global_step % args.eval_steps == 0: eval_loss = trainer.evaluate(eval_loader) if trainer.auto_tuner is not None: trainer.auto_tuner.capture_eval(trainer.global_step, eval_loss) if eval_loss < trainer.best_eval_loss: trainer.best_eval_loss = eval_loss trainer._save_checkpoint(best=True) trainer.model.train() # Checkpoint if trainer.global_step % args.save_steps == 0: trainer._save_checkpoint() if args.max_steps > 0 and trainer.global_step >= args.max_steps: break if args.max_steps > 0 and trainer.global_step >= args.max_steps: break if args.max_steps > 0 and trainer.global_step >= args.max_steps: break # Final save trainer._save_checkpoint(final=True) smart_log.log_summary(trainer.global_step, avg_loss, trainer.best_eval_loss) if se_framework is not None: se_framework.memory.save() log(f"Self-Evolution memory saved ({len(se_framework.memory.entries)} episodes)") elapsed = time.time() - t_start log(f"{C.GREEN}Training complete in {elapsed/3600:.1f}h ({elapsed:.0f}s){C.END}") if __name__ == "__main__": main()