bbkdevops's picture
download
raw
8.26 kB
"""Training loop for Φ-Mind — the physics-derived LLM architecture."""
from __future__ import annotations
import json
import math
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.phimind import PhiMindConfig, PhiMindModel, count_params, phimind_tiny
# ---------------------------------------------------------------------------
# Dataset helpers (character-level UTF-8 tokenisation — no external tokenizer)
# ---------------------------------------------------------------------------
def _encode(text: str, vocab_size: int, max_len: int) -> torch.Tensor:
"""Encode text to token ids using UTF-8 byte values + 4 special ids."""
usable = max(vocab_size - 4, 1)
ids = [2] # <bos>
ids.extend(4 + (b % usable) for b in text.encode("utf-8"))
ids.append(3) # <eos>
return torch.tensor(ids[:max_len], dtype=torch.long)
def _collate(
sequences: list[torch.Tensor], pad_id: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
max_len = max(int(s.numel()) for s in sequences)
input_ids = torch.full((len(sequences), max_len), pad_id, dtype=torch.long)
labels = torch.full((len(sequences), max_len), -100, dtype=torch.long)
for i, seq in enumerate(sequences):
n = int(seq.numel())
input_ids[i, :n] = seq
labels[i, :n] = seq
labels[i, n:] = -100
return input_ids, labels
def _causal_lm_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""Shift-by-one cross-entropy loss for causal language modelling."""
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
return F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# ---------------------------------------------------------------------------
# Training config
# ---------------------------------------------------------------------------
@dataclass
class PhiMindTrainConfig:
# Data
data_path: str = "data/filtered"
out_dir: str = "checkpoints/phimind"
# Training
train_steps: int = 200
batch_size: int = 2
grad_accum: int = 4
lr: float = 3e-4
weight_decay: float = 0.01
warmup_steps: int = 20
clip_grad: float = 1.0
# Eval
eval_interval: int = 50
eval_steps: int = 10
# Logging
log_interval: int = 10
seed: int = 20260522
# ---------------------------------------------------------------------------
# Trainer
# ---------------------------------------------------------------------------
class PhiMindTrainer:
def __init__(
self,
model_cfg: PhiMindConfig,
train_cfg: PhiMindTrainConfig,
sequences: list[torch.Tensor],
eval_sequences: list[torch.Tensor] | None = None,
device: str = "cpu",
):
torch.manual_seed(train_cfg.seed)
self.cfg = train_cfg
self.device = torch.device(device)
self.model = PhiMindModel(model_cfg).to(self.device)
self.sequences = sequences
self.eval_sequences = eval_sequences or sequences[-2:]
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=train_cfg.lr,
weight_decay=train_cfg.weight_decay,
betas=(0.9, 0.95),
)
self.history: list[dict] = []
def _lr_schedule(self, step: int) -> float:
if step < self.cfg.warmup_steps:
return float(step + 1) / max(self.cfg.warmup_steps, 1)
return max(0.1, 0.5 * (1.0 + math.cos(
math.pi * (step - self.cfg.warmup_steps) /
max(self.cfg.train_steps - self.cfg.warmup_steps, 1)
)))
@torch.no_grad()
def eval_loss(self) -> float:
self.model.eval()
total = 0.0
count = 0
for i in range(0, len(self.eval_sequences), self.cfg.batch_size):
batch = self.eval_sequences[i : i + self.cfg.batch_size]
if not batch:
continue
input_ids, labels = _collate(batch)
input_ids = input_ids.to(self.device)
labels = labels.to(self.device)
out = self.model(input_ids)
loss = _causal_lm_loss(out["logits"], labels)
if torch.isfinite(loss):
total += float(loss.item())
count += 1
self.model.train()
return total / max(count, 1)
def train(self) -> dict:
self.model.train()
out_dir = Path(self.cfg.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
step = 0
micro_step = 0
accumulated_loss = 0.0
t0 = time.perf_counter()
initial_eval = self.eval_loss()
train_losses: list[float] = []
best_eval = float("inf")
grad_norm = 0.0
self.optimizer.zero_grad()
for step in range(self.cfg.train_steps):
# Learning rate warmup + cosine decay
scale = self._lr_schedule(step)
for pg in self.optimizer.param_groups:
pg["lr"] = self.cfg.lr * scale
# Mini-batch
idx = (step * self.cfg.batch_size) % max(len(self.sequences), 1)
batch = [
self.sequences[(idx + i) % len(self.sequences)]
for i in range(self.cfg.batch_size)
]
input_ids, labels = _collate(batch)
input_ids = input_ids.to(self.device)
labels = labels.to(self.device)
out = self.model(input_ids)
loss = _causal_lm_loss(out["logits"], labels) / self.cfg.grad_accum
loss.backward()
accumulated_loss += float(loss.item())
micro_step += 1
if micro_step % self.cfg.grad_accum == 0 or step == self.cfg.train_steps - 1:
gn = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.clip_grad
)
grad_norm = float(gn.item() if hasattr(gn, "item") else gn)
self.optimizer.step()
self.optimizer.zero_grad()
train_loss = accumulated_loss * self.cfg.grad_accum / self.cfg.grad_accum
train_losses.append(train_loss)
accumulated_loss = 0.0
if (step + 1) % self.cfg.log_interval == 0:
elapsed = time.perf_counter() - t0
self.history.append({
"step": step + 1,
"train_loss": train_loss,
"grad_norm": grad_norm,
"lr": self.cfg.lr * scale,
"elapsed_s": elapsed,
})
# Eval checkpoint
if (step + 1) % self.cfg.eval_interval == 0:
ev = self.eval_loss()
if ev < best_eval:
best_eval = ev
torch.save(
{
"step": step + 1,
"model_state": self.model.state_dict(),
"eval_loss": ev,
},
out_dir / "best.pt",
)
final_eval = self.eval_loss()
checkpoint_path = out_dir / "final.pt"
torch.save(
{
"step": self.cfg.train_steps,
"model_state": self.model.state_dict(),
"model_cfg": self.model.cfg,
"train_losses": train_losses,
"eval_loss": final_eval,
},
checkpoint_path,
)
return {
"train_steps": self.cfg.train_steps,
"initial_eval_loss": initial_eval,
"final_train_loss": train_losses[-1] if train_losses else float("nan"),
"final_eval_loss": final_eval,
"best_eval_loss": best_eval,
"perplexity": float(math.exp(min(final_eval, 20.0))),
"grad_norm": grad_norm,
"loss_decreased": final_eval < initial_eval,
"checkpoint_path": str(checkpoint_path),
"param_count": count_params(self.model),
}

Xet Storage Details

Size:
8.26 kB
·
Xet hash:
8f2b5ba26b92d439ed2009d6750064879929ac76ba8b67fe10b784f51bd5f8de

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.