Feature Extraction
Transformers
Safetensors
PyTorch
English
eden
text-enhancement
grammar-correction
text-rewriting
encoder-decoder
transformer
custom_code
Instructions to use Rybib/EDEN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Rybib/EDEN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Rybib/EDEN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Rybib/EDEN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Training loop, evaluation, checkpointing, generation, CLI commands, and UI. | |
| This module holds the behavioural core of EDEN. The cleanly separable pieces | |
| (configuration, model, data, runtime helpers) live in sibling modules and are | |
| imported here. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import http.server | |
| import json | |
| import math | |
| import os | |
| import random | |
| import shlex | |
| import shutil | |
| import subprocess | |
| import sys | |
| import threading | |
| import time | |
| import urllib.parse | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from typing import Callable, Iterable | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .constants import * | |
| from .io_utils import * | |
| from .config import * | |
| from .runtime import * | |
| from .model import * | |
| from .data import * | |
| def lr_lambda_factory(total_updates: int, warmup: int, min_ratio: float): | |
| def lr_lambda(step: int) -> float: | |
| if step < warmup: | |
| return max(1e-6, step / max(1, warmup)) | |
| progress = (step - warmup) / max(1, total_updates - warmup) | |
| return max(min_ratio, 0.5 * (1.0 + math.cos(math.pi * progress))) | |
| return lr_lambda | |
| def checkpoint_payload( | |
| model: EdenTransformer, | |
| optimizer, | |
| scheduler, | |
| cfg: TrainConfig, | |
| epoch: int, | |
| step: int, | |
| best_val: float, | |
| completed_epoch: bool = False, | |
| ) -> dict: | |
| return { | |
| "model_state": {k: v.detach().cpu() for k, v in model.state_dict().items()}, | |
| "optimizer_state": optimizer.state_dict() if optimizer is not None else None, | |
| "scheduler_state": scheduler.state_dict() if scheduler is not None else None, | |
| "config": asdict(cfg), | |
| "epoch": epoch, | |
| "step": step, | |
| "best_val": best_val, | |
| "completed_epoch": completed_epoch, | |
| "special_tokens": SPECIAL_TOKENS, | |
| } | |
| def save_checkpoint( | |
| path: Path, | |
| model: EdenTransformer, | |
| optimizer, | |
| scheduler, | |
| cfg: TrainConfig, | |
| epoch: int, | |
| step: int, | |
| best_val: float, | |
| completed_epoch: bool = False, | |
| ) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| tmp = path.with_suffix(".tmp") | |
| torch.save( | |
| checkpoint_payload(model, optimizer, scheduler, cfg, epoch, step, best_val, completed_epoch), | |
| tmp, | |
| ) | |
| tmp.replace(path) | |
| def load_checkpoint(path: Path, map_location: str | torch.device = "cpu") -> dict: | |
| return torch.load(path, map_location=map_location, weights_only=False) | |
| def latest_checkpoint() -> Path: | |
| candidates = all_checkpoint_files() | |
| if not candidates: | |
| raise FileNotFoundError("No checkpoint found. Train first with: python3 main.py train") | |
| best_candidates = [p for p in candidates if p.name == "best.pt"] | |
| if best_candidates: | |
| return best_candidates[0] | |
| return candidates[0] | |
| def evaluate( | |
| model: EdenTransformer, | |
| rows: list[tuple[str, str]], | |
| tok, | |
| cfg: TrainConfig, | |
| device: torch.device, | |
| max_batches: int = 100, | |
| ) -> tuple[float, float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| total_tokens = 0 | |
| correct = 0 | |
| batches = make_batches(rows, cfg.batch_size, shuffle_batches=False)[:max_batches] | |
| for batch in batches: | |
| batch_rows = [rows[i] for i in batch] | |
| src, tin, tout = collate_pairs(batch_rows, tok, cfg) | |
| src = src.to(device) | |
| tin = tin.to(device) | |
| tout = tout.to(device) | |
| logits = model(src, tin) | |
| loss = F.cross_entropy( | |
| logits.float().reshape(-1, logits.size(-1)), | |
| tout.reshape(-1), | |
| ignore_index=-100, | |
| reduction="sum", | |
| ) | |
| mask = tout.ne(-100) | |
| preds = logits.argmax(-1) | |
| correct += (preds[mask] == tout[mask]).sum().item() | |
| seen = mask.sum().item() | |
| total_tokens += seen | |
| total_loss += loss.item() | |
| del src, tin, tout, logits, loss, preds, mask | |
| model.train() | |
| return total_loss / max(1, total_tokens), correct / max(1, total_tokens) | |
| def build_model_from_cfg(cfg: TrainConfig, device: torch.device) -> EdenTransformer: | |
| model = EdenTransformer(cfg) | |
| return model.to(device) | |
| def maybe_prepare_data(args, cfg: TrainConfig) -> None: | |
| custom_data_requested = bool(getattr(args, "data", None)) | |
| if ( | |
| PAIRS_PATH.exists() | |
| and TOKENIZER_PATH.exists() | |
| and not getattr(args, "rebuild_data", False) | |
| and not custom_data_requested | |
| ): | |
| return | |
| prepare_args = argparse.Namespace( | |
| recipe=getattr(args, "recipe", "m5-smart"), | |
| max_pairs=getattr(args, "max_pairs", cfg.max_pairs), | |
| vocab_size=cfg.vocab_size, | |
| include_c4=getattr(args, "include_c4", False), | |
| data=getattr(args, "data", None), | |
| force=True, | |
| ) | |
| command_prepare(prepare_args) | |
| def train_loop( | |
| cfg: TrainConfig, | |
| rows: list[tuple[str, str]], | |
| tok, | |
| device: torch.device, | |
| resume_path: Path | None = None, | |
| finetune: bool = False, | |
| checkpoint_dir: Path | None = None, | |
| ) -> Path: | |
| checkpoint_dir = checkpoint_dir or CHECKPOINT_DIR | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| session_dir = checkpoint_dir.parent if checkpoint_dir.name == "checkpoints" else checkpoint_dir | |
| session_name = session_dir.name if session_dir != CHECKPOINT_DIR else "legacy" | |
| set_seed(cfg.seed) | |
| train_rows, val_rows = split_train_val(list(rows), cfg.val_split, cfg.seed) | |
| batches_per_epoch = math.ceil(len(train_rows) / cfg.batch_size) | |
| updates_per_epoch = math.ceil(batches_per_epoch / cfg.grad_accum) | |
| total_updates = max(1, updates_per_epoch * cfg.epochs) | |
| model = build_model_from_cfg(cfg, device) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=cfg.lr, | |
| betas=(0.9, 0.98), | |
| eps=1e-9, | |
| weight_decay=cfg.weight_decay, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.LambdaLR( | |
| optimizer, | |
| lr_lambda_factory(total_updates, cfg.warmup_steps, cfg.min_lr_ratio), | |
| ) | |
| start_epoch = 1 | |
| global_step = 0 | |
| best_val = float("inf") | |
| if resume_path: | |
| ckpt = load_checkpoint(resume_path, map_location="cpu") | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.to(device) | |
| if not finetune: | |
| if ckpt.get("optimizer_state"): | |
| optimizer.load_state_dict(ckpt["optimizer_state"]) | |
| if ckpt.get("scheduler_state"): | |
| scheduler.load_state_dict(ckpt["scheduler_state"]) | |
| saved_epoch = max(1, int(ckpt.get("epoch", 1))) | |
| start_epoch = saved_epoch + 1 if ckpt.get("completed_epoch") else saved_epoch | |
| global_step = int(ckpt.get("step", 0)) | |
| best_val = float(ckpt.get("best_val", best_val)) | |
| log(f"Loaded checkpoint: {resume_path}") | |
| exact_params = model.parameter_count() | |
| total_steps = max(1, batches_per_epoch * cfg.epochs) | |
| log("") | |
| log("EDEN training") | |
| log(f" device: {device}") | |
| log(f" model: {exact_params / 1e6:.1f}M parameters") | |
| log(f" data: {len(train_rows):,} train / {len(val_rows):,} validation pairs") | |
| log(f" context: {cfg.max_len} tokens") | |
| log(f" batch: {cfg.batch_size} x accum {cfg.grad_accum} = effective {cfg.batch_size * cfg.grad_accum}") | |
| log(f" session: {session_name}") | |
| log(f" checkpoints: {checkpoint_dir}") | |
| log("") | |
| write_run_state( | |
| status="running", | |
| mode="finetune" if finetune else "train", | |
| device=str(device), | |
| params=exact_params, | |
| train_pairs=len(train_rows), | |
| val_pairs=len(val_rows), | |
| epoch=start_epoch, | |
| epochs=cfg.epochs, | |
| completed_epochs=max(0, start_epoch - 1), | |
| epoch_progress=0.0, | |
| epoch_steps_done=0, | |
| epoch_total_steps=batches_per_epoch, | |
| step=global_step, | |
| total_steps=total_steps, | |
| best_val=None if best_val == float("inf") else best_val, | |
| checkpoint=str(checkpoint_dir / "latest.pt"), | |
| session=session_name, | |
| session_dir=str(session_dir), | |
| config=asdict(cfg), | |
| ) | |
| write_metric( | |
| "start", | |
| epoch=start_epoch, | |
| epochs=cfg.epochs, | |
| step=global_step, | |
| total_steps=total_steps, | |
| params=exact_params, | |
| train_pairs=len(train_rows), | |
| val_pairs=len(val_rows), | |
| device=str(device), | |
| ) | |
| optimizer.zero_grad(set_to_none=True) | |
| running_loss = 0.0 | |
| running_count = 0 | |
| last_log = time.time() | |
| for epoch in range(start_epoch, cfg.epochs + 1): | |
| model.train() | |
| batches = make_batches(train_rows, cfg.batch_size, shuffle_batches=True) | |
| for batch_i, batch in enumerate(batches, start=1): | |
| if PAUSE_REQUEST_PATH.exists(): | |
| pause_path = checkpoint_dir / "pause.pt" | |
| latest_path = checkpoint_dir / "latest.pt" | |
| save_checkpoint(pause_path, model, optimizer, scheduler, cfg, epoch, global_step, best_val) | |
| save_checkpoint(latest_path, model, optimizer, scheduler, cfg, epoch, global_step, best_val) | |
| try: | |
| PAUSE_REQUEST_PATH.unlink() | |
| except FileNotFoundError: | |
| pass | |
| write_run_state( | |
| status="paused", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| completed_epochs=max(0, epoch - 1), | |
| epoch_progress=batch_i / max(1, len(batches)), | |
| epoch_steps_done=batch_i, | |
| epoch_total_steps=len(batches), | |
| step=global_step, | |
| total_steps=total_steps, | |
| progress=min(1.0, global_step / total_steps), | |
| checkpoint=str(pause_path), | |
| best_val=None if best_val == float("inf") else best_val, | |
| ) | |
| write_metric( | |
| "pause", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| step=global_step, | |
| total_steps=total_steps, | |
| checkpoint=str(pause_path), | |
| ) | |
| cleanup_device(device) | |
| log(f"Training paused. Checkpoint saved: {pause_path}") | |
| return pause_path | |
| rss, total_ram, frac = memory_fraction() | |
| if frac >= cfg.memory_stop_fraction: | |
| path = checkpoint_dir / "watchdog.pt" | |
| save_checkpoint(path, model, optimizer, scheduler, cfg, epoch, global_step, best_val) | |
| write_run_state( | |
| status="stopped", | |
| reason="memory_watchdog", | |
| epoch=epoch, | |
| step=global_step, | |
| completed_epochs=max(0, epoch - 1), | |
| epoch_progress=batch_i / max(1, len(batches)), | |
| epoch_steps_done=batch_i, | |
| epoch_total_steps=len(batches), | |
| checkpoint=str(path), | |
| ram_gb=rss, | |
| ram_total_gb=total_ram, | |
| ram_fraction=frac, | |
| ) | |
| write_metric( | |
| "stop", | |
| reason="memory_watchdog", | |
| epoch=epoch, | |
| step=global_step, | |
| ram_gb=rss, | |
| ram_total_gb=total_ram, | |
| ram_fraction=frac, | |
| ) | |
| cleanup_device(device) | |
| raise SystemExit( | |
| f"Memory watchdog stopped safely at {rss:.1f}/{total_ram:.0f} GB " | |
| f"({frac * 100:.0f}%). Saved resumable checkpoint: {path}" | |
| ) | |
| batch_rows = [train_rows[i] for i in batch] | |
| src, tin, tout = collate_pairs(batch_rows, tok, cfg) | |
| src = src.to(device) | |
| tin = tin.to(device) | |
| tout = tout.to(device) | |
| logits = model(src, tin) | |
| loss = F.cross_entropy( | |
| logits.float().reshape(-1, logits.size(-1)), | |
| tout.reshape(-1), | |
| ignore_index=-100, | |
| label_smoothing=cfg.label_smoothing, | |
| ) | |
| if not torch.isfinite(loss): | |
| optimizer.zero_grad(set_to_none=True) | |
| cleanup_device(device) | |
| raise RuntimeError("Loss became NaN/Inf. Try lowering lr or using the survivor recipe.") | |
| (loss / cfg.grad_accum).backward() | |
| loss_value = float(loss.item()) | |
| running_loss += loss_value | |
| running_count += 1 | |
| global_step += 1 | |
| if global_step % cfg.grad_accum == 0 or batch_i == len(batches): | |
| grad_norm = nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) | |
| if torch.isfinite(grad_norm): | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| del src, tin, tout, logits, loss | |
| if global_step % cfg.empty_cache_every == 0: | |
| cleanup_device(device) | |
| if global_step % cfg.log_every_steps == 0: | |
| elapsed = time.time() - last_log | |
| avg_loss = running_loss / max(1, running_count) | |
| lr = scheduler.get_last_lr()[0] | |
| rss, total_ram, frac = memory_fraction() | |
| log( | |
| f"epoch {epoch}/{cfg.epochs} step {global_step:,} " | |
| f"loss {avg_loss:.4f} lr {lr:.2e} " | |
| f"ram {rss:.1f}/{total_ram:.0f}GB ({frac * 100:.0f}%) " | |
| f"{elapsed:.1f}s" | |
| ) | |
| progress = min(1.0, global_step / total_steps) | |
| write_run_state( | |
| status="running", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| completed_epochs=max(0, epoch - 1), | |
| epoch_progress=batch_i / max(1, len(batches)), | |
| epoch_steps_done=batch_i, | |
| epoch_total_steps=len(batches), | |
| step=global_step, | |
| total_steps=total_steps, | |
| train_loss=avg_loss, | |
| lr=lr, | |
| progress=progress, | |
| ram_gb=rss, | |
| ram_total_gb=total_ram, | |
| ram_fraction=frac, | |
| ) | |
| write_metric( | |
| "step", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| step=global_step, | |
| total_steps=total_steps, | |
| loss=avg_loss, | |
| lr=lr, | |
| progress=progress, | |
| ram_gb=rss, | |
| ram_total_gb=total_ram, | |
| ram_fraction=frac, | |
| ) | |
| running_loss = 0.0 | |
| running_count = 0 | |
| last_log = time.time() | |
| if global_step % cfg.eval_every_steps == 0: | |
| val_loss, token_acc = evaluate(model, val_rows, tok, cfg, device) | |
| log(f"validation step {global_step:,}: loss {val_loss:.4f}, token_acc {token_acc * 100:.1f}%") | |
| save_checkpoint(checkpoint_dir / "latest.pt", model, optimizer, scheduler, cfg, epoch, global_step, best_val) | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| save_checkpoint(checkpoint_dir / "best.pt", model, optimizer, scheduler, cfg, epoch, global_step, best_val) | |
| log(f"new best checkpoint: {checkpoint_dir / 'best.pt'}") | |
| write_run_state( | |
| status="running", | |
| epoch=epoch, | |
| step=global_step, | |
| completed_epochs=max(0, epoch - 1), | |
| epoch_progress=batch_i / max(1, len(batches)), | |
| epoch_steps_done=batch_i, | |
| epoch_total_steps=len(batches), | |
| val_loss=val_loss, | |
| token_acc=token_acc, | |
| quality_percent=token_acc * 100.0, | |
| best_val=best_val, | |
| ) | |
| write_metric( | |
| "val", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| step=global_step, | |
| total_steps=total_steps, | |
| val_loss=val_loss, | |
| token_acc=token_acc, | |
| quality_percent=token_acc * 100.0, | |
| ) | |
| cleanup_device(device) | |
| if global_step % cfg.save_every_steps == 0: | |
| save_checkpoint(checkpoint_dir / "latest.pt", model, optimizer, scheduler, cfg, epoch, global_step, best_val) | |
| val_loss, token_acc = evaluate(model, val_rows, tok, cfg, device) | |
| log(f"end epoch {epoch}: val_loss {val_loss:.4f}, token_acc {token_acc * 100:.1f}%") | |
| save_checkpoint( | |
| checkpoint_dir / "latest.pt", | |
| model, | |
| optimizer, | |
| scheduler, | |
| cfg, | |
| epoch, | |
| global_step, | |
| best_val, | |
| completed_epoch=True, | |
| ) | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| save_checkpoint( | |
| checkpoint_dir / "best.pt", | |
| model, | |
| optimizer, | |
| scheduler, | |
| cfg, | |
| epoch, | |
| global_step, | |
| best_val, | |
| completed_epoch=True, | |
| ) | |
| write_run_state( | |
| status="running", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| completed_epochs=epoch, | |
| epoch_progress=1.0, | |
| epoch_steps_done=len(batches), | |
| epoch_total_steps=len(batches), | |
| step=global_step, | |
| total_steps=total_steps, | |
| val_loss=val_loss, | |
| token_acc=token_acc, | |
| quality_percent=token_acc * 100.0, | |
| best_val=best_val, | |
| ) | |
| write_metric( | |
| "epoch", | |
| epoch=epoch, | |
| epochs=cfg.epochs, | |
| step=global_step, | |
| total_steps=total_steps, | |
| val_loss=val_loss, | |
| token_acc=token_acc, | |
| quality_percent=token_acc * 100.0, | |
| ) | |
| cleanup_device(device) | |
| final_path = checkpoint_dir / "final.pt" | |
| save_checkpoint( | |
| final_path, | |
| model, | |
| optimizer, | |
| scheduler, | |
| cfg, | |
| cfg.epochs, | |
| global_step, | |
| best_val, | |
| completed_epoch=True, | |
| ) | |
| if not (checkpoint_dir / "best.pt").exists(): | |
| shutil.copy2(final_path, checkpoint_dir / "best.pt") | |
| log(f"Training complete. Final checkpoint: {final_path}") | |
| write_run_state( | |
| status="done", | |
| epoch=cfg.epochs, | |
| epochs=cfg.epochs, | |
| completed_epochs=cfg.epochs, | |
| epoch_progress=1.0, | |
| epoch_steps_done=batches_per_epoch, | |
| epoch_total_steps=batches_per_epoch, | |
| step=global_step, | |
| total_steps=total_steps, | |
| progress=1.0, | |
| checkpoint=str(final_path), | |
| best_val=best_val, | |
| ) | |
| write_metric("done", epoch=cfg.epochs, step=global_step, total_steps=total_steps, best_val=best_val) | |
| return final_path | |
| def beam_generate( | |
| model: EdenTransformer, | |
| src: torch.Tensor, | |
| cfg: TrainConfig, | |
| beam_size: int, | |
| max_new_tokens: int, | |
| length_penalty: float, | |
| repetition_penalty: float, | |
| ) -> list[int]: | |
| model.eval() | |
| device = src.device | |
| memory, src_padding = model.encode(src) | |
| beams: list[tuple[list[int], float, bool]] = [([BOS_ID], 0.0, False)] | |
| for _ in range(max_new_tokens): | |
| candidates: list[tuple[list[int], float, bool]] = [] | |
| active = [b for b in beams if not b[2]] | |
| if not active: | |
| break | |
| for tokens, score, done in beams: | |
| if done: | |
| candidates.append((tokens, score, True)) | |
| continue | |
| tgt = torch.tensor([tokens[-cfg.max_len:]], dtype=torch.long, device=device) | |
| hidden = model.decode(tgt, memory, src_padding) | |
| logits = model.lm_head(hidden[:, -1, :]).float().squeeze(0) | |
| if repetition_penalty != 1.0: | |
| for token_id in set(tokens): | |
| if 0 <= token_id < logits.numel(): | |
| logits[token_id] /= repetition_penalty | |
| logits[UNK_ID] = -float("inf") | |
| logits[PAD_ID] = -float("inf") | |
| logits[BOS_ID] = -float("inf") | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| values, indices = torch.topk(log_probs, k=min(beam_size, log_probs.numel())) | |
| for value, index in zip(values.tolist(), indices.tolist()): | |
| new_tokens = tokens + [int(index)] | |
| candidates.append((new_tokens, score + float(value), int(index) == EOS_ID)) | |
| def rank(item: tuple[list[int], float, bool]) -> float: | |
| toks, score, _ = item | |
| length = max(1, len(toks) - 1) | |
| return score / (length ** length_penalty) | |
| candidates.sort(key=rank, reverse=True) | |
| beams = candidates[:beam_size] | |
| if all(done for _, _, done in beams): | |
| break | |
| best = max(beams, key=lambda item: item[1] / (max(1, len(item[0]) - 1) ** length_penalty)) | |
| out = best[0][1:] | |
| if EOS_ID in out: | |
| out = out[: out.index(EOS_ID)] | |
| return [t for t in out if t not in (PAD_ID, BOS_ID, EOS_ID, UNK_ID)] | |
| def clamp_float(value, low: float, high: float, default: float) -> float: | |
| try: | |
| number = float(value) | |
| except (TypeError, ValueError): | |
| number = default | |
| return max(low, min(high, number)) | |
| def clamp_int(value, low: int, high: int, default: int) -> int: | |
| try: | |
| number = int(value) | |
| except (TypeError, ValueError): | |
| number = default | |
| return max(low, min(high, number)) | |
| def filter_top_k_top_p(logits: torch.Tensor, top_k: int, top_p: float) -> torch.Tensor: | |
| filtered = logits.clone() | |
| if top_k > 0 and top_k < filtered.numel(): | |
| threshold = torch.topk(filtered, top_k).values[-1] | |
| filtered[filtered < threshold] = -float("inf") | |
| if 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(filtered, descending=True) | |
| probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative = torch.cumsum(probs, dim=-1) | |
| remove = cumulative > top_p | |
| remove[1:] = remove[:-1].clone() | |
| remove[0] = False | |
| filtered[sorted_indices[remove]] = -float("inf") | |
| return filtered | |
| def token_generate( | |
| model: EdenTransformer, | |
| src: torch.Tensor, | |
| cfg: TrainConfig, | |
| strategy: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| repetition_penalty: float, | |
| ) -> list[int]: | |
| model.eval() | |
| device = src.device | |
| memory, src_padding = model.encode(src) | |
| tokens = [BOS_ID] | |
| for _ in range(max_new_tokens): | |
| tgt = torch.tensor([tokens[-cfg.max_len:]], dtype=torch.long, device=device) | |
| hidden = model.decode(tgt, memory, src_padding) | |
| logits = model.lm_head(hidden[:, -1, :]).float().squeeze(0) | |
| if repetition_penalty != 1.0: | |
| for token_id in set(tokens): | |
| if 0 <= token_id < logits.numel(): | |
| logits[token_id] /= repetition_penalty | |
| logits[UNK_ID] = -float("inf") | |
| logits[PAD_ID] = -float("inf") | |
| logits[BOS_ID] = -float("inf") | |
| if strategy == "sample": | |
| logits = logits / max(0.05, temperature) | |
| logits = filter_top_k_top_p(logits, top_k, top_p) | |
| probs = F.softmax(logits, dim=-1) | |
| if not torch.isfinite(probs).all() or float(probs.sum().item()) <= 0: | |
| next_id = int(torch.argmax(logits).item()) | |
| else: | |
| next_id = int(torch.multinomial(probs.detach().cpu(), 1).item()) | |
| else: | |
| next_id = int(torch.argmax(logits).item()) | |
| if next_id == EOS_ID: | |
| break | |
| if next_id not in (PAD_ID, BOS_ID, EOS_ID, UNK_ID): | |
| tokens.append(next_id) | |
| return tokens[1:] | |
| def decode_token_piece(tok, token_id: int) -> str: | |
| text = tok.decode([token_id]).replace("\u0120", " ").replace("\u010a", "\n") | |
| return text if text else f"[{token_id}]" | |
| def chunk_text_for_model(text: str, tok, cfg: TrainConfig) -> list[str]: | |
| text = normalise_text(text) | |
| ids = tok.encode(text).ids | |
| max_src = cfg.max_len - 2 | |
| if len(ids) <= max_src: | |
| return [text] | |
| chunks = [] | |
| current = [] | |
| current_ids = [] | |
| for sent in sentence_split(text) or [text]: | |
| sent_ids = tok.encode(sent).ids | |
| if current and len(current_ids) + len(sent_ids) > max_src: | |
| chunks.append(" ".join(current)) | |
| current = [] | |
| current_ids = [] | |
| if len(sent_ids) > max_src: | |
| for i in range(0, len(sent_ids), max_src): | |
| chunks.append(tok.decode(sent_ids[i : i + max_src])) | |
| else: | |
| current.append(sent) | |
| current_ids.extend(sent_ids) | |
| if current: | |
| chunks.append(" ".join(current)) | |
| return chunks | |
| def enhance_text( | |
| text: str, | |
| model: EdenTransformer, | |
| tok, | |
| cfg: TrainConfig, | |
| device: torch.device, | |
| beam_size: int | None = None, | |
| strategy: str = "beam", | |
| max_new_tokens: int | None = None, | |
| temperature: float = 0.7, | |
| top_k: int = 40, | |
| top_p: float = 0.9, | |
| length_penalty: float | None = None, | |
| repetition_penalty: float | None = None, | |
| return_details: bool = False, | |
| ): | |
| strategy = strategy if strategy in {"beam", "greedy", "sample"} else "beam" | |
| beam = max(1, int(beam_size or cfg.beam_size)) | |
| max_tokens = clamp_int(max_new_tokens, 8, max(8, cfg.max_len - 1), min(256, cfg.max_len - 1)) | |
| temp = clamp_float(temperature, 0.05, 2.0, 0.7) | |
| top_k_value = clamp_int(top_k, 0, 200, 40) | |
| top_p_value = clamp_float(top_p, 0.05, 1.0, 0.9) | |
| len_penalty = clamp_float(length_penalty, 0.05, 2.0, cfg.length_penalty) | |
| rep_penalty = clamp_float(repetition_penalty, 1.0, 2.0, cfg.repetition_penalty) | |
| outputs = [] | |
| trace = [] | |
| for chunk in chunk_text_for_model(text, tok, cfg): | |
| src_tokens = tok.encode(chunk).ids[: cfg.max_len - 2] | |
| src = torch.tensor([[BOS_ID] + src_tokens + [EOS_ID]], dtype=torch.long, device=device) | |
| if strategy == "beam": | |
| out_ids = beam_generate( | |
| model, | |
| src, | |
| cfg, | |
| beam_size=max(1, beam), | |
| max_new_tokens=max_tokens, | |
| length_penalty=len_penalty, | |
| repetition_penalty=rep_penalty, | |
| ) | |
| else: | |
| out_ids = token_generate( | |
| model, | |
| src, | |
| cfg, | |
| strategy=strategy, | |
| max_new_tokens=max_tokens, | |
| temperature=temp, | |
| top_k=top_k_value, | |
| top_p=top_p_value, | |
| repetition_penalty=rep_penalty, | |
| ) | |
| trace.extend(decode_token_piece(tok, token_id) for token_id in out_ids[:400]) | |
| decoded = tok.decode(out_ids).replace("\u0120", " ").replace("\u010a", "\n") | |
| decoded = normalise_text(decoded) | |
| outputs.append(decoded or chunk) | |
| result = normalise_text(" ".join(outputs)) | |
| if return_details: | |
| return { | |
| "output": result, | |
| "tokens": trace, | |
| "settings": { | |
| "strategy": strategy, | |
| "beam_size": beam, | |
| "max_new_tokens": max_tokens, | |
| "temperature": temp, | |
| "top_k": top_k_value, | |
| "top_p": top_p_value, | |
| "length_penalty": len_penalty, | |
| "repetition_penalty": rep_penalty, | |
| }, | |
| } | |
| return result | |
| def load_model_for_inference(checkpoint: Path | None, force_cpu: bool = False): | |
| ckpt_path = checkpoint or latest_checkpoint() | |
| ckpt = load_checkpoint(ckpt_path, map_location="cpu") | |
| cfg = TrainConfig(**ckpt["config"]) | |
| tok = load_tokenizer() | |
| device = device_for_training(force_cpu) | |
| model = build_model_from_cfg(cfg, device) | |
| model.load_state_dict(ckpt["model_state"]) | |
| model.to(device) | |
| model.eval() | |
| log(f"Loaded {ckpt_path.name} on {device} ({model.parameter_count() / 1e6:.1f}M params)") | |
| return model, tok, cfg, device | |
| def command_prepare(args) -> None: | |
| ensure_dirs() | |
| cfg = apply_recipe(args.recipe) | |
| if args.max_pairs: | |
| cfg.max_pairs = int(args.max_pairs) | |
| if args.vocab_size: | |
| cfg.vocab_size = int(args.vocab_size) | |
| if PAIRS_PATH.exists() and TOKENIZER_PATH.exists() and not args.force: | |
| log(f"Prepared data already exists in {DATA_DIR}") | |
| log("Use --force to rebuild it.") | |
| return | |
| set_seed(cfg.seed) | |
| rows: list[tuple[str, str]] = [] | |
| if args.data: | |
| custom_path = Path(args.data) | |
| log(f"Loading custom pairs from {custom_path}...") | |
| rows.extend(read_pairs_file(custom_path)) | |
| builtin_limit = max(1000, cfg.max_pairs - len(rows)) | |
| rows.extend(load_builtin_pairs(builtin_limit, args.include_c4, log)) | |
| rows = dedupe_pairs(rows, cfg.max_pairs) | |
| random.shuffle(rows) | |
| save_pairs(rows, PAIRS_PATH) | |
| log(f"Saved {len(rows):,} training pairs to {PAIRS_PATH}") | |
| log(f"Training tokenizer with vocab_size={cfg.vocab_size}...") | |
| tok = train_tokenizer(rows, cfg.vocab_size, TOKENIZER_PATH) | |
| cfg.vocab_size = tok.get_vocab_size() | |
| CONFIG_PATH.write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8") | |
| log(f"Tokenizer saved to {TOKENIZER_PATH} (actual vocab={cfg.vocab_size})") | |
| log("Prepare complete.") | |
| def command_train(args) -> None: | |
| ensure_dirs() | |
| resume = Path(args.resume) if args.resume else None | |
| if resume: | |
| ckpt = load_checkpoint(resume, map_location="cpu") | |
| cfg = TrainConfig(**ckpt["config"]) | |
| else: | |
| cfg = apply_recipe(args.recipe) | |
| if resume: | |
| log("Resume mode: using the checkpoint's saved training settings.") | |
| if args.epochs and not resume: | |
| cfg.epochs = int(args.epochs) | |
| if args.max_pairs and not resume: | |
| cfg.max_pairs = int(args.max_pairs) | |
| if args.lr and not resume: | |
| cfg.lr = float(args.lr) | |
| if getattr(args, "max_len", None) and not resume: | |
| cfg.max_len = int(args.max_len) | |
| if getattr(args, "batch_size", None) and not resume: | |
| cfg.batch_size = int(args.batch_size) | |
| if getattr(args, "grad_accum", None) and not resume: | |
| cfg.grad_accum = int(args.grad_accum) | |
| if getattr(args, "memory_stop_fraction", None) and not resume: | |
| cfg.memory_stop_fraction = float(args.memory_stop_fraction) | |
| maybe_prepare_data(args, cfg) | |
| rows = load_prepared_pairs(PAIRS_PATH) | |
| if not rows: | |
| raise SystemExit("No training pairs found. Run: python3 main.py prepare") | |
| rows = rows[: cfg.max_pairs] | |
| tok = load_tokenizer(TOKENIZER_PATH) | |
| tokenizer_vocab = tok.get_vocab_size() | |
| if resume and tokenizer_vocab != cfg.vocab_size: | |
| raise SystemExit( | |
| "The tokenizer does not match this checkpoint. Resume needs the same " | |
| "eden_system/data/tokenizer.json that was used when the checkpoint was created." | |
| ) | |
| cfg.vocab_size = tokenizer_vocab | |
| CONFIG_PATH.write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8") | |
| device = device_for_training(args.force_cpu) | |
| if resume: | |
| resume_session = session_dir_from_checkpoint(resume) | |
| session_dir = resume_session or next_training_session_dir() | |
| else: | |
| session_dir = next_training_session_dir() | |
| train_loop( | |
| cfg, | |
| rows, | |
| tok, | |
| device, | |
| resume_path=resume, | |
| finetune=False, | |
| checkpoint_dir=session_dir / "checkpoints", | |
| ) | |
| def command_finetune(args) -> None: | |
| ensure_dirs() | |
| base = Path(args.checkpoint) if args.checkpoint else latest_checkpoint() | |
| ckpt = load_checkpoint(base, map_location="cpu") | |
| cfg = TrainConfig(**ckpt["config"]) | |
| cfg.epochs = int(args.epochs) | |
| cfg.lr = float(args.lr) | |
| cfg.max_pairs = int(args.max_pairs) if args.max_pairs else cfg.max_pairs | |
| cfg.warmup_steps = min(cfg.warmup_steps, 200) | |
| custom = read_pairs_file(Path(args.data)) | |
| if args.mix_base and PAIRS_PATH.exists(): | |
| base_rows = load_prepared_pairs(PAIRS_PATH) | |
| random.shuffle(base_rows) | |
| keep = min(len(base_rows), max(len(custom), 2000)) | |
| rows = dedupe_pairs(custom + base_rows[:keep], cfg.max_pairs) | |
| else: | |
| rows = dedupe_pairs(custom, cfg.max_pairs) | |
| if len(rows) < 10: | |
| raise SystemExit("Fine-tuning needs at least 10 valid input/target pairs.") | |
| tok = load_tokenizer(TOKENIZER_PATH) | |
| tokenizer_vocab = tok.get_vocab_size() | |
| if tokenizer_vocab != cfg.vocab_size: | |
| raise SystemExit( | |
| "The tokenizer does not match this checkpoint. Fine-tuning needs the same " | |
| "eden_system/data/tokenizer.json that was used when the checkpoint was created." | |
| ) | |
| cfg.vocab_size = tokenizer_vocab | |
| device = device_for_training(args.force_cpu) | |
| log(f"Fine-tuning from {base} on {len(rows):,} pairs...") | |
| session_dir = next_training_session_dir() | |
| train_loop(cfg, rows, tok, device, resume_path=base, finetune=True, checkpoint_dir=session_dir / "checkpoints") | |
| def command_enhance(args) -> None: | |
| model, tok, cfg, device = load_model_for_inference( | |
| Path(args.checkpoint) if args.checkpoint else None, | |
| force_cpu=args.force_cpu, | |
| ) | |
| text = " ".join(args.text).strip() | |
| if not text: | |
| text = sys.stdin.read().strip() | |
| if not text: | |
| raise SystemExit("Provide text as an argument or via stdin.") | |
| result = enhance_text(text, model, tok, cfg, device, beam_size=args.beam) | |
| print(result) | |
| def command_interactive(args) -> None: | |
| model, tok, cfg, device = load_model_for_inference( | |
| Path(args.checkpoint) if args.checkpoint else None, | |
| force_cpu=args.force_cpu, | |
| ) | |
| log("Interactive EDEN. Type text and press Enter. Type /quit to stop.") | |
| while True: | |
| try: | |
| text = input("\nrough> ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| break | |
| if text.lower() in {"/q", "/quit", "quit", "exit"}: | |
| break | |
| if not text: | |
| continue | |
| print("clean> " + enhance_text(text, model, tok, cfg, device, beam_size=args.beam)) | |
| def command_eval(args) -> None: | |
| model, tok, cfg, device = load_model_for_inference( | |
| Path(args.checkpoint) if args.checkpoint else None, | |
| force_cpu=args.force_cpu, | |
| ) | |
| rows = load_prepared_pairs(PAIRS_PATH) | |
| if args.data: | |
| rows = read_pairs_file(Path(args.data)) | |
| if args.samples: | |
| rows = rows[: int(args.samples)] | |
| if not rows: | |
| raise SystemExit("No eval pairs found.") | |
| loss, acc = evaluate(model, rows, tok, cfg, device, max_batches=args.max_batches) | |
| log(f"eval loss {loss:.4f}, token_acc {acc * 100:.1f}% on {len(rows):,} pairs") | |
| def command_info(_args) -> None: | |
| ensure_dirs() | |
| log(f"Workspace: {ROOT}") | |
| log(f"System dir: {SYSTEM_DIR}") | |
| log(f"PyTorch: {torch.__version__}") | |
| log(f"MPS built: {torch.backends.mps.is_built()}") | |
| log(f"MPS available: {torch.backends.mps.is_available()}") | |
| for name in RECIPES: | |
| cfg = apply_recipe(name) | |
| log( | |
| f"Recipe {name}: ~{model_param_count(cfg) / 1e6:.1f}M params, " | |
| f"ctx {cfg.max_len}, batch {cfg.batch_size} x accum {cfg.grad_accum}, " | |
| f"pairs {cfg.max_pairs:,}" | |
| ) | |
| if PAIRS_PATH.exists(): | |
| log(f"Prepared pairs: {sum(1 for _ in PAIRS_PATH.open('r', encoding='utf-8')):,}") | |
| if TOKENIZER_PATH.exists(): | |
| tok = load_tokenizer(TOKENIZER_PATH) | |
| log(f"Tokenizer vocab: {tok.get_vocab_size():,}") | |
| checkpoints = all_checkpoint_files() | |
| if checkpoints: | |
| log("Checkpoints:") | |
| for path in checkpoints[:8]: | |
| try: | |
| label = str(path.relative_to(SYSTEM_DIR)) | |
| except ValueError: | |
| label = str(path) | |
| log(f" {label} ({path.stat().st_size / 1024 ** 2:.1f} MB)") | |
| def command_smoke(args) -> None: | |
| ensure_dirs() | |
| smoke_dir = SYSTEM_DIR / "smoke" | |
| smoke_dir.mkdir(parents=True, exist_ok=True) | |
| rows = [] | |
| for clean in SEED_CLEAN_SENTENCES: | |
| add_pair(rows, corrupt_sentence(clean, 0.5), clean) | |
| add_pair(rows, clean.lower(), clean) | |
| rows = dedupe_pairs(rows, 32) | |
| tok_path = smoke_dir / "tokenizer.json" | |
| tok = train_tokenizer(rows, 512, tok_path) | |
| cfg = TrainConfig( | |
| vocab_size=tok.get_vocab_size(), | |
| d_model=64, | |
| n_heads=4, | |
| n_layers=1, | |
| dim_feedforward=128, | |
| max_len=64, | |
| batch_size=2, | |
| grad_accum=1, | |
| epochs=1, | |
| eval_every_steps=9999, | |
| save_every_steps=9999, | |
| log_every_steps=1, | |
| memory_stop_fraction=0.95, | |
| ) | |
| device = device_for_training(args.force_cpu) | |
| model = build_model_from_cfg(cfg, device) | |
| opt = torch.optim.AdamW(model.parameters(), lr=1e-3) | |
| for step in range(3): | |
| batch = rows[step * 2 : step * 2 + 2] | |
| src, tin, tout = collate_pairs(batch, tok, cfg) | |
| src, tin, tout = src.to(device), tin.to(device), tout.to(device) | |
| logits = model(src, tin) | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tout.reshape(-1), ignore_index=-100) | |
| loss.backward() | |
| opt.step() | |
| opt.zero_grad(set_to_none=True) | |
| log(f"smoke step {step + 1}: loss {float(loss.item()):.4f}") | |
| cleanup_device(device) | |
| log("Smoke test passed.") | |
| def checkpoint_options() -> list[dict]: | |
| ensure_dirs() | |
| paths = all_checkpoint_files() | |
| out = [] | |
| for path in paths: | |
| try: | |
| stat = path.stat() | |
| except OSError: | |
| continue | |
| session_dir = session_dir_from_checkpoint(path) | |
| session = session_dir.name if session_dir else "legacy" | |
| try: | |
| label = str(path.relative_to(SYSTEM_DIR)) | |
| except ValueError: | |
| label = str(path) | |
| out.append({ | |
| "name": path.name, | |
| "path": str(path), | |
| "label": label, | |
| "session": session, | |
| "size_mb": stat.st_size / 1024 ** 2, | |
| "mtime": stat.st_mtime, | |
| }) | |
| return out | |
| def resolve_checkpoint_path(value: str | None) -> Path: | |
| if not value: | |
| return latest_checkpoint() | |
| path = Path(value).expanduser() | |
| if not path.is_absolute(): | |
| path = ROOT / path | |
| path = path.resolve() | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {path}") | |
| if path.suffix != ".pt": | |
| raise ValueError("Checkpoint must be a .pt file.") | |
| return path | |
| def resolve_finetune_data_path(value: str | None) -> Path: | |
| if not value: | |
| raise ValueError("Choose a fine-tune data file first.") | |
| path = Path(value).expanduser() | |
| if not path.is_absolute(): | |
| path = ROOT / path | |
| path = path.resolve() | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Fine-tune data file not found: {path}") | |
| if path.suffix.lower() not in {".jsonl", ".ndjson", ".json", ".csv", ".tsv"}: | |
| raise ValueError("Fine-tune data must be JSONL, JSON, CSV, or TSV.") | |
| return path | |
| UI_HTML = r"""<!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <meta http-equiv="Cache-Control" content="no-store"> | |
| <meta http-equiv="Pragma" content="no-cache"> | |
| <title>EDEN Dashboard</title> | |
| <style> | |
| :root { | |
| color-scheme: light; | |
| --bg: #f6f7f9; | |
| --surface: #ffffff; | |
| --surface-2: #edf1f5; | |
| --ink: #172033; | |
| --muted: #667085; | |
| --line: #d7dee8; | |
| --accent: #147c72; | |
| --accent-2: #315c9c; | |
| --warn: #a86416; | |
| --danger: #9f2f2f; | |
| --good: #247a3d; | |
| --header-bg: rgba(255,255,255,.94); | |
| --chart-bg: #fbfcfd; | |
| --field-bg: #ffffff; | |
| --log-bg: #171a17; | |
| --log-ink: #ecf6e8; | |
| --shadow: rgba(23, 32, 51, .06); | |
| } | |
| body[data-theme="dark"] { | |
| color-scheme: dark; | |
| --bg: #111827; | |
| --surface: #172033; | |
| --surface-2: #243044; | |
| --ink: #eef4ff; | |
| --muted: #aab6c8; | |
| --line: #344159; | |
| --accent: #18a79a; | |
| --accent-2: #6f93d6; | |
| --warn: #d18a32; | |
| --danger: #d65f5f; | |
| --good: #5ac878; | |
| --header-bg: rgba(23,32,51,.94); | |
| --chart-bg: #111827; | |
| --field-bg: #101725; | |
| --log-bg: #090d14; | |
| --log-ink: #dbeafe; | |
| --shadow: rgba(0, 0, 0, .20); | |
| } | |
| * { box-sizing: border-box; } | |
| body { | |
| margin: 0; | |
| background: var(--bg); | |
| color: var(--ink); | |
| font: 14px/1.45 ui-sans-serif, -apple-system, BlinkMacSystemFont, "Avenir Next", "Segoe UI", sans-serif; | |
| } | |
| header { | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| gap: 16px; | |
| padding: 16px 20px; | |
| border-bottom: 1px solid var(--line); | |
| background: var(--header-bg); | |
| position: sticky; | |
| top: 0; | |
| z-index: 4; | |
| backdrop-filter: blur(12px); | |
| } | |
| h1 { margin: 0; font-size: 20px; letter-spacing: 0; } | |
| main { max-width: 1240px; margin: 0 auto; padding: 18px; } | |
| .nav { display: flex; gap: 8px; flex-wrap: wrap; } | |
| .nav button, button { | |
| border: 1px solid var(--line); | |
| border-radius: 6px; | |
| background: var(--surface); | |
| color: var(--ink); | |
| padding: 9px 12px; | |
| font: inherit; | |
| font-weight: 700; | |
| cursor: pointer; | |
| } | |
| .nav button.active { background: var(--ink); color: #ffffff; } | |
| button.primary { background: var(--accent); color: white; border-color: var(--accent); } | |
| button.secondary { background: var(--accent-2); color: white; border-color: var(--accent-2); } | |
| button.warn { background: var(--warn); color: white; border-color: var(--warn); } | |
| button.danger { background: var(--danger); color: white; border-color: var(--danger); } | |
| button:disabled { opacity: .45; cursor: default; } | |
| .page { display: none; } | |
| .page.active { display: block; } | |
| .header-title { display: grid; gap: 2px; flex: 0 0 auto; } | |
| .top-status { | |
| width: min(380px, 34vw); | |
| min-width: 260px; | |
| border: 1px solid var(--line); | |
| border-radius: 8px; | |
| background: var(--surface); | |
| padding: 8px 10px; | |
| display: grid; | |
| gap: 2px; | |
| overflow: hidden; | |
| } | |
| .top-status span { white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } | |
| .nav { flex: 0 0 auto; } | |
| .header-meta { display: flex; gap: 12px; flex-wrap: wrap; align-items: center; } | |
| .panel { | |
| background: var(--surface); | |
| border: 1px solid var(--line); | |
| border-radius: 8px; | |
| padding: 14px; | |
| box-shadow: 0 10px 24px var(--shadow); | |
| } | |
| .topline { display: flex; align-items: center; justify-content: space-between; gap: 12px; margin-bottom: 12px; } | |
| .muted { color: var(--muted); } | |
| .small { font-size: 12px; } | |
| .cards { display: grid; grid-template-columns: repeat(auto-fit, minmax(165px, 1fr)); gap: 12px; margin-bottom: 12px; } | |
| .card { | |
| background: var(--surface); | |
| border: 1px solid var(--line); | |
| border-radius: 8px; | |
| padding: 12px; | |
| min-height: 86px; | |
| } | |
| .label { color: var(--muted); font-size: 12px; } | |
| .value { font-size: 24px; font-weight: 800; margin-top: 4px; overflow-wrap: anywhere; } | |
| .progress { height: 10px; background: var(--surface-2); border-radius: 999px; overflow: hidden; margin-top: 9px; } | |
| .bar { height: 100%; width: 0%; background: var(--accent); transition: width .25s ease; } | |
| .actions { display: flex; gap: 8px; flex-wrap: wrap; } | |
| .charts { display: grid; grid-template-columns: 1fr 1fr; gap: 12px; } | |
| canvas { display: block; width: 100%; height: 260px; background: var(--chart-bg); border-radius: 6px; } | |
| .form-grid { display: grid; grid-template-columns: repeat(3, minmax(0, 1fr)); gap: 12px; } | |
| label { display: grid; gap: 5px; color: var(--muted); font-size: 12px; font-weight: 700; } | |
| .hint { color: var(--muted); font-size: 11px; font-weight: 500; line-height: 1.35; } | |
| select, textarea, input[type="range"], input[type="text"] { | |
| width: 100%; | |
| font: inherit; | |
| } | |
| select, textarea, input[type="text"] { | |
| border: 1px solid var(--line); | |
| border-radius: 6px; | |
| background: var(--field-bg); | |
| color: var(--ink); | |
| padding: 9px 10px; | |
| } | |
| input[type="range"] { accent-color: var(--accent); } | |
| .range-value { color: var(--ink); font-size: 12px; font-weight: 800; } | |
| textarea { min-height: 150px; resize: vertical; } | |
| pre { | |
| margin: 0; | |
| min-height: 260px; | |
| max-height: 440px; | |
| overflow: auto; | |
| white-space: pre-wrap; | |
| border-radius: 6px; | |
| background: var(--log-bg); | |
| color: var(--log-ink); | |
| padding: 12px; | |
| font-size: 12px; | |
| } | |
| .monitor-log { margin-top: 12px; } | |
| .monitor-log pre { min-height: 180px; max-height: 280px; } | |
| .source-list { display: grid; gap: 8px; margin-top: 12px; } | |
| .source-item { border: 1px solid var(--line); border-radius: 6px; padding: 10px; background: var(--field-bg); } | |
| .source-item strong { display: block; margin-bottom: 2px; } | |
| .run-grid { display: grid; grid-template-columns: 1fr 1fr; gap: 12px; } | |
| .infer-grid { display: grid; grid-template-columns: repeat(2, minmax(0, 1fr)); gap: 12px; margin-top: 12px; } | |
| .token-trace { min-height: 120px; max-height: 220px; margin-top: 8px; } | |
| .modal-backdrop { | |
| position: fixed; | |
| inset: 0; | |
| display: none; | |
| align-items: center; | |
| justify-content: center; | |
| padding: 20px; | |
| background: rgba(32,32,29,.38); | |
| z-index: 20; | |
| } | |
| .modal-backdrop.open { display: flex; } | |
| .modal { | |
| width: min(620px, 100%); | |
| background: var(--surface); | |
| border: 1px solid var(--line); | |
| border-radius: 8px; | |
| padding: 16px; | |
| box-shadow: 0 24px 80px rgba(23, 32, 51, .24); | |
| } | |
| .modal-grid { display: grid; gap: 12px; margin: 14px 0; } | |
| @media (max-width: 920px) { | |
| header { align-items: flex-start; flex-direction: column; } | |
| .top-status { width: 100%; min-width: 0; } | |
| .cards { grid-template-columns: repeat(2, minmax(0, 1fr)); } | |
| .charts, .form-grid, .run-grid, .infer-grid { grid-template-columns: 1fr; } | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <header> | |
| <div class="header-title"> | |
| <h1>EDEN</h1> | |
| </div> | |
| <div class="top-status muted small"> | |
| <span id="statusText">Status: loading dashboard</span> | |
| <span id="etaText">ETA -</span> | |
| </div> | |
| <nav class="nav"> | |
| <button class="active" data-page="monitor">Monitor</button> | |
| <button data-page="settings">Settings</button> | |
| <button data-page="finetune">Fine-Tune</button> | |
| <button data-page="run">Run Model</button> | |
| <button data-page="logs">Logs</button> | |
| <button id="themeToggle" type="button">Dark Mode</button> | |
| </nav> | |
| </header> | |
| <main> | |
| <section id="monitor" class="page active"> | |
| <div class="topline"> | |
| <div> | |
| <strong>Training</strong> | |
| <div id="sessionLine" class="muted small">Session not started</div> | |
| </div> | |
| <div class="actions"> | |
| <button id="startBtn" class="primary">Start New Training</button> | |
| <button id="pauseBtn" class="warn">Pause + Save</button> | |
| <button id="resumeBtn" class="secondary">Resume</button> | |
| <button id="stopBtn" class="danger">Force Stop</button> | |
| </div> | |
| </div> | |
| <div class="cards"> | |
| <div class="card"><div class="label">Model Size</div><div id="params" class="value">-</div><div id="device" class="muted small">-</div></div> | |
| <div class="card"><div class="label">Epoch</div><div id="epochsDone" class="value">-</div><div id="epochNow" class="muted small">completed -</div></div> | |
| <div class="card"><div class="label">Current Epoch</div><div id="epochPercent" class="value">-</div><div class="progress"><div id="epochBar" class="bar"></div></div><div id="epochSteps" class="muted small">epoch steps -</div></div> | |
| <div class="card"><div class="label">Overall Training</div><div id="completion" class="value">-</div><div class="progress"><div id="progressBar" class="bar"></div></div><div id="stepsLeft" class="muted small">total steps -</div></div> | |
| <div class="card"><div class="label">Latest Loss</div><div id="loss" class="value">-</div><div id="lr" class="muted small">lr -</div></div> | |
| <div class="card"><div class="label">Validation Quality</div><div id="quality" class="value">-</div><div id="valLoss" class="muted small">validation -</div></div> | |
| </div> | |
| <div class="charts"> | |
| <div class="panel"><div class="label">Loss Per Step</div><canvas id="stepChart"></canvas></div> | |
| <div class="panel"><div class="label">Validation Loss / Accuracy</div><canvas id="valChart"></canvas></div> | |
| </div> | |
| <div class="panel monitor-log"> | |
| <div class="label">Training Log</div> | |
| <pre id="monitorLog"></pre> | |
| </div> | |
| </section> | |
| <section id="settings" class="page"> | |
| <div class="panel"> | |
| <div class="topline"> | |
| <div> | |
| <strong>Training Settings</strong> | |
| <div id="ramTip" class="muted small">Loading RAM recommendation...</div> | |
| </div> | |
| <button id="recommendedBtn" class="secondary" type="button">Use Recommended Settings</button> | |
| </div> | |
| <div class="form-grid"> | |
| <label><span>Model Recipe</span> | |
| <span class="hint">Controls model size. m5-smart is the best default for 32 GB RAM.</span> | |
| <select id="recipe"> | |
| <option value="m5-smart">56M params - M5 Smart (recommended)</option> | |
| <option value="survivor">23M params - Survivor / safest</option> | |
| <option value="m5-large">107M params - M5 Large / high memory</option> | |
| </select> | |
| </label> | |
| <label><span>Context Length</span> | |
| <span class="hint">How many tokens the model sees at once. Higher helps longer paragraphs but uses more memory.</span> | |
| <select id="maxLen"> | |
| <option value="256">256 tokens - short text / safest</option> | |
| <option value="320">320 tokens</option> | |
| <option value="384">384 tokens</option> | |
| <option value="512" selected>512 tokens - longer paragraphs (recommended)</option> | |
| </select> | |
| </label> | |
| <label><span>Batch Size</span> | |
| <span class="hint">How many examples run at once. Higher can train faster but pushes RAM harder.</span> | |
| <select id="batchSize"> | |
| <option value="1">1 example at once - safest</option> | |
| <option value="2" selected>2 examples at once (recommended)</option> | |
| <option value="4">4 examples at once - high RAM</option> | |
| </select> | |
| </label> | |
| <label><span>Gradient Accumulation</span> | |
| <span class="hint">Combines several small batches into one update. Higher is safer for memory but a little slower.</span> | |
| <select id="gradAccum"> | |
| <option value="8" selected>8 mini-batches per update (recommended)</option> | |
| <option value="16">16 mini-batches per update - safer</option> | |
| <option value="24">24 mini-batches per update</option> | |
| <option value="32">32 mini-batches per update - safest memory</option> | |
| </select> | |
| </label> | |
| <label><span>RAM Safety Limit</span> | |
| <span class="hint">Safety stop before macOS runs out of memory. 78% leaves about 7 GB free on a 32 GB Mac.</span> | |
| <select id="ramLimit"> | |
| <option value="0.72">72% RAM - extra safe</option> | |
| <option value="0.78" selected>78% RAM (recommended)</option> | |
| <option value="0.82">82% RAM - less headroom</option> | |
| </select> | |
| </label> | |
| <label><span>Training Epochs</span> | |
| <span class="hint">How many full passes through the training data. More can improve quality but takes longer.</span> | |
| <select id="epochs"> | |
| <option value="">Use recipe default (recommended)</option> | |
| <option value="1">1 epoch - quick test</option> | |
| <option value="3">3 epochs</option> | |
| <option value="6">6 epochs</option> | |
| <option value="8">8 epochs</option> | |
| <option value="10">10 epochs - longer training</option> | |
| </select> | |
| </label> | |
| <label><span>Training Pairs</span> | |
| <span class="hint">How many rough-to-clean examples to train on. More examples usually help quality and increase time.</span> | |
| <select id="pairs"> | |
| <option value="">Use recipe default (recommended)</option> | |
| <option value="5000">5k pairs - quick test</option> | |
| <option value="25000">25k pairs</option> | |
| <option value="80000">80k pairs</option> | |
| <option value="120000">120k pairs</option> | |
| <option value="180000">180k pairs - longer training</option> | |
| </select> | |
| </label> | |
| </div> | |
| </div> | |
| <div class="panel" style="margin-top:12px"> | |
| <div class="topline"> | |
| <div> | |
| <strong>Training Data</strong> | |
| <div id="dataSummary" class="muted small">Loading prepared data...</div> | |
| </div> | |
| </div> | |
| <div id="dataSources" class="source-list"></div> | |
| </div> | |
| </section> | |
| <section id="finetune" class="page"> | |
| <div class="panel"> | |
| <div class="topline"> | |
| <div> | |
| <strong>Fine-Tune</strong> | |
| <div class="muted small">Teach the selected model your own writing style or task using input/target examples.</div> | |
| </div> | |
| <button id="startFinetuneBtn" class="primary">Start Fine-Tune</button> | |
| </div> | |
| <div class="form-grid"> | |
| <label><span>Base Model</span> | |
| <select id="finetuneCheckpoint"></select> | |
| <span class="hint">The model checkpoint to specialize. Defaults to the current active checkpoint.</span> | |
| </label> | |
| <label><span>Custom Training File</span> | |
| <input id="finetuneData" type="text" readonly placeholder="Choose a JSONL, JSON, CSV, or TSV file"> | |
| <span class="hint">Use pairs with input/target columns, like rough text to polished text.</span> | |
| </label> | |
| <label><span>Learning Rate</span> | |
| <select id="finetuneLr"> | |
| <option value="0.00008" selected>0.00008 - careful style learning (recommended)</option> | |
| <option value="0.00005">0.00005 - extra gentle</option> | |
| <option value="0.00012">0.00012 - stronger change</option> | |
| <option value="0.0002">0.00020 - aggressive</option> | |
| </select> | |
| <span class="hint">Lower protects the model. Higher adapts faster but can damage general ability.</span> | |
| </label> | |
| <label><span>Fine-Tune Epochs</span> | |
| <select id="finetuneEpochs"> | |
| <option value="1">1 epoch - quick test</option> | |
| <option value="2">2 epochs</option> | |
| <option value="3" selected>3 epochs (recommended)</option> | |
| <option value="5">5 epochs - stronger style</option> | |
| <option value="8">8 epochs - high overfit risk</option> | |
| </select> | |
| <span class="hint">More passes learn your file harder. Stop early if it starts copying style too aggressively.</span> | |
| </label> | |
| <label><span>Max Fine-Tune Pairs</span> | |
| <select id="finetuneMaxPairs"> | |
| <option value="">Use all valid pairs (recommended)</option> | |
| <option value="100">100 pairs - tiny test</option> | |
| <option value="500">500 pairs</option> | |
| <option value="2000">2k pairs</option> | |
| <option value="10000">10k pairs</option> | |
| </select> | |
| <span class="hint">Limits how many examples from your file are used.</span> | |
| </label> | |
| <label><span>Preserve General Skill</span> | |
| <select id="finetuneMixBase"> | |
| <option value="on" selected>Mix base data in (recommended)</option> | |
| <option value="off">Only my custom file</option> | |
| </select> | |
| <span class="hint">Mixing base examples helps prevent forgetting normal spelling, grammar, and rewrite ability.</span> | |
| </label> | |
| </div> | |
| <div class="actions" style="margin-top:12px"> | |
| <button id="chooseFinetuneCheckpoint" class="secondary">Choose Base Model</button> | |
| <button id="chooseFinetuneData">Choose Data File</button> | |
| </div> | |
| </div> | |
| <div class="panel" style="margin-top:12px"> | |
| <div class="label">Example Fine-Tune File</div> | |
| <pre style="min-height:110px; max-height:180px">{"input":"i need this to sound more warm and clear","target":"I need this to sound warmer and clearer."} | |
| {"input":"rough text here","target":"Polished text here."}</pre> | |
| </div> | |
| </section> | |
| <section id="run" class="page"> | |
| <div class="run-grid"> | |
| <div class="panel"> | |
| <div class="topline"> | |
| <strong>Run Model</strong> | |
| <div class="actions"> | |
| <button id="chooseRunCheckpoint" class="secondary">Choose Checkpoint</button> | |
| <button id="openInferenceBtn" type="button">Inference Controls</button> | |
| </div> | |
| </div> | |
| <label>Checkpoint | |
| <select id="runCheckpoint"></select> | |
| <span class="hint">Defaults to the current training checkpoint when one is available.</span> | |
| </label> | |
| <label style="margin-top:12px">Text | |
| <textarea id="modelInput">i relly wnt this sentance to sound more profesional and clear</textarea> | |
| </label> | |
| <div class="actions" style="margin-top:12px"> | |
| <button id="enhanceBtn" class="primary">Run Model</button> | |
| </div> | |
| </div> | |
| <div class="panel"> | |
| <div class="label">Output</div> | |
| <pre id="modelOutput"></pre> | |
| <div id="tokenTraceWrap" style="display:none"> | |
| <div class="label" style="margin-top:12px">Token Trace</div> | |
| <pre id="tokenTrace" class="token-trace"></pre> | |
| </div> | |
| </div> | |
| </div> | |
| </section> | |
| <section id="logs" class="page"> | |
| <div class="panel"> | |
| <div class="label">Training Log</div> | |
| <pre id="log"></pre> | |
| </div> | |
| </section> | |
| </main> | |
| <div id="resumeModal" class="modal-backdrop"> | |
| <div class="modal"> | |
| <div class="topline"> | |
| <div> | |
| <strong id="checkpointModalTitle">Resume Training</strong> | |
| <div id="checkpointModalNote" class="muted small">Choose the training session first. The checkpoint's saved settings are restored automatically.</div> | |
| </div> | |
| </div> | |
| <div class="modal-grid"> | |
| <label>Training Session | |
| <select id="resumeSession"></select> | |
| </label> | |
| <label>Checkpoint | |
| <select id="resumeCheckpoint"></select> | |
| </label> | |
| </div> | |
| <div class="actions"> | |
| <button id="confirmResumeBtn" class="secondary">Resume</button> | |
| <button id="cancelResumeBtn">Cancel</button> | |
| </div> | |
| </div> | |
| </div> | |
| <div id="inferenceModal" class="modal-backdrop"> | |
| <div class="modal"> | |
| <div class="topline"> | |
| <div> | |
| <strong>Inference Controls</strong> | |
| <div class="muted small">Beam search is the recommended best-intelligence mode.</div> | |
| </div> | |
| <button id="recommendedInferBtn" class="secondary" type="button">Use Recommended Inference</button> | |
| </div> | |
| <div class="infer-grid"> | |
| <label><span>Generation Mode</span> | |
| <select id="inferMode"> | |
| <option value="beam" selected>Beam search - best answer (recommended)</option> | |
| <option value="greedy">Greedy - deterministic token by token</option> | |
| <option value="sample">Sampling - creative token by token</option> | |
| </select> | |
| <span class="hint">Beam compares several drafts and picks the strongest one. Sampling is more playful but less reliable.</span> | |
| </label> | |
| <label><span>Output View</span> | |
| <select id="showTokens"> | |
| <option value="off" selected>Final answer only (recommended)</option> | |
| <option value="on">Show token-by-token trace</option> | |
| </select> | |
| <span class="hint">Token trace shows the small pieces the model generated before they are joined into text.</span> | |
| </label> | |
| <label><span>Beam Width</span> | |
| <span id="beamSizeValue" class="range-value">4 (recommended)</span> | |
| <input id="beamSize" type="range" min="1" max="8" step="1" value="4"> | |
| <span class="hint">Higher checks more possible answers. 4 is the best balance for intelligence and speed.</span> | |
| </label> | |
| <label><span>Max Output Tokens</span> | |
| <span id="maxTokensValue" class="range-value">256 (recommended)</span> | |
| <input id="maxTokens" type="range" min="32" max="512" step="16" value="256"> | |
| <span class="hint">Caps how long the answer can be. Longer can help paragraphs but takes more time.</span> | |
| </label> | |
| <label><span>Temperature</span> | |
| <span id="temperatureValue" class="range-value">0.70 (recommended)</span> | |
| <input id="temperature" type="range" min="0.1" max="1.5" step="0.05" value="0.70"> | |
| <span class="hint">For sampling mode: lower is safer, higher is more creative.</span> | |
| </label> | |
| <label><span>Top-K</span> | |
| <span id="topKValue" class="range-value">40 (recommended)</span> | |
| <input id="topK" type="range" min="0" max="120" step="5" value="40"> | |
| <span class="hint">For sampling mode: limits choices to the top K tokens. 0 means no top-k limit.</span> | |
| </label> | |
| <label><span>Top-P</span> | |
| <span id="topPValue" class="range-value">0.90 (recommended)</span> | |
| <input id="topP" type="range" min="0.50" max="1.00" step="0.01" value="0.90"> | |
| <span class="hint">For sampling mode: keeps the most likely token group. 0.90 is a steady default.</span> | |
| </label> | |
| <label><span>Repetition Penalty</span> | |
| <span id="repetitionPenaltyValue" class="range-value">1.08 (recommended)</span> | |
| <input id="repetitionPenalty" type="range" min="1.00" max="1.50" step="0.01" value="1.08"> | |
| <span class="hint">Discourages repeating the same words. Too high can make wording strange.</span> | |
| </label> | |
| <label><span>Length Penalty</span> | |
| <span id="lengthPenaltyValue" class="range-value">0.70 (recommended)</span> | |
| <input id="lengthPenalty" type="range" min="0.10" max="1.50" step="0.05" value="0.70"> | |
| <span class="hint">For beam mode: nudges answer length. Lower is concise, higher allows longer wording.</span> | |
| </label> | |
| </div> | |
| <div class="actions" style="margin-top:14px"> | |
| <button id="closeInferenceBtn" class="primary" type="button">Done</button> | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const $ = (id) => document.getElementById(id); | |
| let lastData = {metrics: []}; | |
| let lastCheckpoints = []; | |
| let checkpointPickerMode = "resume"; | |
| let runCheckpointTouched = false; | |
| let finetuneCheckpointTouched = false; | |
| const navButtons = [...document.querySelectorAll(".nav button[data-page]")]; | |
| navButtons.forEach(btn => btn.onclick = () => { | |
| navButtons.forEach(b => b.classList.toggle("active", b === btn)); | |
| document.querySelectorAll(".page").forEach(p => p.classList.toggle("active", p.id === btn.dataset.page)); | |
| drawCharts(); | |
| }); | |
| function applyTheme(theme) { | |
| document.body.dataset.theme = theme; | |
| localStorage.setItem("edenTheme", theme); | |
| $("themeToggle").textContent = theme === "dark" ? "Light Mode" : "Dark Mode"; | |
| drawCharts(); | |
| } | |
| applyTheme(localStorage.getItem("edenTheme") || "light"); | |
| $("themeToggle").onclick = () => applyTheme(document.body.dataset.theme === "dark" ? "light" : "dark"); | |
| function fmtNum(n, digits=2) { | |
| if (n === undefined || n === null || Number.isNaN(Number(n))) return "-"; | |
| return Number(n).toFixed(digits); | |
| } | |
| function fmtParams(n) { | |
| if (!n) return "-"; | |
| return (Number(n) / 1e6).toFixed(1) + "M"; | |
| } | |
| function fmtDuration(seconds) { | |
| if (!Number.isFinite(seconds) || seconds <= 0) return "-"; | |
| const s = Math.round(seconds); | |
| const days = Math.floor(s / 86400); | |
| const hours = Math.floor((s % 86400) / 3600); | |
| const mins = Math.floor((s % 3600) / 60); | |
| if (days) return `${days}d ${hours}h`; | |
| if (hours) return `${hours}h ${mins}m`; | |
| return `${Math.max(1, mins)}m`; | |
| } | |
| function estimateRemaining(metrics, step, total, running, status) { | |
| if (total && step >= total) return "ETA complete"; | |
| if (!total) return "ETA -"; | |
| const remaining = Math.max(0, total - step); | |
| const points = (metrics || []) | |
| .filter(m => Number.isFinite(Number(m.time)) && Number.isFinite(Number(m.step))) | |
| .map(m => ({time: Number(m.time), step: Number(m.step)})) | |
| .filter(p => p.step <= step) | |
| .slice(-30); | |
| const last = points[points.length - 1]; | |
| let first = null; | |
| for (let i = points.length - 2; i >= 0; i--) { | |
| if (points[i].step < last?.step && points[i].time < last?.time) { | |
| first = points[i]; | |
| break; | |
| } | |
| } | |
| if (!first || !last) return running ? "ETA learning speed..." : "ETA -"; | |
| const rate = (last.step - first.step) / Math.max(1, last.time - first.time); | |
| if (!Number.isFinite(rate) || rate <= 0) return running ? "ETA learning speed..." : "ETA -"; | |
| const label = `ETA ${fmtDuration(remaining / rate)} remaining`; | |
| return running ? label : `${label} if resumed`; | |
| } | |
| async function post(url, data={}) { | |
| const res = await fetch(url, {method: "POST", cache: "no-store", headers: {"Content-Type": "application/json"}, body: JSON.stringify(data)}); | |
| return await res.json(); | |
| } | |
| function trainingPayload() { | |
| return { | |
| recipe: $("recipe").value, | |
| epochs: $("epochs").value, | |
| max_pairs: $("pairs").value, | |
| max_len: $("maxLen").value, | |
| batch_size: $("batchSize").value, | |
| grad_accum: $("gradAccum").value, | |
| memory_stop_fraction: $("ramLimit").value | |
| }; | |
| } | |
| function setSelectIfOption(id, value) { | |
| const select = $(id); | |
| const stringValue = String(value); | |
| if ([...select.options].some(opt => opt.value === stringValue)) { | |
| select.value = stringValue; | |
| } | |
| } | |
| function applyRecommendedSettings() { | |
| const rec = lastData.recommendation || {}; | |
| setSelectIfOption("recipe", "m5-smart"); | |
| if (rec.max_len) setSelectIfOption("maxLen", rec.max_len); | |
| if (rec.batch_size) setSelectIfOption("batchSize", rec.batch_size); | |
| if (rec.grad_accum) setSelectIfOption("gradAccum", rec.grad_accum); | |
| if (rec.memory_stop_fraction) setSelectIfOption("ramLimit", Number(rec.memory_stop_fraction).toFixed(2)); | |
| } | |
| function syncInferenceLabels() { | |
| $("beamSizeValue").textContent = `${$("beamSize").value}${$("beamSize").value === "4" ? " (recommended)" : ""}`; | |
| $("maxTokensValue").textContent = `${$("maxTokens").value}${$("maxTokens").value === "256" ? " (recommended)" : ""}`; | |
| $("temperatureValue").textContent = `${Number($("temperature").value).toFixed(2)}${Number($("temperature").value) === 0.7 ? " (recommended)" : ""}`; | |
| $("topKValue").textContent = `${$("topK").value}${$("topK").value === "40" ? " (recommended)" : ""}`; | |
| $("topPValue").textContent = `${Number($("topP").value).toFixed(2)}${Number($("topP").value) === 0.9 ? " (recommended)" : ""}`; | |
| $("repetitionPenaltyValue").textContent = `${Number($("repetitionPenalty").value).toFixed(2)}${Number($("repetitionPenalty").value) === 1.08 ? " (recommended)" : ""}`; | |
| $("lengthPenaltyValue").textContent = `${Number($("lengthPenalty").value).toFixed(2)}${Number($("lengthPenalty").value) === 0.7 ? " (recommended)" : ""}`; | |
| } | |
| function applyRecommendedInferenceSettings() { | |
| $("inferMode").value = "beam"; | |
| $("showTokens").value = "off"; | |
| $("beamSize").value = "4"; | |
| $("maxTokens").value = "256"; | |
| $("temperature").value = "0.70"; | |
| $("topK").value = "40"; | |
| $("topP").value = "0.90"; | |
| $("repetitionPenalty").value = "1.08"; | |
| $("lengthPenalty").value = "0.70"; | |
| syncInferenceLabels(); | |
| } | |
| function inferencePayload() { | |
| return { | |
| mode: $("inferMode").value, | |
| show_tokens: $("showTokens").value === "on", | |
| beam: $("beamSize").value, | |
| max_new_tokens: $("maxTokens").value, | |
| temperature: $("temperature").value, | |
| top_k: $("topK").value, | |
| top_p: $("topP").value, | |
| repetition_penalty: $("repetitionPenalty").value, | |
| length_penalty: $("lengthPenalty").value | |
| }; | |
| } | |
| function finetunePayload() { | |
| return { | |
| checkpoint: $("finetuneCheckpoint").value, | |
| data: $("finetuneData").value, | |
| epochs: $("finetuneEpochs").value, | |
| lr: $("finetuneLr").value, | |
| max_pairs: $("finetuneMaxPairs").value, | |
| mix_base: $("finetuneMixBase").value === "on" | |
| }; | |
| } | |
| function renderDataSources(summary) { | |
| $("dataSummary").textContent = summary?.prepared | |
| ? "Using prepared pairs in eden_system/data/pairs.jsonl. This list does not load the datasets." | |
| : "Prepared pairs are not built yet. This list shows the default sources EDEN will use."; | |
| $("dataSources").innerHTML = ""; | |
| (summary?.sources || []).forEach(source => { | |
| const item = document.createElement("div"); | |
| item.className = "source-item"; | |
| item.innerHTML = `<strong>${source.name}</strong><span class="muted small">${source.detail}</span>`; | |
| $("dataSources").appendChild(item); | |
| }); | |
| } | |
| function fillCheckpoints(select, checkpoints, preferredPath=null) { | |
| const current = preferredPath || select.value; | |
| select.innerHTML = ""; | |
| if (!checkpoints.length) { | |
| const opt = document.createElement("option"); | |
| opt.value = ""; | |
| opt.textContent = "no checkpoints yet"; | |
| select.appendChild(opt); | |
| return; | |
| } | |
| checkpoints.forEach((ckpt) => { | |
| const opt = document.createElement("option"); | |
| opt.value = ckpt.path; | |
| opt.textContent = `${ckpt.label || ckpt.name} (${ckpt.size_mb.toFixed(1)} MB)`; | |
| select.appendChild(opt); | |
| }); | |
| if ([...select.options].some(o => o.value === current)) { | |
| select.value = current; | |
| } else if (select.options.length) { | |
| select.value = select.options[0].value; | |
| } | |
| } | |
| function sessionForCheckpointPath(checkpoints, checkpointPath) { | |
| const match = checkpoints.find(ckpt => ckpt.path === checkpointPath); | |
| return match ? (match.session || "legacy") : ""; | |
| } | |
| function sessionNames(checkpoints) { | |
| const names = []; | |
| checkpoints.forEach(ckpt => { | |
| const name = ckpt.session || "legacy"; | |
| if (!names.includes(name)) names.push(name); | |
| }); | |
| return names; | |
| } | |
| function fillResumeSessions(checkpoints, preferredPath=null) { | |
| const preferredSession = sessionForCheckpointPath(checkpoints, preferredPath); | |
| const current = preferredSession || $("resumeSession").value; | |
| $("resumeSession").innerHTML = ""; | |
| const names = sessionNames(checkpoints); | |
| if (!names.length) { | |
| const opt = document.createElement("option"); | |
| opt.value = ""; | |
| opt.textContent = "no sessions yet"; | |
| $("resumeSession").appendChild(opt); | |
| fillResumeCheckpoints([]); | |
| return; | |
| } | |
| names.forEach(name => { | |
| const opt = document.createElement("option"); | |
| opt.value = name; | |
| opt.textContent = name === "legacy" ? "legacy checkpoints" : name; | |
| $("resumeSession").appendChild(opt); | |
| }); | |
| $("resumeSession").value = names.includes(current) ? current : names[0]; | |
| fillResumeCheckpoints(checkpoints, preferredPath); | |
| } | |
| function fillResumeCheckpoints(checkpoints, preferredPath=null) { | |
| const session = $("resumeSession").value; | |
| const current = preferredPath || $("resumeCheckpoint").value; | |
| const scoped = checkpoints.filter(ckpt => (ckpt.session || "legacy") === session); | |
| $("resumeCheckpoint").innerHTML = ""; | |
| scoped.forEach(ckpt => { | |
| const opt = document.createElement("option"); | |
| opt.value = ckpt.path; | |
| opt.textContent = `${ckpt.name} (${ckpt.size_mb.toFixed(1)} MB)`; | |
| $("resumeCheckpoint").appendChild(opt); | |
| }); | |
| if ([...$("resumeCheckpoint").options].some(o => o.value === current)) { | |
| $("resumeCheckpoint").value = current; | |
| } else if ($("resumeCheckpoint").options.length) { | |
| $("resumeCheckpoint").value = $("resumeCheckpoint").options[0].value; | |
| } | |
| } | |
| function openResumeModal() { | |
| checkpointPickerMode = "resume"; | |
| $("checkpointModalTitle").textContent = "Resume Training"; | |
| $("checkpointModalNote").textContent = "Choose the training session first. The checkpoint's saved settings are restored automatically."; | |
| $("confirmResumeBtn").textContent = "Resume"; | |
| fillResumeSessions(lastCheckpoints, lastData.state?.checkpoint || ""); | |
| $("resumeModal").classList.add("open"); | |
| } | |
| function openRunCheckpointModal() { | |
| checkpointPickerMode = "run"; | |
| $("checkpointModalTitle").textContent = "Choose Model Checkpoint"; | |
| $("checkpointModalNote").textContent = "Choose the training session first, then the checkpoint to run in the Run Model page."; | |
| $("confirmResumeBtn").textContent = "Use This Checkpoint"; | |
| fillResumeSessions(lastCheckpoints, $("runCheckpoint").value || lastData.state?.checkpoint || ""); | |
| $("resumeModal").classList.add("open"); | |
| } | |
| function openFinetuneCheckpointModal() { | |
| checkpointPickerMode = "finetune"; | |
| $("checkpointModalTitle").textContent = "Choose Base Model"; | |
| $("checkpointModalNote").textContent = "Choose the checkpoint you want to specialize with your custom examples."; | |
| $("confirmResumeBtn").textContent = "Use This Model"; | |
| fillResumeSessions(lastCheckpoints, $("finetuneCheckpoint").value || lastData.state?.checkpoint || ""); | |
| $("resumeModal").classList.add("open"); | |
| } | |
| function closeResumeModal() { | |
| $("resumeModal").classList.remove("open"); | |
| } | |
| function openInferenceModal() { | |
| syncInferenceLabels(); | |
| $("inferenceModal").classList.add("open"); | |
| } | |
| function closeInferenceModal() { | |
| $("inferenceModal").classList.remove("open"); | |
| } | |
| function drawLine(canvas, series, color, label, y2Series=null) { | |
| const ctx = canvas.getContext("2d"); | |
| const styles = getComputedStyle(document.body); | |
| const gridColor = styles.getPropertyValue("--line").trim() || "#d7dee8"; | |
| const textColor = styles.getPropertyValue("--muted").trim() || "#667085"; | |
| const goodColor = styles.getPropertyValue("--good").trim() || "#247a3d"; | |
| const dpr = window.devicePixelRatio || 1; | |
| const w = canvas.clientWidth || 500, h = canvas.clientHeight || 260; | |
| canvas.width = Math.floor(w * dpr); canvas.height = Math.floor(h * dpr); | |
| ctx.setTransform(dpr, 0, 0, dpr, 0, 0); | |
| ctx.clearRect(0, 0, w, h); | |
| ctx.strokeStyle = gridColor; ctx.lineWidth = 1; | |
| ctx.beginPath(); | |
| for (let i = 0; i <= 4; i++) { | |
| const y = 28 + (h - 52) * i / 4; | |
| ctx.moveTo(38, y); ctx.lineTo(w - 12, y); | |
| } | |
| ctx.stroke(); | |
| ctx.fillStyle = textColor; ctx.font = "12px Avenir Next, sans-serif"; | |
| ctx.fillText(label, 12, 18); | |
| if (!series.length) return; | |
| const vals = series.map(p => p.y).filter(Number.isFinite); | |
| const min = Math.min(...vals), max = Math.max(...vals); | |
| const span = Math.max(1e-6, max - min); | |
| const xFor = (i, len) => 38 + (w - 54) * (i / Math.max(1, len - 1)); | |
| const yFor = (v) => 28 + (h - 52) * (1 - (v - min) / span); | |
| ctx.strokeStyle = color; ctx.lineWidth = 2; | |
| ctx.beginPath(); | |
| series.forEach((p, i) => { | |
| const x = xFor(i, series.length), y = yFor(p.y); | |
| if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y); | |
| }); | |
| ctx.stroke(); | |
| ctx.fillStyle = textColor; | |
| ctx.fillText(max.toFixed(3), 4, 34); | |
| ctx.fillText(min.toFixed(3), 4, h - 20); | |
| if (y2Series && y2Series.length) { | |
| ctx.strokeStyle = goodColor; ctx.lineWidth = 2; | |
| ctx.beginPath(); | |
| y2Series.forEach((p, i) => { | |
| const x = xFor(i, y2Series.length); | |
| const y = 28 + (h - 52) * (1 - Math.max(0, Math.min(100, p.y)) / 100); | |
| if (i === 0) ctx.moveTo(x, y); else ctx.lineTo(x, y); | |
| }); | |
| ctx.stroke(); | |
| ctx.fillStyle = goodColor; ctx.fillText("accuracy %", w - 92, 18); | |
| } | |
| } | |
| function drawCharts() { | |
| const metrics = lastData.metrics || []; | |
| const steps = metrics.filter(m => m.kind === "step" && Number.isFinite(Number(m.loss))).map(m => ({x: m.step, y: Number(m.loss)})); | |
| const vals = metrics.filter(m => (m.kind === "val" || m.kind === "epoch") && Number.isFinite(Number(m.val_loss))).map(m => ({x: m.step, y: Number(m.val_loss)})); | |
| const accs = metrics.filter(m => (m.kind === "val" || m.kind === "epoch") && Number.isFinite(Number(m.quality_percent))).map(m => ({x: m.step, y: Number(m.quality_percent)})); | |
| drawLine($("stepChart"), steps, "#147c72", "train loss"); | |
| drawLine($("valChart"), vals, "#a86416", "val loss", accs); | |
| } | |
| async function chooseCheckpoint() { | |
| const picked = await post("/api/pick-checkpoint"); | |
| if (!picked.ok) return null; | |
| return picked.checkpoint; | |
| } | |
| async function refresh() { | |
| const res = await fetch("/api/status", {cache: "no-store"}); | |
| const data = await res.json(); | |
| lastData = data; | |
| const s = data.state || {}; | |
| const running = !!data.running; | |
| const checkpoints = data.checkpoints || []; | |
| lastCheckpoints = checkpoints; | |
| $("startBtn").disabled = running; | |
| $("resumeBtn").disabled = running || !checkpoints.length; | |
| $("pauseBtn").disabled = !running; | |
| $("stopBtn").disabled = !running; | |
| $("enhanceBtn").disabled = running || !checkpoints.length; | |
| $("startFinetuneBtn").disabled = running || !checkpoints.length; | |
| const statusLabel = running ? "running" : (s.status || "idle"); | |
| $("statusText").textContent = `Status: ${statusLabel}`; | |
| const recipeParams = data.recipe_params || {}; | |
| const paramCount = Number(s.params || recipeParams[$("recipe").value] || 0); | |
| $("params").textContent = paramCount ? fmtParams(paramCount) : "-"; | |
| $("device").textContent = s.device || data.device ? `device ${s.device || data.device}` : "-"; | |
| const step = Number(s.step || 0), total = Number(s.total_steps || 0); | |
| $("etaText").textContent = estimateRemaining(data.metrics || [], step, total, running, s.status || ""); | |
| const progress = total ? Math.max(0, Math.min(1, step / total)) : (s.progress || 0); | |
| const completed = Number(s.completed_epochs ?? (s.status === "done" ? s.epochs || 0 : Math.max(0, (s.epoch || 1) - 1))); | |
| let fallbackEpochProgress = 0; | |
| if (s.epochs && total) { | |
| const epochFloat = progress * Number(s.epochs); | |
| fallbackEpochProgress = progress >= 1 ? 1 : epochFloat - Math.floor(epochFloat); | |
| } | |
| const epochProgress = Math.max(0, Math.min(1, Number(s.epoch_progress ?? fallbackEpochProgress))); | |
| const epochDone = Number(s.epoch_steps_done || 0); | |
| const epochTotal = Number(s.epoch_total_steps || 0); | |
| const epochTotalCount = Number(s.epochs || 0); | |
| const currentEpoch = epochTotalCount ? Math.max(1, Math.min(epochTotalCount, Number(s.epoch || completed + 1))) : 0; | |
| $("epochsDone").textContent = epochTotalCount ? `${currentEpoch}/${epochTotalCount}` : "-"; | |
| $("epochNow").textContent = epochTotalCount ? `${completed} completed` : "completed -"; | |
| $("epochPercent").textContent = s.epochs ? `${(epochProgress * 100).toFixed(1)}%` : "-"; | |
| $("epochSteps").textContent = epochTotal ? `${epochDone.toLocaleString()} / ${epochTotal.toLocaleString()} this epoch` : "current epoch progress"; | |
| $("epochBar").style.width = `${Math.max(0, Math.min(100, epochProgress * 100))}%`; | |
| $("completion").textContent = total ? `${(progress * 100).toFixed(1)}%` : "-"; | |
| $("stepsLeft").textContent = total ? `${step.toLocaleString()} done / ${Math.max(0, total - step).toLocaleString()} left` : "steps -"; | |
| $("progressBar").style.width = `${Math.max(0, Math.min(100, progress * 100))}%`; | |
| $("loss").textContent = fmtNum(s.train_loss, 4); | |
| $("lr").textContent = s.lr ? `lr ${Number(s.lr).toExponential(2)}` : "lr -"; | |
| $("quality").textContent = s.quality_percent ? `${fmtNum(s.quality_percent, 1)}%` : "-"; | |
| $("valLoss").textContent = s.val_loss ? `val loss ${fmtNum(s.val_loss, 4)}` : "validation -"; | |
| $("sessionLine").textContent = s.session ? `${s.session} | ${s.session_dir || ""}` : (s.checkpoint || "No active session"); | |
| const rec = data.recommendation || {}; | |
| $("ramTip").textContent = data.ram_total_gb ? `Detected ${Number(data.ram_total_gb).toFixed(0)} GB RAM. Recommended: context ${rec.max_len}, batch ${rec.batch_size}, grad accumulation ${rec.grad_accum}, RAM limit ${(Number(rec.memory_stop_fraction || 0) * 100).toFixed(0)}%. ${rec.note || ""}` : "RAM recommendation unavailable."; | |
| renderDataSources(data.data_summary || {}); | |
| fillCheckpoints($("runCheckpoint"), checkpoints, runCheckpointTouched ? null : (s.checkpoint || null)); | |
| fillCheckpoints($("finetuneCheckpoint"), checkpoints, finetuneCheckpointTouched ? null : (s.checkpoint || null)); | |
| const logText = data.log || ""; | |
| ["log", "monitorLog"].forEach(id => { | |
| const el = $(id); | |
| if (!el) return; | |
| el.textContent = logText; | |
| el.scrollTop = el.scrollHeight; | |
| }); | |
| drawCharts(); | |
| } | |
| $("startBtn").onclick = async () => { | |
| const res = await post("/api/start", trainingPayload()); | |
| if (!res.ok) alert(res.error || "Could not start training"); | |
| refresh(); | |
| }; | |
| $("resumeBtn").onclick = openResumeModal; | |
| $("resumeSession").onchange = () => fillResumeCheckpoints(lastCheckpoints); | |
| $("cancelResumeBtn").onclick = closeResumeModal; | |
| $("resumeModal").onclick = (event) => { | |
| if (event.target === $("resumeModal")) closeResumeModal(); | |
| }; | |
| $("openInferenceBtn").onclick = openInferenceModal; | |
| $("closeInferenceBtn").onclick = closeInferenceModal; | |
| $("inferenceModal").onclick = (event) => { | |
| if (event.target === $("inferenceModal")) closeInferenceModal(); | |
| }; | |
| $("confirmResumeBtn").onclick = async () => { | |
| const checkpoint = $("resumeCheckpoint").value; | |
| if (!checkpoint) return; | |
| if (checkpointPickerMode === "run") { | |
| fillCheckpoints($("runCheckpoint"), lastCheckpoints, checkpoint); | |
| runCheckpointTouched = true; | |
| closeResumeModal(); | |
| return; | |
| } | |
| if (checkpointPickerMode === "finetune") { | |
| fillCheckpoints($("finetuneCheckpoint"), lastCheckpoints, checkpoint); | |
| finetuneCheckpointTouched = true; | |
| closeResumeModal(); | |
| return; | |
| } | |
| const payload = {checkpoint}; | |
| const res = await post("/api/resume", payload); | |
| if (!res.ok) alert(res.error || "Could not resume training"); | |
| closeResumeModal(); | |
| refresh(); | |
| }; | |
| $("pauseBtn").onclick = async () => { | |
| const res = await post("/api/pause"); | |
| if (!res.ok) alert(res.error || "Could not pause training"); | |
| refresh(); | |
| }; | |
| $("stopBtn").onclick = async () => { | |
| const res = await post("/api/stop"); | |
| if (!res.ok) alert(res.error || "Could not stop training"); | |
| refresh(); | |
| }; | |
| $("recommendedBtn").onclick = applyRecommendedSettings; | |
| $("chooseRunCheckpoint").onclick = openRunCheckpointModal; | |
| $("runCheckpoint").onchange = () => { runCheckpointTouched = true; }; | |
| $("chooseFinetuneCheckpoint").onclick = openFinetuneCheckpointModal; | |
| $("finetuneCheckpoint").onchange = () => { finetuneCheckpointTouched = true; }; | |
| $("chooseFinetuneData").onclick = async () => { | |
| const picked = await post("/api/pick-data"); | |
| if (picked.ok) $("finetuneData").value = picked.path; | |
| }; | |
| $("startFinetuneBtn").onclick = async () => { | |
| const payload = finetunePayload(); | |
| if (!payload.data) { | |
| alert("Choose a fine-tune data file first."); | |
| return; | |
| } | |
| const res = await post("/api/finetune", payload); | |
| if (!res.ok) alert(res.error || "Could not start fine-tuning"); | |
| refresh(); | |
| }; | |
| $("recommendedInferBtn").onclick = applyRecommendedInferenceSettings; | |
| ["beamSize", "maxTokens", "temperature", "topK", "topP", "repetitionPenalty", "lengthPenalty"].forEach(id => { | |
| $(id).oninput = syncInferenceLabels; | |
| }); | |
| $("enhanceBtn").onclick = async () => { | |
| $("modelOutput").textContent = "Running..."; | |
| $("tokenTraceWrap").style.display = "none"; | |
| $("tokenTrace").textContent = ""; | |
| const payload = {checkpoint: $("runCheckpoint").value, text: $("modelInput").value, ...inferencePayload()}; | |
| const res = await post("/api/enhance", payload); | |
| $("modelOutput").textContent = res.ok ? res.output : (res.error || "Could not run model"); | |
| if (res.ok && payload.show_tokens) { | |
| $("tokenTraceWrap").style.display = "block"; | |
| $("tokenTrace").textContent = (res.tokens || []).map((piece, idx) => `${idx + 1}. ${piece}`).join("\n"); | |
| } | |
| }; | |
| window.addEventListener("resize", drawCharts); | |
| syncInferenceLabels(); | |
| refresh(); | |
| setInterval(refresh, 2000); | |
| </script> | |
| </body> | |
| </html>""" | |
| def command_ui(args) -> None: | |
| ensure_dirs() | |
| process_holder: dict[str, subprocess.Popen | None] = {"process": None} | |
| log_file_holder = {"file": None} | |
| model_cache: dict[str, object] = {"path": None, "model": None, "tok": None, "cfg": None, "device": None} | |
| def running_process() -> subprocess.Popen | None: | |
| proc = process_holder.get("process") | |
| if proc is not None and proc.poll() is None: | |
| return proc | |
| return None | |
| def json_response(handler, payload: dict, status: int = 200) -> None: | |
| body = json.dumps(payload).encode("utf-8") | |
| handler.send_response(status) | |
| handler.send_header("Content-Type", "application/json; charset=utf-8") | |
| handler.send_header("Cache-Control", "no-store") | |
| handler.send_header("Pragma", "no-cache") | |
| handler.send_header("Expires", "0") | |
| handler.send_header("Content-Length", str(len(body))) | |
| handler.end_headers() | |
| handler.wfile.write(body) | |
| def current_data_summary() -> dict: | |
| return { | |
| "prepared": PAIRS_PATH.exists(), | |
| "pairs_path": str(PAIRS_PATH), | |
| "tokenizer_path": str(TOKENIZER_PATH) if TOKENIZER_PATH.exists() else None, | |
| "sources": DATASET_SOURCES, | |
| } | |
| def start_training_process(body: dict, resume_path: Path | None = None) -> subprocess.Popen: | |
| recipe = body.get("recipe") or "m5-smart" | |
| if recipe not in RECIPES and resume_path is None: | |
| raise ValueError(f"Unknown recipe: {recipe}") | |
| for path in (METRICS_PATH, STATE_PATH, TRAIN_LOG_PATH, PAUSE_REQUEST_PATH): | |
| try: | |
| path.unlink() | |
| except FileNotFoundError: | |
| pass | |
| cmd = [sys.executable, "-m", "eden.cli", "train"] | |
| if resume_path is not None: | |
| cmd += ["--resume", str(resume_path)] | |
| else: | |
| cmd += ["--recipe", recipe] | |
| if resume_path is None: | |
| for body_key, flag in [ | |
| ("epochs", "--epochs"), | |
| ("max_pairs", "--max-pairs"), | |
| ("max_len", "--max-len"), | |
| ("batch_size", "--batch-size"), | |
| ("grad_accum", "--grad-accum"), | |
| ("memory_stop_fraction", "--memory-stop-fraction"), | |
| ]: | |
| value = body.get(body_key) | |
| if value not in (None, ""): | |
| cmd += [flag, str(value)] | |
| env = dict(os.environ) | |
| env["PYTHONUNBUFFERED"] = "1" | |
| log_fh = TRAIN_LOG_PATH.open("a", encoding="utf-8") | |
| old_fh = log_file_holder.get("file") | |
| if old_fh: | |
| old_fh.close() | |
| log_file_holder["file"] = log_fh | |
| proc = subprocess.Popen( | |
| cmd, | |
| cwd=str(ROOT), | |
| stdout=log_fh, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| ) | |
| process_holder["process"] = proc | |
| write_run_state( | |
| status="starting", | |
| command=shlex.join(cmd), | |
| pid=proc.pid, | |
| resume_from=str(resume_path) if resume_path else None, | |
| ) | |
| return proc | |
| def start_finetune_process(body: dict, checkpoint: Path, data_path: Path) -> subprocess.Popen: | |
| for path in (METRICS_PATH, STATE_PATH, TRAIN_LOG_PATH, PAUSE_REQUEST_PATH): | |
| try: | |
| path.unlink() | |
| except FileNotFoundError: | |
| pass | |
| cmd = [ | |
| sys.executable, | |
| "-m", | |
| "eden.cli", | |
| "finetune", | |
| "--checkpoint", | |
| str(checkpoint), | |
| "--data", | |
| str(data_path), | |
| "--epochs", | |
| str(body.get("epochs") or 3), | |
| "--lr", | |
| str(body.get("lr") or 8e-5), | |
| ] | |
| if body.get("max_pairs") not in (None, ""): | |
| cmd += ["--max-pairs", str(body.get("max_pairs"))] | |
| mix_base = body.get("mix_base", True) | |
| if mix_base not in (False, "false", "off", "0", 0): | |
| cmd.append("--mix-base") | |
| env = dict(os.environ) | |
| env["PYTHONUNBUFFERED"] = "1" | |
| log_fh = TRAIN_LOG_PATH.open("a", encoding="utf-8") | |
| old_fh = log_file_holder.get("file") | |
| if old_fh: | |
| old_fh.close() | |
| log_file_holder["file"] = log_fh | |
| proc = subprocess.Popen( | |
| cmd, | |
| cwd=str(ROOT), | |
| stdout=log_fh, | |
| stderr=subprocess.STDOUT, | |
| env=env, | |
| text=True, | |
| ) | |
| process_holder["process"] = proc | |
| write_run_state( | |
| status="starting", | |
| mode="finetune", | |
| command=shlex.join(cmd), | |
| pid=proc.pid, | |
| checkpoint=str(checkpoint), | |
| finetune_data=str(data_path), | |
| ) | |
| return proc | |
| def pick_checkpoint_with_dialog() -> Path | None: | |
| default_dir = SESSIONS_DIR if SESSIONS_DIR.exists() else CHECKPOINT_DIR | |
| script = ( | |
| 'set chosenFile to choose file with prompt "Choose an EDEN checkpoint to resume from" ' | |
| f'default location POSIX file "{default_dir}/"\n' | |
| "POSIX path of chosenFile" | |
| ) | |
| result = subprocess.run(["osascript", "-e", script], capture_output=True, text=True) | |
| if result.returncode != 0: | |
| return None | |
| picked = result.stdout.strip() | |
| return resolve_checkpoint_path(picked) if picked else None | |
| def pick_data_file_with_dialog() -> Path | None: | |
| script = ( | |
| 'set chosenFile to choose file with prompt "Choose fine-tune examples (JSONL, JSON, CSV, or TSV)" ' | |
| f'default location POSIX file "{ROOT}/"\n' | |
| "POSIX path of chosenFile" | |
| ) | |
| result = subprocess.run(["osascript", "-e", script], capture_output=True, text=True) | |
| if result.returncode != 0: | |
| return None | |
| picked = result.stdout.strip() | |
| return resolve_finetune_data_path(picked) if picked else None | |
| class Handler(http.server.BaseHTTPRequestHandler): | |
| def log_message(self, fmt, *args): | |
| return | |
| def do_GET(self): | |
| parsed = urllib.parse.urlparse(self.path) | |
| if parsed.path == "/": | |
| body = UI_HTML.encode("utf-8") | |
| self.send_response(200) | |
| self.send_header("Content-Type", "text/html; charset=utf-8") | |
| self.send_header("Cache-Control", "no-store") | |
| self.send_header("Pragma", "no-cache") | |
| self.send_header("Expires", "0") | |
| self.send_header("Content-Length", str(len(body))) | |
| self.end_headers() | |
| self.wfile.write(body) | |
| return | |
| if parsed.path == "/api/status": | |
| proc = running_process() | |
| state = read_json_file(STATE_PATH, {}) | |
| if proc is not None: | |
| state["status"] = "running" | |
| elif state.get("status") in {"running", "starting", "stopping", "pause requested"}: | |
| state["status"] = "stopped" | |
| _, total_gb, _ = memory_fraction() | |
| payload = { | |
| "running": proc is not None, | |
| "state": state, | |
| "metrics": read_jsonl_tail(METRICS_PATH, limit=1500), | |
| "log": read_text_tail(TRAIN_LOG_PATH), | |
| "device": str(device_for_training(False)), | |
| "checkpoints": checkpoint_options(), | |
| "recipe_params": {name: model_param_count(apply_recipe(name)) for name in RECIPES}, | |
| "recommendation": recommended_runtime_settings(total_gb), | |
| "ram_total_gb": total_gb, | |
| "data_summary": current_data_summary(), | |
| } | |
| json_response(self, payload) | |
| return | |
| json_response(self, {"ok": False, "error": "not found"}, 404) | |
| def do_POST(self): | |
| parsed = urllib.parse.urlparse(self.path) | |
| length = int(self.headers.get("Content-Length", "0") or "0") | |
| raw = self.rfile.read(length) if length else b"{}" | |
| try: | |
| body = json.loads(raw.decode("utf-8") or "{}") | |
| except Exception: | |
| body = {} | |
| if parsed.path == "/api/start": | |
| if running_process() is not None: | |
| json_response(self, {"ok": False, "error": "Training is already running."}, 409) | |
| return | |
| try: | |
| proc = start_training_process(body) | |
| except Exception as exc: | |
| json_response(self, {"ok": False, "error": str(exc)}, 400) | |
| return | |
| json_response(self, {"ok": True, "pid": proc.pid}) | |
| return | |
| if parsed.path == "/api/resume": | |
| if running_process() is not None: | |
| json_response(self, {"ok": False, "error": "Training is already running."}, 409) | |
| return | |
| try: | |
| checkpoint = resolve_checkpoint_path(body.get("checkpoint")) | |
| proc = start_training_process(body, resume_path=checkpoint) | |
| except Exception as exc: | |
| json_response(self, {"ok": False, "error": str(exc)}, 400) | |
| return | |
| json_response(self, {"ok": True, "pid": proc.pid, "checkpoint": str(checkpoint)}) | |
| return | |
| if parsed.path == "/api/finetune": | |
| if running_process() is not None: | |
| json_response(self, {"ok": False, "error": "Training or fine-tuning is already running."}, 409) | |
| return | |
| try: | |
| checkpoint = resolve_checkpoint_path(body.get("checkpoint")) | |
| data_path = resolve_finetune_data_path(body.get("data")) | |
| proc = start_finetune_process(body, checkpoint, data_path) | |
| except Exception as exc: | |
| json_response(self, {"ok": False, "error": str(exc)}, 400) | |
| return | |
| json_response(self, {"ok": True, "pid": proc.pid, "checkpoint": str(checkpoint), "data": str(data_path)}) | |
| return | |
| if parsed.path == "/api/pick-checkpoint": | |
| try: | |
| checkpoint = pick_checkpoint_with_dialog() | |
| except Exception as exc: | |
| json_response(self, {"ok": False, "error": str(exc)}, 400) | |
| return | |
| if checkpoint is None: | |
| json_response(self, {"ok": False, "cancelled": True}) | |
| return | |
| json_response(self, {"ok": True, "checkpoint": str(checkpoint), "name": checkpoint.name}) | |
| return | |
| if parsed.path == "/api/pick-data": | |
| try: | |
| data_path = pick_data_file_with_dialog() | |
| except Exception as exc: | |
| json_response(self, {"ok": False, "error": str(exc)}, 400) | |
| return | |
| if data_path is None: | |
| json_response(self, {"ok": False, "cancelled": True}) | |
| return | |
| json_response(self, {"ok": True, "path": str(data_path), "name": data_path.name}) | |
| return | |
| if parsed.path == "/api/pause": | |
| proc = running_process() | |
| if proc is None: | |
| json_response(self, {"ok": False, "error": "Training is not running."}, 409) | |
| return | |
| PAUSE_REQUEST_PATH.write_text(str(time.time()), encoding="utf-8") | |
| write_run_state(status="pause requested", pid=proc.pid) | |
| json_response(self, {"ok": True}) | |
| return | |
| if parsed.path == "/api/stop": | |
| proc = running_process() | |
| if proc is None: | |
| json_response(self, {"ok": False, "error": "Training is not running."}, 409) | |
| return | |
| proc.terminate() | |
| write_run_state(status="stopping", pid=proc.pid) | |
| json_response(self, {"ok": True}) | |
| return | |
| if parsed.path == "/api/enhance": | |
| if running_process() is not None: | |
| json_response(self, {"ok": False, "error": "Pause or stop training before running the model."}, 409) | |
| return | |
| text = (body.get("text") or "").strip() | |
| if not text: | |
| json_response(self, {"ok": False, "error": "Enter text to enhance."}, 400) | |
| return | |
| try: | |
| checkpoint = resolve_checkpoint_path(body.get("checkpoint")) | |
| cache_path = str(checkpoint) | |
| if model_cache.get("path") != cache_path: | |
| model, tok, cfg, device = load_model_for_inference(checkpoint, force_cpu=False) | |
| model_cache.update(path=cache_path, model=model, tok=tok, cfg=cfg, device=device) | |
| result = enhance_text( | |
| text, | |
| model_cache["model"], | |
| model_cache["tok"], | |
| model_cache["cfg"], | |
| model_cache["device"], | |
| beam_size=int(body.get("beam") or 4), | |
| strategy=str(body.get("mode") or "beam"), | |
| max_new_tokens=body.get("max_new_tokens"), | |
| temperature=body.get("temperature", 0.7), | |
| top_k=body.get("top_k", 40), | |
| top_p=body.get("top_p", 0.9), | |
| length_penalty=body.get("length_penalty"), | |
| repetition_penalty=body.get("repetition_penalty"), | |
| return_details=True, | |
| ) | |
| except Exception as exc: | |
| json_response(self, {"ok": False, "error": str(exc)}, 500) | |
| return | |
| json_response( | |
| self, | |
| { | |
| "ok": True, | |
| "output": result["output"], | |
| "tokens": result["tokens"] if body.get("show_tokens") else [], | |
| "settings": result["settings"], | |
| "checkpoint": str(checkpoint), | |
| }, | |
| ) | |
| return | |
| json_response(self, {"ok": False, "error": "not found"}, 404) | |
| server = http.server.ThreadingHTTPServer((args.host, args.port), Handler) | |
| url = f"http://{args.host}:{args.port}" | |
| log(f"EDEN UI running at {url}") | |
| log("Press Ctrl+C to stop the dashboard. New training checkpoints go into eden_system/training_sessions.") | |
| try: | |
| server.serve_forever() | |
| except KeyboardInterrupt: | |
| log("\nStopping EDEN UI...") | |
| finally: | |
| proc = running_process() | |
| if proc is not None and args.stop_training_on_exit: | |
| proc.terminate() | |
| fh = log_file_holder.get("file") | |
| if fh: | |
| fh.close() | |
| server.server_close() | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description="EDEN: train and run a from-scratch text-enhancement encoder-decoder model.", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| sub = parser.add_subparsers(dest="command") | |
| p = sub.add_parser("install", help="Install/update Python dependencies.") | |
| p.set_defaults(func=lambda args: install_deps()) | |
| p = sub.add_parser("prepare", help="Build the training dataset and tokenizer.") | |
| p.add_argument("--recipe", choices=RECIPES.keys(), default="m5-smart") | |
| p.add_argument("--max-pairs", type=int, default=None) | |
| p.add_argument("--vocab-size", type=int, default=None) | |
| p.add_argument("--data", type=str, default=None, help="Optional custom JSONL/JSON/CSV/TSV pairs.") | |
| p.add_argument("--include-c4", action="store_true", help="Include the optional large C4-200M GEC stream.") | |
| p.add_argument("--force", action="store_true", help="Rebuild even if prepared files already exist.") | |
| p.set_defaults(func=command_prepare) | |
| p = sub.add_parser("train", help="Train EDEN from scratch or resume.") | |
| p.add_argument("--recipe", choices=RECIPES.keys(), default="m5-smart") | |
| p.add_argument("--epochs", type=int, default=None) | |
| p.add_argument("--max-pairs", type=int, default=None) | |
| p.add_argument("--lr", type=float, default=None) | |
| p.add_argument("--max-len", type=int, default=None, help="Context length in tokens.") | |
| p.add_argument("--batch-size", type=int, default=None) | |
| p.add_argument("--grad-accum", type=int, default=None) | |
| p.add_argument("--memory-stop-fraction", type=float, default=None) | |
| p.add_argument("--resume", type=str, default=None) | |
| p.add_argument("--data", type=str, default=None, help="Optional custom data to include during auto-prepare.") | |
| p.add_argument("--include-c4", action="store_true") | |
| p.add_argument("--rebuild-data", action="store_true") | |
| p.add_argument("--force-cpu", action="store_true") | |
| p.set_defaults(func=command_train) | |
| p = sub.add_parser("finetune", help="Continue from a checkpoint on your own pairs.") | |
| p.add_argument("--data", required=True, help="JSONL/JSON/CSV/TSV pairs with input/target columns.") | |
| p.add_argument("--checkpoint", type=str, default=None) | |
| p.add_argument("--epochs", type=int, default=3) | |
| p.add_argument("--lr", type=float, default=8e-5) | |
| p.add_argument("--max-pairs", type=int, default=None) | |
| p.add_argument("--mix-base", action="store_true", help="Mix some base data to reduce forgetting.") | |
| p.add_argument("--force-cpu", action="store_true") | |
| p.set_defaults(func=command_finetune) | |
| p = sub.add_parser("enhance", help="Enhance one piece of text.") | |
| p.add_argument("text", nargs="*", help="Text to enhance. If empty, reads stdin.") | |
| p.add_argument("--checkpoint", type=str, default=None) | |
| p.add_argument("--beam", type=int, default=None) | |
| p.add_argument("--force-cpu", action="store_true") | |
| p.set_defaults(func=command_enhance) | |
| p = sub.add_parser("interactive", help="Run an interactive local enhancer.") | |
| p.add_argument("--checkpoint", type=str, default=None) | |
| p.add_argument("--beam", type=int, default=None) | |
| p.add_argument("--force-cpu", action="store_true") | |
| p.set_defaults(func=command_interactive) | |
| p = sub.add_parser("eval", help="Evaluate a checkpoint on prepared or custom pairs.") | |
| p.add_argument("--checkpoint", type=str, default=None) | |
| p.add_argument("--data", type=str, default=None) | |
| p.add_argument("--samples", type=int, default=None) | |
| p.add_argument("--max-batches", type=int, default=100) | |
| p.add_argument("--force-cpu", action="store_true") | |
| p.set_defaults(func=command_eval) | |
| p = sub.add_parser("info", help="Show system, recipe, data, and checkpoint info.") | |
| p.set_defaults(func=command_info) | |
| p = sub.add_parser("smoke", help="Run a tiny forward/backward test.") | |
| p.add_argument("--force-cpu", action="store_true") | |
| p.set_defaults(func=command_smoke) | |
| p = sub.add_parser("ui", help="Start the lightweight local web dashboard.") | |
| p.add_argument("--host", default="127.0.0.1") | |
| p.add_argument("--port", type=int, default=7860) | |
| p.add_argument( | |
| "--stop-training-on-exit", | |
| action="store_true", | |
| help="Also stop the training subprocess when the dashboard exits.", | |
| ) | |
| p.set_defaults(func=command_ui) | |
| return parser | |
| def main(argv: list[str] | None = None) -> None: | |
| parser = build_parser() | |
| args = parser.parse_args(argv) | |
| if not hasattr(args, "func"): | |
| parser.print_help() | |
| return | |
| args.func(args) | |