Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-colab-handoff /bundle /train /phimind_trainer.py
| """Training loop for Φ-Mind — the physics-derived LLM architecture.""" | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import time | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Iterable | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.phimind import PhiMindConfig, PhiMindModel, count_params, phimind_tiny | |
| # --------------------------------------------------------------------------- | |
| # Dataset helpers (character-level UTF-8 tokenisation — no external tokenizer) | |
| # --------------------------------------------------------------------------- | |
| def _encode(text: str, vocab_size: int, max_len: int) -> torch.Tensor: | |
| """Encode text to token ids using UTF-8 byte values + 4 special ids.""" | |
| usable = max(vocab_size - 4, 1) | |
| ids = [2] # <bos> | |
| ids.extend(4 + (b % usable) for b in text.encode("utf-8")) | |
| ids.append(3) # <eos> | |
| return torch.tensor(ids[:max_len], dtype=torch.long) | |
| def _collate( | |
| sequences: list[torch.Tensor], pad_id: int = 0 | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| max_len = max(int(s.numel()) for s in sequences) | |
| input_ids = torch.full((len(sequences), max_len), pad_id, dtype=torch.long) | |
| labels = torch.full((len(sequences), max_len), -100, dtype=torch.long) | |
| for i, seq in enumerate(sequences): | |
| n = int(seq.numel()) | |
| input_ids[i, :n] = seq | |
| labels[i, :n] = seq | |
| labels[i, n:] = -100 | |
| return input_ids, labels | |
| def _causal_lm_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | |
| """Shift-by-one cross-entropy loss for causal language modelling.""" | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = labels[:, 1:].contiguous() | |
| return F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Training config | |
| # --------------------------------------------------------------------------- | |
| class PhiMindTrainConfig: | |
| # Data | |
| data_path: str = "data/filtered" | |
| out_dir: str = "checkpoints/phimind" | |
| # Training | |
| train_steps: int = 200 | |
| batch_size: int = 2 | |
| grad_accum: int = 4 | |
| lr: float = 3e-4 | |
| weight_decay: float = 0.01 | |
| warmup_steps: int = 20 | |
| clip_grad: float = 1.0 | |
| # Eval | |
| eval_interval: int = 50 | |
| eval_steps: int = 10 | |
| # Logging | |
| log_interval: int = 10 | |
| seed: int = 20260522 | |
| # --------------------------------------------------------------------------- | |
| # Trainer | |
| # --------------------------------------------------------------------------- | |
| class PhiMindTrainer: | |
| def __init__( | |
| self, | |
| model_cfg: PhiMindConfig, | |
| train_cfg: PhiMindTrainConfig, | |
| sequences: list[torch.Tensor], | |
| eval_sequences: list[torch.Tensor] | None = None, | |
| device: str = "cpu", | |
| ): | |
| torch.manual_seed(train_cfg.seed) | |
| self.cfg = train_cfg | |
| self.device = torch.device(device) | |
| self.model = PhiMindModel(model_cfg).to(self.device) | |
| self.sequences = sequences | |
| self.eval_sequences = eval_sequences or sequences[-2:] | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=train_cfg.lr, | |
| weight_decay=train_cfg.weight_decay, | |
| betas=(0.9, 0.95), | |
| ) | |
| self.history: list[dict] = [] | |
| def _lr_schedule(self, step: int) -> float: | |
| if step < self.cfg.warmup_steps: | |
| return float(step + 1) / max(self.cfg.warmup_steps, 1) | |
| return max(0.1, 0.5 * (1.0 + math.cos( | |
| math.pi * (step - self.cfg.warmup_steps) / | |
| max(self.cfg.train_steps - self.cfg.warmup_steps, 1) | |
| ))) | |
| def eval_loss(self) -> float: | |
| self.model.eval() | |
| total = 0.0 | |
| count = 0 | |
| for i in range(0, len(self.eval_sequences), self.cfg.batch_size): | |
| batch = self.eval_sequences[i : i + self.cfg.batch_size] | |
| if not batch: | |
| continue | |
| input_ids, labels = _collate(batch) | |
| input_ids = input_ids.to(self.device) | |
| labels = labels.to(self.device) | |
| out = self.model(input_ids) | |
| loss = _causal_lm_loss(out["logits"], labels) | |
| if torch.isfinite(loss): | |
| total += float(loss.item()) | |
| count += 1 | |
| self.model.train() | |
| return total / max(count, 1) | |
| def train(self) -> dict: | |
| self.model.train() | |
| out_dir = Path(self.cfg.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| step = 0 | |
| micro_step = 0 | |
| accumulated_loss = 0.0 | |
| t0 = time.perf_counter() | |
| initial_eval = self.eval_loss() | |
| train_losses: list[float] = [] | |
| best_eval = float("inf") | |
| grad_norm = 0.0 | |
| self.optimizer.zero_grad() | |
| for step in range(self.cfg.train_steps): | |
| # Learning rate warmup + cosine decay | |
| scale = self._lr_schedule(step) | |
| for pg in self.optimizer.param_groups: | |
| pg["lr"] = self.cfg.lr * scale | |
| # Mini-batch | |
| idx = (step * self.cfg.batch_size) % max(len(self.sequences), 1) | |
| batch = [ | |
| self.sequences[(idx + i) % len(self.sequences)] | |
| for i in range(self.cfg.batch_size) | |
| ] | |
| input_ids, labels = _collate(batch) | |
| input_ids = input_ids.to(self.device) | |
| labels = labels.to(self.device) | |
| out = self.model(input_ids) | |
| loss = _causal_lm_loss(out["logits"], labels) / self.cfg.grad_accum | |
| loss.backward() | |
| accumulated_loss += float(loss.item()) | |
| micro_step += 1 | |
| if micro_step % self.cfg.grad_accum == 0 or step == self.cfg.train_steps - 1: | |
| gn = torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), self.cfg.clip_grad | |
| ) | |
| grad_norm = float(gn.item() if hasattr(gn, "item") else gn) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| train_loss = accumulated_loss * self.cfg.grad_accum / self.cfg.grad_accum | |
| train_losses.append(train_loss) | |
| accumulated_loss = 0.0 | |
| if (step + 1) % self.cfg.log_interval == 0: | |
| elapsed = time.perf_counter() - t0 | |
| self.history.append({ | |
| "step": step + 1, | |
| "train_loss": train_loss, | |
| "grad_norm": grad_norm, | |
| "lr": self.cfg.lr * scale, | |
| "elapsed_s": elapsed, | |
| }) | |
| # Eval checkpoint | |
| if (step + 1) % self.cfg.eval_interval == 0: | |
| ev = self.eval_loss() | |
| if ev < best_eval: | |
| best_eval = ev | |
| torch.save( | |
| { | |
| "step": step + 1, | |
| "model_state": self.model.state_dict(), | |
| "eval_loss": ev, | |
| }, | |
| out_dir / "best.pt", | |
| ) | |
| final_eval = self.eval_loss() | |
| checkpoint_path = out_dir / "final.pt" | |
| torch.save( | |
| { | |
| "step": self.cfg.train_steps, | |
| "model_state": self.model.state_dict(), | |
| "model_cfg": self.model.cfg, | |
| "train_losses": train_losses, | |
| "eval_loss": final_eval, | |
| }, | |
| checkpoint_path, | |
| ) | |
| return { | |
| "train_steps": self.cfg.train_steps, | |
| "initial_eval_loss": initial_eval, | |
| "final_train_loss": train_losses[-1] if train_losses else float("nan"), | |
| "final_eval_loss": final_eval, | |
| "best_eval_loss": best_eval, | |
| "perplexity": float(math.exp(min(final_eval, 20.0))), | |
| "grad_norm": grad_norm, | |
| "loss_decreased": final_eval < initial_eval, | |
| "checkpoint_path": str(checkpoint_path), | |
| "param_count": count_params(self.model), | |
| } | |
Xet Storage Details
- Size:
- 8.26 kB
- Xet hash:
- 8f2b5ba26b92d439ed2009d6750064879929ac76ba8b67fe10b784f51bd5f8de
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.