bbkdevops's picture
download
raw
3.52 kB
"""
SpectralMind Quick-Start
========================
รัน: python -m train.spectral_quickstart
สร้าง SpectralMind nano model และ train บน data/filtered/clean_qa.jsonl
"""
from __future__ import annotations
import torch
from pathlib import Path
from torch.utils.data import DataLoader, random_split
from tokenizers import Tokenizer
from model.config import spectral_config, spectral_hyperparams
from model.spectral_compact import SpectralMindModel
from train.dataset import QADataset, collate_fn
from train.spectral_trainer import SpectralTrainer, SpectralTrainConfig
def main(
data_path: str = "data/filtered/clean_qa.jsonl",
tokenizer_path: str = "data/tokenizer/tokenizer.json",
size: str = "nano",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> None:
print("=" * 60)
print("SpectralMind — Pure Mathematics Training")
print("=" * 60)
# ── Model ─────────────────────────────────────────────────────────
cfg = spectral_config(size)
hp = spectral_hyperparams(cfg)
model = SpectralMindModel(
cfg,
attn_rank=hp["attn_rank"],
ffn_rank=hp["ffn_rank"],
bloom_buckets=hp["bloom_buckets"],
bloom_hashes=hp["bloom_hashes"],
)
print(model.count_params())
print()
# ── Data ──────────────────────────────────────────────────────────
if not Path(data_path).exists():
raise FileNotFoundError(
f"ไม่พบ {data_path}\n"
"รัน python -m data.collect && python -m data.filter ก่อน"
)
tokenizer = Tokenizer.from_file(tokenizer_path)
dataset = QADataset(data_path, tokenizer, max_len=cfg.max_seq_len)
n_val = min(len(dataset) // 10, 2000)
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(dataset, [n_train, n_val])
pad_id = cfg.pad_token_id
train_loader = DataLoader(
train_ds, batch_size=32, shuffle=True, num_workers=2,
collate_fn=lambda b: collate_fn(b, pad_id),
pin_memory=True,
)
val_loader = DataLoader(
val_ds, batch_size=64, shuffle=False, num_workers=2,
collate_fn=lambda b: collate_fn(b, pad_id),
pin_memory=True,
)
print(f"Train: {n_train:,} | Val: {n_val:,}")
# ── Train config ──────────────────────────────────────────────────
train_cfg = SpectralTrainConfig(
max_steps=30_000,
batch_size=32,
grad_accum=4,
lr=3e-4,
warmup_steps=500,
spectral_reg=1e-4,
rank_grow_every=1_000,
rank_grow_threshold=0.12,
rank_max_per_layer=64,
warmstart_batches=8,
log_every=100,
save_every=5_000,
run_name=f"spectral_{size}",
)
# ── Train ─────────────────────────────────────────────────────────
trainer = SpectralTrainer(model, train_loader, val_loader, train_cfg, device)
trainer.train(warmstart=True)
print("\nDone. Model saved to", train_cfg.save_dir)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
3.52 kB
·
Xet hash:
30a0f2d0435c922188e86c757bde91999d8d707123ff5510f4305c39314e176f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.