deeplm-108m / train.py
samcheng0's picture
Upload train.py with huggingface_hub
fe7d2c6 verified
"""
Train Deeplm on indonesian-nlp/mc4-id — ALL features enabled.
Features:
- MLA (Multi-head Latent Attention)
- MoE (Mixture of Experts: 4 routed + 1 shared, top-k=2)
- Hybrid Attention (Softmax + Lightning v2)
- Hyper-Connections + Sinkhorn routing
- MTP (Multi-Token Prediction, depth=2)
- BitNet b1.58 ternary quantization from init
- AutoTuner (adaptive LR, GN, WD, momentum, revive, etc.)
- Curriculum Router (phase-based category weighting)
- Self-Evolution Framework (autonomous hypothesis → experiment → decision)
- SmartLogger (anomaly detection, ETA, JSONL logging)
- StrictFilter (Indonesian quality filter)
- BucketDataset + WeightedBucketSampler (length-bucketed efficient batching)
- Gradient checkpointing
- TrainingControl (unified control plane for all hyperparams)
Usage:
python train.py
python train.py --max_steps 100000 --batch_size 4 --seq_len 512
python train.py --no_auto_tuner --no_self_evolution
python train.py --resume ./deeplm_output/checkpoint-1000
"""
import os
import sys
import json
import time
import argparse
import math
from pathlib import Path
from collections import defaultdict
import torch
from torch.utils.data import Dataset
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "deeplm"))
from deeplm.config import DeeplmConfig
from deeplm.model.deeplm import DeeplmModel
from deeplm.quantization.bitnet_quantize import apply_bitnet_quantization
from deeplm.training.trainer import Trainer, TrainingArgs
from deeplm.training.auto_tuner import AutoTuner
from deeplm.training.curriculum_router import CurriculumRouter
from deeplm.training.data_pipeline import StrictFilter, TokenCache, BucketDataset, WeightedBucketSampler
from deeplm.training.logger import SmartLogger, MetricsTracker
from deeplm.training.control import TrainingControl
from deeplm.self_evolution.framework import (
SelfEvolutionFramework, TrainingMetrics as SEMetrics, MetaMemory,
)
from datasets import load_dataset
class C:
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
CYAN = "\033[96m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
END = "\033[0m"
def log(msg):
print(f"{C.CYAN}[deeplm]{C.END} {msg}", flush=True)
# ═══════════════════════════════════════════════════════════════
# Tokenizer loader
# ═══════════════════════════════════════════════════════════════
def load_tokenizer(repo_id="samcheng0/deeplm-108m", token=None):
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
local_tok = Path("tokenizer.json")
if local_tok.exists():
log(f"Loading tokenizer from {local_tok}")
tok = Tokenizer.from_file(str(local_tok))
else:
log(f"Downloading tokenizer from {repo_id}")
path = hf_hub_download(repo_id=repo_id, filename="tokenizer.json", token=token)
tok = Tokenizer.from_file(path)
local_cfg = Path("tokenizer_config.json")
if local_cfg.exists():
with open(local_cfg) as f:
cfg = json.load(f)
else:
cfg = {"vocab_size": 32000}
special = cfg.get("special_tokens", {})
bos = special.get("bos_token", "<|begin_of_sentence|>")
eos = special.get("eos_token", "<|end_of_sentence|>")
pad = special.get("pad_token", "<|pad|>")
bos_id = tok.token_to_id(bos) or 1
eos_id = tok.token_to_id(eos) or 2
pad_id = tok.token_to_id(pad) or 0
return tok, bos_id, eos_id, pad_id
# ═══════════════════════════════════════════════════════════════
# Dataset
# ═══════════════════════════════════════════════════════════════
class Mc4Dataset(Dataset):
"""mc4-id dataset with length-bucket support and category tagging."""
def __init__(self, texts, tokenizer, max_seq_length, bos, eos, pad,
category_map=None):
self.texts = texts
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.bos = bos
self.eos = eos
self.pad = pad
self.category_map = category_map or ["general"] * len(texts)
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
enc = self.tokenizer.encode(text)
ids = enc.ids if hasattr(enc, 'ids') else enc
if isinstance(ids, list) and len(ids) == 0:
ids = [self.bos, self.eos]
else:
ids = [self.bos] + ids[:self.max_seq_length - 2] + [self.eos]
ln = len(ids)
if ln < self.max_seq_length:
ids = ids + [self.pad] * (self.max_seq_length - ln)
else:
ids = ids[:self.max_seq_length]
attn = [1] * ln + [0] * (self.max_seq_length - ln)
labels = ids[:ln] + [-100] * (self.max_seq_length - ln)
labels[0] = -100
return {
"input_ids": torch.tensor(ids, dtype=torch.long),
"attention_mask": torch.tensor(attn, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
def categorize_text(text):
"""Heuristic category assignment for curriculum routing."""
words = set(text.lower().split())
id_words = {"yang", "dan", "di", "dengan", "ini", "itu", "dari", "dalam",
"untuk", "pada", "adalah", "merupakan", "atau", "bahwa", "karena"}
reasoning_words = {"karena", "sehingga", "oleh", "jika", "maka", "sebab",
"akibat", "disebabkan", "berakibat", "mengakibatkan"}
creative_words = {"cerita", "kisah", "dongeng", "novel", "puisi", "syair",
"imajinasi", "fantasi", "mimpi"}
academic_words = {"penelitian", "studi", "analisis", "metode", "hipotesis",
"eksperimen", "data", "hasil", "kesimpulan"}
code_words = {"def", "function", "class", "import", "return", "if", "else",
"for", "while", "print", "variable", "kode", "program"}
overlap_reasoning = len(words & reasoning_words)
overlap_creative = len(words & creative_words)
overlap_academic = len(words & academic_words)
overlap_code = len(words & code_words)
max_overlap = max(overlap_reasoning, overlap_creative, overlap_academic, overlap_code)
if max_overlap == 0:
return "general"
if overlap_reasoning == max_overlap:
return "reasoning"
if overlap_creative == max_overlap:
return "creative"
if overlap_academic == max_overlap:
return "academic"
return "code"
def load_dataset_indonesian(dataset_name="afrizalha/KamusOne-28M-Indonesian", sample_size=None, cache_dir="/tmp/kamusone_cache"):
log(f"Loading {dataset_name}...")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "texts.jsonl")
if os.path.exists(cache_file):
log(f"Loading cached texts from {cache_file}")
texts = []
with open(cache_file, "r") as f:
for line in f:
texts.append(json.loads(line)["text"])
log(f" Loaded {len(texts):,} cached texts")
return texts
log(" Downloading dataset (streaming)...")
ds = load_dataset(dataset_name, split="train", streaming=True)
texts = []
for item in ds:
text = item.get("text", "").strip()
if len(text) > 50:
texts.append(text)
if sample_size and len(texts) >= sample_size:
break
log(f" Downloaded {len(texts):,} texts, caching...")
with open(cache_file, "w") as f:
for t in texts:
f.write(json.dumps({"text": t}) + "\n")
return texts
# ═══════════════════════════════════════════════════════════════
# Model builder
# ═══════════════════════════════════════════════════════════════
def build_model(config: DeeplmConfig, bitnet=True):
log(f"Creating DeeplmModel (vocab={config.vocab_size}, hidden={config.architecture.hidden_size}, "
f"layers={config.architecture.num_layers}, heads={config.architecture.num_attention_heads})")
model = DeeplmModel(config)
if bitnet:
log(f"Applying BitNet b1.58 ternary quantization (absmean) from init...")
stats = apply_bitnet_quantization(model, scale="absmean", verbose=False)
log(f" Quantized {stats['quantized']}/{stats['total_linear']} linear layers")
log(f" Total parameters: {model.num_parameters():,}")
return model
# ═══════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════
def main():
parser = argparse.ArgumentParser(description="Train Deeplm — ALL features")
parser.add_argument("--output_dir", type=str, default="./deeplm_output")
parser.add_argument("--max_steps", type=int, default=50000)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--grad_accum", type=int, default=4)
parser.add_argument("--seq_len", type=int, default=1024)
parser.add_argument("--lr", type=float, default=6.0e-4)
parser.add_argument("--min_lr", type=float, default=6.0e-6)
parser.add_argument("--warmup_steps", type=int, default=150)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=1000)
parser.add_argument("--eval_steps", type=int, default=2000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--sample_size", type=int, default=0, help="0 = all data")
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--no_bitnet", action="store_true")
parser.add_argument("--no_auto_tuner", action="store_true")
parser.add_argument("--no_curriculum", action="store_true")
parser.add_argument("--no_self_evolution", action="store_true")
parser.add_argument("--hf_token", type=str, default=None)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"Device: {device}")
# ── 1. Tokenizer ──
tokenizer, bos_id, eos_id, pad_id = load_tokenizer(token=args.hf_token)
# ── 2. Config — ALL features enabled ──
config = DeeplmConfig(
vocab_size=32000,
max_seq_length=args.seq_len,
dtype="float32",
)
# Architecture
config.architecture.num_layers = 10
config.architecture.hidden_size = 512
config.architecture.intermediate_size = 2048
config.architecture.num_attention_heads = 8
config.architecture.num_key_value_heads = 1
config.architecture.head_dim = 128
config.architecture.rope_head_dim = 64
config.architecture.nope_head_dim = 64
config.architecture.max_seq_length = args.seq_len
config.architecture.rope_theta = 50000.0
# MLA
config.mla.q_lora_rank = 192
config.mla.kv_lora_rank = 64
config.mla.qk_rope_head_dim = 64
config.mla.qk_nope_head_dim = 64
config.mla.v_head_dim = 128
config.mla.num_heads = 8
config.mla.kv_heads = 1
# MoE
config.moe.num_routed_experts = 4
config.moe.num_shared_experts = 1
config.moe.top_k = 2
# MTP
config.mtp.num_mtp_layers = 2
config.mtp.mtp_depth = 2
config.mtp.mtp_hidden_size = 512
# Hybrid Attention
config.hybrid_attention.softmax_layers = [0, 4, 8]
config.hybrid_attention.linear_layers = [1, 2, 3, 5, 6, 7, 9]
# Hyper-Connections
config.hyper_connections.enabled = True
# Output heads
config.output_heads.lm_head.type = "tied"
config.output_heads.lm_head.bias = False
# ── 3. Model ──
model = build_model(config, bitnet=not args.no_bitnet)
# ── 4. Data ──
sample = args.sample_size if args.sample_size > 0 else None
texts = load_dataset_indonesian(sample_size=sample)
# StrictFilter
log("Applying StrictFilter (Indonesian quality)...")
filt = StrictFilter(config={
"min_length": 50,
"max_length": 4000,
"min_char_ratio": 0.25,
"max_repetition_ratio": 0.4,
"min_words": 10,
})
texts = filt.filter(texts, lang="id")
log(f" Filter: {filt.summary()}")
# Categorize for curriculum routing
log("Categorizing texts for curriculum routing...")
category_map = [categorize_text(t) for t in texts]
cat_counts = defaultdict(int)
for c in category_map:
cat_counts[c] += 1
for c, n in sorted(cat_counts.items()):
log(f" {c}: {n:,}")
# Split
eval_size = min(1000, len(texts) // 10)
train_texts = texts[eval_size:]
eval_texts = texts[:eval_size]
train_cats = category_map[eval_size:]
eval_cats = category_map[:eval_size]
log(f" Train: {len(train_texts):,} | Eval: {len(eval_texts):,}")
train_ds = Mc4Dataset(train_texts, tokenizer, args.seq_len, bos_id, eos_id, pad_id,
category_map=train_cats)
eval_ds = Mc4Dataset(eval_texts, tokenizer, args.seq_len, bos_id, eos_id, pad_id,
category_map=eval_cats)
# ── 5. Trainer ──
warmup_ratio = args.warmup_steps / max(args.max_steps, 1)
train_args = TrainingArgs(
output_dir=args.output_dir,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
weight_decay=args.weight_decay,
max_grad_norm=args.max_grad_norm,
warmup_ratio=warmup_ratio,
lr_schedule="cosine",
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
max_steps=args.max_steps,
gradient_checkpointing=True,
seed=args.seed,
eval_dataset=eval_ds,
max_eval_samples=100,
use_auto_tuner=not args.no_auto_tuner,
)
trainer = Trainer(
model=model,
config=config,
train_dataset=train_ds,
eval_dataset=eval_ds,
args=train_args,
)
# ── 6. SmartLogger ──
os.makedirs(args.output_dir, exist_ok=True)
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"
smart_log = SmartLogger(
log_dir=args.output_dir,
total_steps=args.max_steps,
model_params=model.num_parameters(),
batch_size=args.batch_size,
grad_accum=args.grad_accum,
seq_length=args.seq_len,
lr=args.lr,
vocab_size=config.vocab_size,
gpu_name=gpu_name,
)
metrics_tracker = MetricsTracker(window=500)
# ── 7. Curriculum Router ──
if not args.no_curriculum:
curriculum_router = CurriculumRouter()
log(f"CurriculumRouter: enabled (phase={curriculum_router.get_phase()})")
else:
curriculum_router = None
# ── 8. Self-Evolution Framework ──
if not args.no_self_evolution:
se_metrics = SEMetrics(window_size=100)
se_framework = SelfEvolutionFramework(
config=config.self_evolution,
model=model,
optimizer=trainer.optimizer,
metrics=se_metrics,
)
se_framework.memory.load()
log(f"Self-Evolution: enabled (max_iterations={se_framework.max_iterations})")
else:
se_framework = None
se_metrics = None
# ── 9. Training header ──
log(f"{C.BOLD}{'='*70}{C.END}")
log(f"{C.BOLD} Deeplm Training — ALL FEATURES{C.END}")
log(f"{C.BOLD}{'='*70}{C.END}")
log(f" Dataset: afrizalha/KamusOne-28M-Indonesian")
log(f" MLA: enabled (q_lora=192, kv_lora=64, 8 heads, 1 kv)")
log(f" MoE: enabled (4 routed + 1 shared, top-k=2)")
log(f" Hybrid Attn: Softmax [{','.join(map(str, config.hybrid_attention.softmax_layers))}] + Lightning")
log(f" Hyper-Conn: enabled (Sinkhorn routing)")
log(f" MTP: enabled (depth=2, {config.mtp.num_mtp_layers} layers)")
log(f" BitNet: {'enabled (absmean)' if not args.no_bitnet else 'disabled'}")
log(f" AutoTuner: {'enabled' if not args.no_auto_tuner else 'disabled'}")
log(f" Curriculum: {'enabled' if curriculum_router else 'disabled'}")
log(f" Self-Evolution: {'enabled' if se_framework else 'disabled'}")
log(f" SmartLogger: enabled (JSONL + anomaly detection + ETA)")
log(f" Steps: {args.max_steps:,}")
log(f" Batch: {args.batch_size} x grad_accum={args.grad_accum} = eff {args.batch_size * args.grad_accum}")
log(f" Seq len: {args.seq_len}")
log(f" LR: {args.lr:.2e} (warmup={args.warmup_steps})")
log(f" Output: {args.output_dir}")
log(f"{'='*70}")
# ── 10. Custom training loop with ALL features ──
trainer.model.train()
trainer.optimizer.zero_grad()
train_loader = trainer.get_train_dataloader()
eval_loader = trainer.get_eval_dataloader()
num_training_steps = args.max_steps
scheduler = trainer.get_scheduler(num_training_steps)
# AutoTuner init
if not args.no_auto_tuner:
trainer.auto_tuner = AutoTuner(
args.lr, args.warmup_steps, num_training_steps, args.max_grad_norm
)
if curriculum_router:
trainer.auto_tuner._curriculum_router = curriculum_router
log(" AutoTuner: initialized")
# Resume
if args.resume:
trainer._load_checkpoint(args.resume, scheduler)
micro_step = 0
t_start = time.time()
se_evolution_step = 0
se_consolidation_interval = 5000
for epoch in range(train_args.num_train_epochs):
trainer.epoch = epoch
for step, batch in enumerate(train_loader):
batch = {k: v.to(trainer.device) for k, v in batch.items()}
# Forward
output = trainer.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
output_mtp_loss=config.mtp.enabled,
)
loss = output["loss"] / args.grad_accum
loss.backward()
trainer.total_loss += loss.item() * args.grad_accum
micro_step += 1
if micro_step % args.grad_accum == 0:
# AutoTuner adjustments
if trainer.auto_tuner is not None:
adj = trainer.auto_tuner.get_adjustments(trainer.global_step)
trainer.auto_tuner.capture_gradients(trainer.model)
control = trainer.auto_tuner.get_training_control(trainer.global_step)
trainer.apply_training_control(control)
grad_norm_clip = trainer._get_grad_norm_limit()
else:
grad_norm_clip = args.max_grad_norm
torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), grad_norm_clip)
trainer.optimizer.step()
if trainer.auto_tuner is not None:
trainer.auto_tuner.capture_update(trainer.model)
trainer.optimizer.zero_grad()
scheduler.step()
trainer.global_step += 1
# Metrics
avg_loss = trainer.total_loss / (args.logging_steps * args.grad_accum)
lr = trainer.optimizer.param_groups[0]["lr"]
grad_norm = trainer._get_grad_norm()
elapsed = time.time() - t_start
n_tokens = batch["input_ids"].numel() * args.grad_accum
tok_per_sec = n_tokens * trainer.global_step / max(elapsed, 1.0)
# AutoTuner record
if trainer.auto_tuner is not None:
trainer.auto_tuner.record(trainer.global_step, avg_loss, grad_norm)
# Curriculum Router
if curriculum_router and trainer.auto_tuner is not None:
phase = trainer.auto_tuner.phase
curriculum_router.set_phase(phase, trainer.global_step)
signals = {
"plateau": trainer.auto_tuner.state == "plateau",
"overfit": getattr(trainer.auto_tuner, 'overfit_signal', False),
"sick": len(getattr(trainer.auto_tuner, 'sick_layers', set())) > 0,
}
curriculum_router.explore(trainer.global_step, signals)
# Self-Evolution metrics
if se_metrics is not None:
se_metrics.record(avg_loss, grad_norm, lr, tok_per_sec)
# SmartLogger
smart_log.log_step(
trainer.global_step, avg_loss, grad_norm, lr, tok_per_sec
)
metrics_tracker.update(
loss=avg_loss, grad_norm=grad_norm, lr=lr, tok_per_sec=tok_per_sec
)
# Self-Evolution periodic check
if se_framework is not None and trainer.global_step % se_consolidation_interval == 0:
log(f"{C.MAGENTA}[Self-Evolution] Running autonomous round {se_evolution_step + 1}{C.END}")
anomalies = se_metrics.detect_anomalies()
trend = se_metrics.get_trend()
hypothesis = se_framework._generate_hypothesis()
experiment = se_framework._design_experiment(hypothesis)
changes = se_framework._execute_changes(experiment)
bugs = se_framework._diagnose_bugs()
fixes = se_framework._apply_fixes(bugs)
result = se_framework._evaluate()
decision = se_framework._make_decision(result)
from deeplm.self_evolution.framework import EvolutionEpisode
episode = EvolutionEpisode(
step=trainer.global_step,
phase="autonomous_check",
hypothesis=hypothesis,
experiment_design=experiment,
changes_applied=changes,
metrics_before=se_metrics.get_recent(),
experiment_result=result,
decision=decision,
bugs_found=bugs,
fixes_applied=fixes,
timestamp=time.time(),
)
se_framework.memory.add(episode)
if decision == "keep":
se_framework._commit_changes(episode)
elif decision == "revert":
se_framework._revert_changes()
log(f" Hypothesis: {hypothesis}")
log(f" Result: {result:.2f} | Decision: {decision}")
if bugs:
log(f" Bugs: {bugs}")
if fixes:
log(f" Fixes: {fixes}")
se_evolution_step += 1
if se_evolution_step % 5 == 0:
se_framework.memory.save()
stats = se_framework.get_memory_stats()
log(f" Memory stats: {stats['total_entries']} entries, "
f"keep_rate={stats['consolidation'].get('keep_rate', 0):.2f}")
# Logging
if trainer.global_step % args.logging_steps == 0:
line = (f"Step {trainer.global_step:>6,} | Loss: {avg_loss:.4f} | "
f"LR: {lr:.6f} | Grad: {grad_norm:.2f} | "
f"Tok/s: {tok_per_sec:,.0f}")
if trainer.auto_tuner is not None:
at = trainer.auto_tuner
line += f" | AT: {at.state} LRx{at.lr_mult:.2f}"
if hasattr(at, 'diagnosis') and at.diagnosis != "unknown":
line += f" | Diag: {at.diagnosis}"
if curriculum_router is not None:
weights = curriculum_router.get_weights()
top_cats = sorted(weights.items(), key=lambda x: -x[1])[:3]
line += f" | Curr: {', '.join(f'{c}={w:.1f}' for c, w in top_cats)}"
print(f"{C.CYAN}[train]{C.END} {line}")
trainer.total_loss = 0.0
# Evaluation
if eval_loader and trainer.global_step % args.eval_steps == 0:
eval_loss = trainer.evaluate(eval_loader)
if trainer.auto_tuner is not None:
trainer.auto_tuner.capture_eval(trainer.global_step, eval_loss)
if eval_loss < trainer.best_eval_loss:
trainer.best_eval_loss = eval_loss
trainer._save_checkpoint(best=True)
trainer.model.train()
# Checkpoint
if trainer.global_step % args.save_steps == 0:
trainer._save_checkpoint()
if args.max_steps > 0 and trainer.global_step >= args.max_steps:
break
if args.max_steps > 0 and trainer.global_step >= args.max_steps:
break
if args.max_steps > 0 and trainer.global_step >= args.max_steps:
break
# Final save
trainer._save_checkpoint(final=True)
smart_log.log_summary(trainer.global_step, avg_loss, trainer.best_eval_loss)
if se_framework is not None:
se_framework.memory.save()
log(f"Self-Evolution memory saved ({len(se_framework.memory.entries)} episodes)")
elapsed = time.time() - t_start
log(f"{C.GREEN}Training complete in {elapsed/3600:.1f}h ({elapsed:.0f}s){C.END}")
if __name__ == "__main__":
main()