feat: add authenticated remote control UI and ngrok launcher
Browse files- README.md +14 -4
- data/bootstrap.py +105 -0
- data/pipeline.py +93 -0
- docs/COMMANDS.md +129 -18
- scripts/run_data_pipeline.sh +1 -1
- serve/control_plane.py +80 -3
- serve/server.py +142 -17
- serve/server_cpu.py +29 -1
- serve/static/index.html +121 -1
- tests/test_control_plane.py +13 -0
- tests/test_data_pipeline.py +52 -0
- tests/test_servers.py +44 -0
- tests/test_tokenizer.py +20 -0
- tokenizer/train_tokenizer.py +32 -13
README.md
CHANGED
|
@@ -6,8 +6,16 @@ Designed to be both educational and functional, SAGE can be trained, fine-tuned,
|
|
| 6 |
|
| 7 |
---
|
| 8 |
|
| 9 |
-
##
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
---
|
| 13 |
|
|
@@ -92,11 +100,13 @@ Once launched, simply type your message to chat with SAGE. The system uses a rol
|
|
| 92 |
The FastAPI server now serves a minimal remote control panel at `/`.
|
| 93 |
|
| 94 |
```bash
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
python -m uvicorn serve.server:app --host 0.0.0.0 --port 8000
|
| 97 |
```
|
| 98 |
|
| 99 |
-
Open the server root in a browser, log in with `SAGE_WEB_PASSWORD`, and use preset actions or a raw command box to drive the local repo from the UI. The included `test.ipynb` notebook starts the real app and exposes it through ngrok for Colab.
|
| 100 |
|
| 101 |
---
|
| 102 |
|
|
|
|
| 6 |
|
| 7 |
---
|
| 8 |
|
| 9 |
+
## Command Guide
|
| 10 |
+
|
| 11 |
+
Use the repo command sheet at `docs/COMMANDS.md` for the current end-to-end flow:
|
| 12 |
+
|
| 13 |
+
- bootstrap starter JSONL data
|
| 14 |
+
- train the tokenizer
|
| 15 |
+
- build parquet shards
|
| 16 |
+
- launch training
|
| 17 |
+
- serve the model
|
| 18 |
+
- use the browser UI and chat endpoint
|
| 19 |
|
| 20 |
---
|
| 21 |
|
|
|
|
| 100 |
The FastAPI server now serves a minimal remote control panel at `/`.
|
| 101 |
|
| 102 |
```bash
|
| 103 |
+
$env:SAGE_WEB_PASSWORD="change-me"
|
| 104 |
+
$env:SAGE_CHECKPOINT_DIR="runs/sage-1b"
|
| 105 |
+
$env:SAGE_TOKENIZER_MODEL="tokenizer/tokenizer.model"
|
| 106 |
python -m uvicorn serve.server:app --host 0.0.0.0 --port 8000
|
| 107 |
```
|
| 108 |
|
| 109 |
+
Open the server root in a browser, log in with `SAGE_WEB_PASSWORD`, and use preset actions or a raw command box to drive the local repo from the UI. The browser UI now also includes direct text chat through `/chat` when a tokenizer is available. If no checkpoint is loaded yet, the UI warns that outputs are coming from randomly initialized weights. The included `test.ipynb` notebook starts the real app and exposes it through ngrok for Colab.
|
| 110 |
|
| 111 |
---
|
| 112 |
|
data/bootstrap.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bootstrap small raw corpora for tokenizer and smoke-training flows."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
BOOTSTRAP_CORPORA: dict[str, list[str]] = {
|
| 11 |
+
"general_web": [
|
| 12 |
+
"Large language models learn by predicting the next token in a sequence, but useful systems depend just as much on data quality as on architecture size.",
|
| 13 |
+
"A good training corpus mixes clean prose, documentation, dialogue, and reference material so the model sees multiple ways humans structure information.",
|
| 14 |
+
"When you build a local model, start with small smoke runs, measure loss curves, and only then scale sequence length, batch size, and parameter count.",
|
| 15 |
+
"The fastest way to waste compute is to train on noisy duplicated text without checking tokenization, filtering, and validation splits first.",
|
| 16 |
+
"Evaluation should include both regression tests and qualitative prompts because perplexity alone does not tell you whether a model follows instructions well.",
|
| 17 |
+
"A serving stack usually needs checkpoint loading, tokenization, generation settings, and telemetry before it is practical for iterative experiments.",
|
| 18 |
+
],
|
| 19 |
+
"code": [
|
| 20 |
+
"def running_mean(values):\n total = 0.0\n result = []\n for index, value in enumerate(values, start=1):\n total += value\n result.append(total / index)\n return result",
|
| 21 |
+
"class TextBatch:\n def __init__(self, items):\n self.items = list(items)\n\n def join(self, sep='\\n'):\n return sep.join(self.items)",
|
| 22 |
+
"from pathlib import Path\n\ndef read_text(path):\n return Path(path).read_text(encoding='utf-8')",
|
| 23 |
+
"def clamp(value, lo, hi):\n if value < lo:\n return lo\n if value > hi:\n return hi\n return value",
|
| 24 |
+
"def format_metrics(step, loss):\n return f'step={step} loss={loss:.4f}'",
|
| 25 |
+
"def greedy_decode(logits):\n import torch\n return int(torch.argmax(logits, dim=-1).item())",
|
| 26 |
+
],
|
| 27 |
+
"math_science": [
|
| 28 |
+
"The derivative of x squared is 2x, and gradient-based optimization uses derivatives to decide how to update model parameters.",
|
| 29 |
+
"Perplexity is the exponential of average negative log likelihood; lower perplexity means the model assigns higher probability to the observed sequence.",
|
| 30 |
+
"If a batch contains B sequences of length T, then the number of next-token predictions is roughly B times T.",
|
| 31 |
+
"Matrix multiplication is central to transformer inference because projections for queries, keys, values, and feed-forward layers are all linear maps.",
|
| 32 |
+
"Softmax converts raw logits into a probability distribution by exponentiating each value and dividing by the sum of exponentials.",
|
| 33 |
+
"The context window bounds how many previous tokens the decoder can attend to while producing the next token.",
|
| 34 |
+
],
|
| 35 |
+
"multilingual": [
|
| 36 |
+
"English: Training data should be filtered, deduplicated, and documented before long runs begin.",
|
| 37 |
+
"Hindi: अच्छे मॉडल के लिए साफ और विविध डेटा उतना ही जरूरी है जितना अच्छा आर्किटेक्चर।",
|
| 38 |
+
"Arabic: جودة البيانات تؤثر على جودة النموذج بقدر تأثير حجم النموذج نفسه.",
|
| 39 |
+
"Chinese: 在开始长时间训练之前,先做小规模验证可以节省大量计算资源。",
|
| 40 |
+
"Spanish: Un buen flujo de datos incluye limpieza, deduplicacion y particiones reproducibles.",
|
| 41 |
+
"French: Un modele utile demande des donnees propres, des tests et une boucle d'evaluation simple.",
|
| 42 |
+
],
|
| 43 |
+
"synthetic": [
|
| 44 |
+
"[INST] Explain why deduplication matters before tokenizer training. [/INST] Deduplication prevents repeated passages from dominating merge statistics and reduces wasted compute during later model training.",
|
| 45 |
+
"[INST] Write a short checklist for a smoke training run. [/INST] Verify shards exist, verify tokenizer loads, run a short job, inspect metrics, and confirm checkpoints are written.",
|
| 46 |
+
"[INST] How do you know a dataset is too noisy? [/INST] Look for low alpha ratios, malformed markup, repeated content, excessive boilerplate, or corrupted encoding.",
|
| 47 |
+
"[INST] What is the purpose of a validation split? [/INST] It gives you held-out data for tracking generalization and for catching regressions during training.",
|
| 48 |
+
"[INST] Summarize the role of the tokenizer. [/INST] The tokenizer maps raw text into stable token ids the model can consume during training and generation.",
|
| 49 |
+
"[INST] Why keep metadata with each record? [/INST] Metadata helps audit provenance, quality, language mix, and filtering decisions across the pipeline.",
|
| 50 |
+
],
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _pad_sample(text: str, minimum_chars: int = 240) -> str:
|
| 55 |
+
"""Extend short bootstrap samples so they survive the default filters."""
|
| 56 |
+
trailer = (
|
| 57 |
+
" This bootstrap record is intentionally longer so the repo's default "
|
| 58 |
+
"quality filters keep it during smoke-test data preparation and tokenizer training."
|
| 59 |
+
)
|
| 60 |
+
padded = text.strip()
|
| 61 |
+
while len(padded) < minimum_chars:
|
| 62 |
+
padded += trailer
|
| 63 |
+
return padded
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def bootstrap_raw_corpora(output_dir: str = "data/raw", overwrite: bool = False) -> dict[str, int]:
|
| 67 |
+
"""Write one small JSONL corpus file per registered source."""
|
| 68 |
+
root = Path(output_dir)
|
| 69 |
+
root.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
counts: dict[str, int] = {}
|
| 71 |
+
for source_name, samples in BOOTSTRAP_CORPORA.items():
|
| 72 |
+
path = root / f"{source_name}.jsonl"
|
| 73 |
+
if path.exists() and not overwrite:
|
| 74 |
+
existing = sum(1 for _ in path.open("r", encoding="utf-8"))
|
| 75 |
+
counts[source_name] = existing
|
| 76 |
+
continue
|
| 77 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 78 |
+
for index, text in enumerate(samples, start=1):
|
| 79 |
+
payload = {
|
| 80 |
+
"id": f"{source_name}-{index:04d}",
|
| 81 |
+
"text": _pad_sample(text),
|
| 82 |
+
"source_name": source_name,
|
| 83 |
+
}
|
| 84 |
+
handle.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
| 85 |
+
counts[source_name] = len(samples)
|
| 86 |
+
return counts
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def build_argparser() -> argparse.ArgumentParser:
|
| 90 |
+
"""Build the CLI parser for corpus bootstrapping."""
|
| 91 |
+
parser = argparse.ArgumentParser(description="Create small JSONL corpora for SAGE smoke runs.")
|
| 92 |
+
parser.add_argument("--output-dir", default="data/raw", help="Directory for raw JSONL corpus files.")
|
| 93 |
+
parser.add_argument("--overwrite", action="store_true", help="Replace any existing bootstrap corpus files.")
|
| 94 |
+
return parser
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def main() -> None:
|
| 98 |
+
"""CLI entrypoint."""
|
| 99 |
+
args = build_argparser().parse_args()
|
| 100 |
+
summary = bootstrap_raw_corpora(output_dir=args.output_dir, overwrite=args.overwrite)
|
| 101 |
+
print(json.dumps({"output_dir": args.output_dir, "sources": summary}, indent=2, ensure_ascii=False))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
main()
|
data/pipeline.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""End-to-end raw-corpus to Parquet shard pipeline."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import sentencepiece as spm
|
| 10 |
+
|
| 11 |
+
from data.dedup import deduplicate_records
|
| 12 |
+
from data.filter import FilterConfig, filter_record
|
| 13 |
+
from data.ingest import SOURCE_REGISTRY, stream_source
|
| 14 |
+
from data.shard import ShardConfig, write_shards
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _select_sources(names: list[str] | None) -> tuple:
|
| 18 |
+
if not names:
|
| 19 |
+
return SOURCE_REGISTRY
|
| 20 |
+
wanted = set(names)
|
| 21 |
+
selected = tuple(spec for spec in SOURCE_REGISTRY if spec.name in wanted)
|
| 22 |
+
missing = sorted(wanted - {spec.name for spec in selected})
|
| 23 |
+
if missing:
|
| 24 |
+
raise ValueError(f"Unknown sources: {', '.join(missing)}")
|
| 25 |
+
return selected
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def build_records(source_names: list[str] | None = None, limit_per_source: int | None = None) -> list[dict[str, object]]:
|
| 29 |
+
"""Load, filter, and deduplicate records from the configured raw sources."""
|
| 30 |
+
records: list[dict[str, object]] = []
|
| 31 |
+
for spec in _select_sources(source_names):
|
| 32 |
+
source_records: list[dict[str, object]] = []
|
| 33 |
+
for record in stream_source(spec):
|
| 34 |
+
filtered = filter_record(record, FilterConfig())
|
| 35 |
+
if filtered is None:
|
| 36 |
+
continue
|
| 37 |
+
source_records.append(filtered)
|
| 38 |
+
if limit_per_source is not None and len(source_records) >= limit_per_source:
|
| 39 |
+
break
|
| 40 |
+
records.extend(source_records)
|
| 41 |
+
return deduplicate_records(records)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run_pipeline(
|
| 45 |
+
tokenizer_model: str,
|
| 46 |
+
output_dir: str = "data/processed",
|
| 47 |
+
source_names: list[str] | None = None,
|
| 48 |
+
shard_size: int = 2048,
|
| 49 |
+
limit_per_source: int | None = None,
|
| 50 |
+
) -> dict[str, object]:
|
| 51 |
+
"""Create Parquet shards from the current raw JSONL corpora."""
|
| 52 |
+
tokenizer = spm.SentencePieceProcessor()
|
| 53 |
+
tokenizer.load(tokenizer_model)
|
| 54 |
+
records = build_records(source_names=source_names, limit_per_source=limit_per_source)
|
| 55 |
+
manifest = write_shards(records, tokenizer, ShardConfig(output_dir=output_dir, shard_size=shard_size))
|
| 56 |
+
summary = {
|
| 57 |
+
"tokenizer_model": tokenizer_model,
|
| 58 |
+
"output_dir": output_dir,
|
| 59 |
+
"records": len(records),
|
| 60 |
+
"sources": source_names or [spec.name for spec in SOURCE_REGISTRY],
|
| 61 |
+
"manifest": manifest,
|
| 62 |
+
}
|
| 63 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
| 64 |
+
(Path(output_dir) / "pipeline_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 65 |
+
return summary
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def build_argparser() -> argparse.ArgumentParser:
|
| 69 |
+
"""Build the CLI parser."""
|
| 70 |
+
parser = argparse.ArgumentParser(description="Filter, deduplicate, tokenize, and shard SAGE raw corpora.")
|
| 71 |
+
parser.add_argument("--tokenizer-model", default="tokenizer/tokenizer.model", help="SentencePiece tokenizer model.")
|
| 72 |
+
parser.add_argument("--output-dir", default="data/processed", help="Destination directory for parquet shards.")
|
| 73 |
+
parser.add_argument("--sources", nargs="*", default=None, help="Subset of source names from data.ingest.SOURCE_REGISTRY.")
|
| 74 |
+
parser.add_argument("--shard-size", type=int, default=2048, help="Rows per parquet shard.")
|
| 75 |
+
parser.add_argument("--limit-per-source", type=int, default=None, help="Optional cap for smoke-testing.")
|
| 76 |
+
return parser
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> None:
|
| 80 |
+
"""CLI entrypoint."""
|
| 81 |
+
args = build_argparser().parse_args()
|
| 82 |
+
summary = run_pipeline(
|
| 83 |
+
tokenizer_model=args.tokenizer_model,
|
| 84 |
+
output_dir=args.output_dir,
|
| 85 |
+
source_names=args.sources,
|
| 86 |
+
shard_size=args.shard_size,
|
| 87 |
+
limit_per_source=args.limit_per_source,
|
| 88 |
+
)
|
| 89 |
+
print(json.dumps(summary, indent=2))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
docs/COMMANDS.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# SAGE Commands
|
| 2 |
|
| 3 |
-
This
|
| 4 |
|
| 5 |
## Install
|
| 6 |
|
|
@@ -14,25 +14,81 @@ pip install -r requirements.txt
|
|
| 14 |
pytest -q
|
| 15 |
```
|
| 16 |
|
| 17 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
```bash
|
| 20 |
python -m tokenizer.train_tokenizer \
|
| 21 |
-
--input data/raw/general_web.
|
| 22 |
-
--model-prefix tokenizer/tokenizer \
|
| 23 |
--model-prefix tokenizer/tokenizer \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
```
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
```bash
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
```
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
```bash
|
| 35 |
python -m train.trainer \
|
|
|
|
|
|
|
| 36 |
--train-shards data/processed/shard-00000.parquet \
|
| 37 |
--validation-shards data/processed/shard-00001.parquet \
|
| 38 |
--output-dir runs/smoke \
|
|
@@ -40,7 +96,7 @@ python -m train.trainer \
|
|
| 40 |
--disable-wandb
|
| 41 |
```
|
| 42 |
|
| 43 |
-
|
| 44 |
|
| 45 |
```bash
|
| 46 |
python -m train.trainer \
|
|
@@ -51,38 +107,93 @@ python -m train.trainer \
|
|
| 51 |
--output-dir runs/sage-1b
|
| 52 |
```
|
| 53 |
|
| 54 |
-
##
|
|
|
|
|
|
|
| 55 |
|
| 56 |
```bash
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
| 58 |
```
|
| 59 |
|
| 60 |
-
|
| 61 |
|
| 62 |
```bash
|
| 63 |
-
|
| 64 |
-
|
| 65 |
```
|
| 66 |
|
| 67 |
-
|
| 68 |
|
| 69 |
```bash
|
| 70 |
-
|
| 71 |
bash scripts/run_serve_cpu.sh
|
| 72 |
```
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
|
| 78 |
```bash
|
| 79 |
curl http://127.0.0.1:8000/health
|
| 80 |
```
|
| 81 |
|
| 82 |
-
|
| 83 |
|
| 84 |
```bash
|
| 85 |
curl -X POST http://127.0.0.1:8000/generate \
|
| 86 |
-H "Content-Type: application/json" \
|
| 87 |
-d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
|
| 88 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# SAGE Commands
|
| 2 |
|
| 3 |
+
This is the repo's current command reference for data preparation, tokenizer training, model training, serving, browser control, and validation.
|
| 4 |
|
| 5 |
## Install
|
| 6 |
|
|
|
|
| 14 |
pytest -q
|
| 15 |
```
|
| 16 |
|
| 17 |
+
## 1. Create a starter dataset
|
| 18 |
+
|
| 19 |
+
This repo does not ship a large training corpus. The fastest way to unblock the pipeline is to generate the built-in smoke dataset first:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
python -m data.bootstrap --output-dir data/raw --overwrite
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
That writes JSONL files like:
|
| 26 |
+
|
| 27 |
+
```text
|
| 28 |
+
data/raw/general_web.jsonl
|
| 29 |
+
data/raw/code.jsonl
|
| 30 |
+
data/raw/math_science.jsonl
|
| 31 |
+
data/raw/multilingual.jsonl
|
| 32 |
+
data/raw/synthetic.jsonl
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
If you want to use your own corpus, put JSONL records in the same folder with at least a `text` field:
|
| 36 |
+
|
| 37 |
+
```json
|
| 38 |
+
{"text": "your training sample here"}
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## 2. Train the tokenizer
|
| 42 |
+
|
| 43 |
+
The tokenizer trainer now accepts plain text files or JSONL files.
|
| 44 |
|
| 45 |
```bash
|
| 46 |
python -m tokenizer.train_tokenizer \
|
| 47 |
+
--input data/raw/general_web.jsonl data/raw/code.jsonl data/raw/math_science.jsonl data/raw/multilingual.jsonl data/raw/synthetic.jsonl \
|
|
|
|
| 48 |
--model-prefix tokenizer/tokenizer \
|
| 49 |
+
--vocab-size 4096 \
|
| 50 |
+
--training-text tokenizer/training_corpus.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## 3. Validate the tokenizer
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
python -m tokenizer.validate_tokenizer tokenizer/tokenizer.model
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## 4. Build parquet shards
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
python -m data.pipeline \
|
| 63 |
+
--tokenizer-model tokenizer/tokenizer.model \
|
| 64 |
+
--output-dir data/processed \
|
| 65 |
+
--shard-size 128
|
| 66 |
```
|
| 67 |
|
| 68 |
+
For a short smoke run:
|
| 69 |
|
| 70 |
```bash
|
| 71 |
+
python -m data.pipeline \
|
| 72 |
+
--tokenizer-model tokenizer/tokenizer.model \
|
| 73 |
+
--output-dir data/processed \
|
| 74 |
+
--shard-size 32 \
|
| 75 |
+
--limit-per-source 4
|
| 76 |
```
|
| 77 |
|
| 78 |
+
The shell helper now points to the real data pipeline:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
bash scripts/run_data_pipeline.sh --tokenizer-model tokenizer/tokenizer.model --output-dir data/processed
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## 5. Start training
|
| 85 |
+
|
| 86 |
+
Smoke run:
|
| 87 |
|
| 88 |
```bash
|
| 89 |
python -m train.trainer \
|
| 90 |
+
--model-config configs/model/1b.yaml \
|
| 91 |
+
--schedule-config configs/train/schedule.yaml \
|
| 92 |
--train-shards data/processed/shard-00000.parquet \
|
| 93 |
--validation-shards data/processed/shard-00001.parquet \
|
| 94 |
--output-dir runs/smoke \
|
|
|
|
| 96 |
--disable-wandb
|
| 97 |
```
|
| 98 |
|
| 99 |
+
Longer run:
|
| 100 |
|
| 101 |
```bash
|
| 102 |
python -m train.trainer \
|
|
|
|
| 107 |
--output-dir runs/sage-1b
|
| 108 |
```
|
| 109 |
|
| 110 |
+
## 6. Serve the model
|
| 111 |
+
|
| 112 |
+
GPU/PyTorch server:
|
| 113 |
|
| 114 |
```bash
|
| 115 |
+
$env:SAGE_WEB_PASSWORD="change-me"
|
| 116 |
+
$env:SAGE_CHECKPOINT_DIR="runs/sage-1b"
|
| 117 |
+
$env:SAGE_TOKENIZER_MODEL="tokenizer/tokenizer.model"
|
| 118 |
+
python -m uvicorn serve.server:app --host 0.0.0.0 --port 8000
|
| 119 |
```
|
| 120 |
|
| 121 |
+
CPU control-plane server:
|
| 122 |
|
| 123 |
```bash
|
| 124 |
+
$env:SAGE_WEB_PASSWORD="change-me"
|
| 125 |
+
python -m uvicorn serve.server_cpu:app --host 0.0.0.0 --port 8001
|
| 126 |
```
|
| 127 |
|
| 128 |
+
Helper scripts:
|
| 129 |
|
| 130 |
```bash
|
| 131 |
+
bash scripts/run_serve.sh
|
| 132 |
bash scripts/run_serve_cpu.sh
|
| 133 |
```
|
| 134 |
|
| 135 |
+
## 7. Browser control panel
|
| 136 |
+
|
| 137 |
+
Open the server root:
|
| 138 |
+
|
| 139 |
+
```text
|
| 140 |
+
http://127.0.0.1:8000/
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
The browser UI now supports:
|
| 144 |
+
|
| 145 |
+
- login with `SAGE_WEB_PASSWORD`
|
| 146 |
+
- dataset bootstrap preset
|
| 147 |
+
- shard-building preset
|
| 148 |
+
- tokenizer/train/eval/server presets
|
| 149 |
+
- raw shell commands
|
| 150 |
+
- live job logs
|
| 151 |
+
- direct model chat through `/chat`
|
| 152 |
+
|
| 153 |
+
## 8. API commands
|
| 154 |
|
| 155 |
+
Health:
|
| 156 |
|
| 157 |
```bash
|
| 158 |
curl http://127.0.0.1:8000/health
|
| 159 |
```
|
| 160 |
|
| 161 |
+
Generate from token ids:
|
| 162 |
|
| 163 |
```bash
|
| 164 |
curl -X POST http://127.0.0.1:8000/generate \
|
| 165 |
-H "Content-Type: application/json" \
|
| 166 |
-d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
|
| 167 |
```
|
| 168 |
+
|
| 169 |
+
Chat from text:
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
curl -X POST http://127.0.0.1:8000/chat \
|
| 173 |
+
-H "Content-Type: application/json" \
|
| 174 |
+
-d "{\"prompt\": \"Explain the training flow in this repo.\", \"max_new_tokens\": 64}"
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
Chat status:
|
| 178 |
+
|
| 179 |
+
```bash
|
| 180 |
+
curl http://127.0.0.1:8000/chat/status
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
## 9. Evaluation
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
python -m eval.run_benchmarks
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
Or use the helper:
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
bash scripts/run_eval.sh
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
## 10. Hugging Face sync
|
| 196 |
+
|
| 197 |
+
```bash
|
| 198 |
+
python hf_push.py
|
| 199 |
+
```
|
scripts/run_data_pipeline.sh
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
set -euo pipefail
|
| 3 |
|
| 4 |
-
python -m
|
|
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
set -euo pipefail
|
| 3 |
|
| 4 |
+
python -m data.pipeline "$@"
|
serve/control_plane.py
CHANGED
|
@@ -158,6 +158,34 @@ def _build_presets(enable_generate: bool) -> list[CommandPreset]:
|
|
| 158 |
"Call the local /health API and show the JSON response.",
|
| 159 |
"api",
|
| 160 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
CommandPreset(
|
| 162 |
"serve_gpu",
|
| 163 |
"Serve GPU",
|
|
@@ -188,7 +216,7 @@ def _build_presets(enable_generate: bool) -> list[CommandPreset]:
|
|
| 188 |
"input_paths",
|
| 189 |
"Input Paths",
|
| 190 |
kind="textarea",
|
| 191 |
-
placeholder="data/raw/general_web.
|
| 192 |
required=True,
|
| 193 |
),
|
| 194 |
PresetField("model_prefix", "Model Prefix", default="tokenizer/tokenizer"),
|
|
@@ -472,7 +500,49 @@ def _api_response(handler: Callable[[dict[str, Any]], dict[str, Any]], args: dic
|
|
| 472 |
return {"kind": "api", "result": handler(args)}
|
| 473 |
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
def _build_command_for_preset(preset_id: str, args: dict[str, Any]) -> list[str] | str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
if preset_id == "serve_gpu":
|
| 477 |
return [
|
| 478 |
sys.executable,
|
|
@@ -607,6 +677,7 @@ def build_control_router(api_handlers: dict[str, Callable[[dict[str, Any]], dict
|
|
| 607 |
preset = preset_map.get(payload.preset_id)
|
| 608 |
if preset is None:
|
| 609 |
raise HTTPException(status_code=404, detail=f"Unknown preset: {payload.preset_id}")
|
|
|
|
| 610 |
if preset.mode == "api":
|
| 611 |
handler = api_handlers.get(payload.preset_id)
|
| 612 |
if handler is None:
|
|
@@ -614,11 +685,17 @@ def build_control_router(api_handlers: dict[str, Callable[[dict[str, Any]], dict
|
|
| 614 |
return _api_response(handler, payload.args)
|
| 615 |
command = _build_command_for_preset(payload.preset_id, payload.args)
|
| 616 |
mode = "shell" if isinstance(command, str) else "job"
|
| 617 |
-
|
|
|
|
|
|
|
|
|
|
| 618 |
return {"kind": "job", "job": job.to_dict()}
|
| 619 |
if payload.command:
|
| 620 |
cwd = payload.cwd or str(REPO_ROOT)
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
| 622 |
return {"kind": "job", "job": job.to_dict()}
|
| 623 |
raise HTTPException(status_code=400, detail="Provide either preset_id or command.")
|
| 624 |
|
|
|
|
| 158 |
"Call the local /health API and show the JSON response.",
|
| 159 |
"api",
|
| 160 |
),
|
| 161 |
+
CommandPreset(
|
| 162 |
+
"data_bootstrap",
|
| 163 |
+
"Bootstrap Dataset",
|
| 164 |
+
"Create small JSONL corpora under data/raw for tokenizer and smoke-training runs.",
|
| 165 |
+
"job",
|
| 166 |
+
(
|
| 167 |
+
PresetField("output_dir", "Output Dir", default="data/raw"),
|
| 168 |
+
PresetField("overwrite", "Overwrite Existing Files", kind="boolean", default=False),
|
| 169 |
+
),
|
| 170 |
+
),
|
| 171 |
+
CommandPreset(
|
| 172 |
+
"data_pipeline",
|
| 173 |
+
"Build Data Shards",
|
| 174 |
+
"Filter raw JSONL corpora, deduplicate them, then write parquet shards with the trained tokenizer.",
|
| 175 |
+
"job",
|
| 176 |
+
(
|
| 177 |
+
PresetField("tokenizer_model", "Tokenizer Model", default="tokenizer/tokenizer.model"),
|
| 178 |
+
PresetField("output_dir", "Output Dir", default="data/processed"),
|
| 179 |
+
PresetField(
|
| 180 |
+
"sources",
|
| 181 |
+
"Sources",
|
| 182 |
+
kind="textarea",
|
| 183 |
+
placeholder="general_web\ncode\nmath_science\nmultilingual\nsynthetic",
|
| 184 |
+
),
|
| 185 |
+
PresetField("shard_size", "Shard Size", kind="number", default=2048),
|
| 186 |
+
PresetField("limit_per_source", "Limit Per Source", kind="number", default=0),
|
| 187 |
+
),
|
| 188 |
+
),
|
| 189 |
CommandPreset(
|
| 190 |
"serve_gpu",
|
| 191 |
"Serve GPU",
|
|
|
|
| 216 |
"input_paths",
|
| 217 |
"Input Paths",
|
| 218 |
kind="textarea",
|
| 219 |
+
placeholder="data/raw/general_web.jsonl\ndata/raw/code.jsonl",
|
| 220 |
required=True,
|
| 221 |
),
|
| 222 |
PresetField("model_prefix", "Model Prefix", default="tokenizer/tokenizer"),
|
|
|
|
| 500 |
return {"kind": "api", "result": handler(args)}
|
| 501 |
|
| 502 |
|
| 503 |
+
def _validate_preset_args(preset: CommandPreset, args: dict[str, Any]) -> None:
|
| 504 |
+
missing: list[str] = []
|
| 505 |
+
for field in preset.fields:
|
| 506 |
+
if not field.required:
|
| 507 |
+
continue
|
| 508 |
+
value = args.get(field.name)
|
| 509 |
+
if value is None:
|
| 510 |
+
missing.append(field.label)
|
| 511 |
+
continue
|
| 512 |
+
if isinstance(value, str) and not value.strip():
|
| 513 |
+
missing.append(field.label)
|
| 514 |
+
continue
|
| 515 |
+
if isinstance(value, list) and not value:
|
| 516 |
+
missing.append(field.label)
|
| 517 |
+
if missing:
|
| 518 |
+
raise HTTPException(status_code=400, detail=f"Missing required fields: {', '.join(missing)}")
|
| 519 |
+
|
| 520 |
+
|
| 521 |
def _build_command_for_preset(preset_id: str, args: dict[str, Any]) -> list[str] | str:
|
| 522 |
+
if preset_id == "data_bootstrap":
|
| 523 |
+
command = [sys.executable, "-m", "data.bootstrap", "--output-dir", str(args.get("output_dir") or "data/raw")]
|
| 524 |
+
if bool(args.get("overwrite", False)):
|
| 525 |
+
command.append("--overwrite")
|
| 526 |
+
return command
|
| 527 |
+
if preset_id == "data_pipeline":
|
| 528 |
+
command = [
|
| 529 |
+
sys.executable,
|
| 530 |
+
"-m",
|
| 531 |
+
"data.pipeline",
|
| 532 |
+
"--tokenizer-model",
|
| 533 |
+
str(args.get("tokenizer_model") or "tokenizer/tokenizer.model"),
|
| 534 |
+
"--output-dir",
|
| 535 |
+
str(args.get("output_dir") or "data/processed"),
|
| 536 |
+
"--shard-size",
|
| 537 |
+
str(_parse_number(args.get("shard_size"), 2048)),
|
| 538 |
+
]
|
| 539 |
+
sources = _split_multi_value(args.get("sources"))
|
| 540 |
+
if sources:
|
| 541 |
+
command.extend(["--sources", *sources])
|
| 542 |
+
limit_per_source = _parse_number(args.get("limit_per_source"), 0)
|
| 543 |
+
if limit_per_source > 0:
|
| 544 |
+
command.extend(["--limit-per-source", str(limit_per_source)])
|
| 545 |
+
return command
|
| 546 |
if preset_id == "serve_gpu":
|
| 547 |
return [
|
| 548 |
sys.executable,
|
|
|
|
| 677 |
preset = preset_map.get(payload.preset_id)
|
| 678 |
if preset is None:
|
| 679 |
raise HTTPException(status_code=404, detail=f"Unknown preset: {payload.preset_id}")
|
| 680 |
+
_validate_preset_args(preset, payload.args)
|
| 681 |
if preset.mode == "api":
|
| 682 |
handler = api_handlers.get(payload.preset_id)
|
| 683 |
if handler is None:
|
|
|
|
| 685 |
return _api_response(handler, payload.args)
|
| 686 |
command = _build_command_for_preset(payload.preset_id, payload.args)
|
| 687 |
mode = "shell" if isinstance(command, str) else "job"
|
| 688 |
+
try:
|
| 689 |
+
job = CONTROL_MANAGER.start_job(preset.label, command, cwd=str(REPO_ROOT), mode=mode)
|
| 690 |
+
except OSError as exc:
|
| 691 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 692 |
return {"kind": "job", "job": job.to_dict()}
|
| 693 |
if payload.command:
|
| 694 |
cwd = payload.cwd or str(REPO_ROOT)
|
| 695 |
+
try:
|
| 696 |
+
job = CONTROL_MANAGER.start_job("Raw Command", payload.command, cwd=cwd, mode="shell")
|
| 697 |
+
except OSError as exc:
|
| 698 |
+
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
| 699 |
return {"kind": "job", "job": job.to_dict()}
|
| 700 |
raise HTTPException(status_code=400, detail="Provide either preset_id or command.")
|
| 701 |
|
serve/server.py
CHANGED
|
@@ -2,7 +2,9 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from fastapi import FastAPI
|
|
@@ -11,13 +13,21 @@ from pydantic import BaseModel
|
|
| 11 |
from model.config import ModelConfig
|
| 12 |
from model.model import SageTransformer
|
| 13 |
from serve.control_plane import build_control_router
|
| 14 |
-
from
|
| 15 |
from train.hardware import HardwareConfig
|
| 16 |
|
| 17 |
|
| 18 |
app = FastAPI(title="SAGE Server")
|
| 19 |
_MODEL: SageTransformer | None = None
|
| 20 |
_TOKENIZER = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
class GenerationRequest(BaseModel):
|
|
@@ -27,37 +37,152 @@ class GenerationRequest(BaseModel):
|
|
| 27 |
max_new_tokens: int = 32
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def get_model() -> SageTransformer:
|
| 31 |
"""Lazily create the model for server startup."""
|
| 32 |
-
global _MODEL
|
| 33 |
if _MODEL is None:
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
_MODEL.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
return _MODEL
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
@app.get("/health")
|
| 40 |
def health() -> dict[str, object]:
|
| 41 |
"""Return basic health and hardware information."""
|
| 42 |
hw = HardwareConfig(model_size_b=1.0, context_length=4096)
|
| 43 |
-
return {"status": "ok", "hardware": hw.summary()}
|
| 44 |
|
| 45 |
|
| 46 |
@app.post("/generate")
|
| 47 |
def generate(request: GenerationRequest) -> dict[str, object]:
|
| 48 |
"""Generate continuation token ids from an input token list."""
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def _health_action(_: dict[str, object]) -> dict[str, object]:
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Optional
|
| 8 |
|
| 9 |
import torch
|
| 10 |
from fastapi import FastAPI
|
|
|
|
| 13 |
from model.config import ModelConfig
|
| 14 |
from model.model import SageTransformer
|
| 15 |
from serve.control_plane import build_control_router
|
| 16 |
+
from train.checkpoint import load_latest_checkpoint
|
| 17 |
from train.hardware import HardwareConfig
|
| 18 |
|
| 19 |
|
| 20 |
app = FastAPI(title="SAGE Server")
|
| 21 |
_MODEL: SageTransformer | None = None
|
| 22 |
_TOKENIZER = None
|
| 23 |
+
_MODEL_DEVICE: torch.device | None = None
|
| 24 |
+
_MODEL_STATE: dict[str, Any] = {
|
| 25 |
+
"model_config": None,
|
| 26 |
+
"checkpoint_dir": None,
|
| 27 |
+
"checkpoint_loaded": False,
|
| 28 |
+
"checkpoint_step": 0,
|
| 29 |
+
"tokenizer_path": None,
|
| 30 |
+
}
|
| 31 |
|
| 32 |
|
| 33 |
class GenerationRequest(BaseModel):
|
|
|
|
| 37 |
max_new_tokens: int = 32
|
| 38 |
|
| 39 |
|
| 40 |
+
class ChatRequest(BaseModel):
|
| 41 |
+
"""Request schema for text generation through the tokenizer."""
|
| 42 |
+
|
| 43 |
+
prompt: str
|
| 44 |
+
max_new_tokens: int = 64
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_generation_device() -> torch.device:
|
| 48 |
+
"""Pick the active inference device."""
|
| 49 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _resolve_model_config_path() -> Path:
|
| 53 |
+
configured = Path(os.environ.get("SAGE_MODEL_CONFIG", "configs/model/1b.yaml"))
|
| 54 |
+
return configured if configured.exists() else Path("configs/model/1b.yaml")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _resolve_checkpoint_dir() -> Path:
|
| 58 |
+
return Path(os.environ.get("SAGE_CHECKPOINT_DIR", "runs/default"))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _resolve_tokenizer_path() -> Path:
|
| 62 |
+
return Path(os.environ.get("SAGE_TOKENIZER_MODEL", "tokenizer/tokenizer.model"))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
def get_model() -> SageTransformer:
|
| 66 |
"""Lazily create the model for server startup."""
|
| 67 |
+
global _MODEL, _MODEL_DEVICE
|
| 68 |
if _MODEL is None:
|
| 69 |
+
config_path = _resolve_model_config_path()
|
| 70 |
+
config = ModelConfig.from_yaml(config_path) if config_path.exists() else ModelConfig()
|
| 71 |
+
_MODEL = SageTransformer(config)
|
| 72 |
+
checkpoint_dir = _resolve_checkpoint_dir()
|
| 73 |
+
checkpoint_step = 0
|
| 74 |
+
if checkpoint_dir.exists():
|
| 75 |
+
checkpoint_step = load_latest_checkpoint(_MODEL, None, None, None, str(checkpoint_dir), device="cpu")
|
| 76 |
+
_MODEL_STATE.update(
|
| 77 |
+
{
|
| 78 |
+
"model_config": str(config_path),
|
| 79 |
+
"checkpoint_dir": str(checkpoint_dir),
|
| 80 |
+
"checkpoint_loaded": checkpoint_step > 0,
|
| 81 |
+
"checkpoint_step": checkpoint_step,
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
_MODEL.eval()
|
| 85 |
+
device = get_generation_device()
|
| 86 |
+
if _MODEL_DEVICE != device:
|
| 87 |
+
_MODEL = _MODEL.to(device)
|
| 88 |
+
_MODEL_DEVICE = device
|
| 89 |
return _MODEL
|
| 90 |
|
| 91 |
|
| 92 |
+
def get_tokenizer():
|
| 93 |
+
"""Lazily load the SentencePiece tokenizer if present."""
|
| 94 |
+
global _TOKENIZER
|
| 95 |
+
if _TOKENIZER is None:
|
| 96 |
+
tokenizer_path = _resolve_tokenizer_path()
|
| 97 |
+
_MODEL_STATE["tokenizer_path"] = str(tokenizer_path)
|
| 98 |
+
if not tokenizer_path.exists():
|
| 99 |
+
return None
|
| 100 |
+
from tokenizer.validate_tokenizer import load_processor
|
| 101 |
+
|
| 102 |
+
_TOKENIZER = load_processor(str(tokenizer_path))
|
| 103 |
+
return _TOKENIZER
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _generate_token_ids(input_ids: list[int], max_new_tokens: int) -> list[int]:
|
| 107 |
+
"""Run greedy decoding from input token ids."""
|
| 108 |
+
model = get_model()
|
| 109 |
+
device = get_generation_device()
|
| 110 |
+
generated = list(input_ids)
|
| 111 |
+
tensor_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
|
| 112 |
+
cache: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None
|
| 113 |
+
with torch.inference_mode():
|
| 114 |
+
for _ in range(max(0, int(max_new_tokens))):
|
| 115 |
+
logits, cache = model(tensor_ids[:, -1:] if cache is not None else tensor_ids, past_key_values=cache)
|
| 116 |
+
next_token = int(torch.argmax(logits[:, -1, :], dim=-1).item())
|
| 117 |
+
generated.append(next_token)
|
| 118 |
+
tensor_ids = torch.tensor([[next_token]], dtype=torch.long, device=device)
|
| 119 |
+
return generated
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def chat_status() -> dict[str, object]:
|
| 123 |
+
"""Return whether text chat is configured for the current server."""
|
| 124 |
+
tokenizer = get_tokenizer()
|
| 125 |
+
checkpoint_loaded = bool(_MODEL_STATE["checkpoint_loaded"])
|
| 126 |
+
available = tokenizer is not None
|
| 127 |
+
warning = None
|
| 128 |
+
if tokenizer is None:
|
| 129 |
+
warning = "Tokenizer model not found. Train or place tokenizer/tokenizer.model before using browser chat."
|
| 130 |
+
elif not checkpoint_loaded:
|
| 131 |
+
warning = "No checkpoint loaded. Chat will run with randomly initialized model weights until you train or configure SAGE_CHECKPOINT_DIR."
|
| 132 |
+
return {
|
| 133 |
+
"available": available,
|
| 134 |
+
"tokenizer_path": _MODEL_STATE["tokenizer_path"],
|
| 135 |
+
"checkpoint_dir": _MODEL_STATE["checkpoint_dir"],
|
| 136 |
+
"checkpoint_loaded": checkpoint_loaded,
|
| 137 |
+
"checkpoint_step": _MODEL_STATE["checkpoint_step"],
|
| 138 |
+
"warning": warning,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
@app.get("/health")
|
| 143 |
def health() -> dict[str, object]:
|
| 144 |
"""Return basic health and hardware information."""
|
| 145 |
hw = HardwareConfig(model_size_b=1.0, context_length=4096)
|
| 146 |
+
return {"status": "ok", "hardware": hw.summary(), "chat": chat_status()}
|
| 147 |
|
| 148 |
|
| 149 |
@app.post("/generate")
|
| 150 |
def generate(request: GenerationRequest) -> dict[str, object]:
|
| 151 |
"""Generate continuation token ids from an input token list."""
|
| 152 |
+
return {"tokens": _generate_token_ids(request.input_ids, request.max_new_tokens)}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@app.get("/chat/status")
|
| 156 |
+
def get_chat_status() -> dict[str, object]:
|
| 157 |
+
"""Expose browser-chat readiness."""
|
| 158 |
+
return chat_status()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@app.post("/chat")
|
| 162 |
+
def chat(request: ChatRequest) -> dict[str, object]:
|
| 163 |
+
"""Generate text from a prompt using the local tokenizer."""
|
| 164 |
+
tokenizer = get_tokenizer()
|
| 165 |
+
if tokenizer is None:
|
| 166 |
+
return {
|
| 167 |
+
"success": False,
|
| 168 |
+
"detail": "Tokenizer model not found. Train the tokenizer first or set SAGE_TOKENIZER_MODEL.",
|
| 169 |
+
**chat_status(),
|
| 170 |
+
}
|
| 171 |
+
prompt = request.prompt.strip()
|
| 172 |
+
if not prompt:
|
| 173 |
+
return {"success": False, "detail": "Prompt cannot be empty.", **chat_status()}
|
| 174 |
+
prompt_ids = list(tokenizer.encode(prompt, out_type=int))
|
| 175 |
+
generated = _generate_token_ids(prompt_ids, request.max_new_tokens)
|
| 176 |
+
completion_ids = generated[len(prompt_ids) :]
|
| 177 |
+
return {
|
| 178 |
+
"success": True,
|
| 179 |
+
"prompt": prompt,
|
| 180 |
+
"response": tokenizer.decode(completion_ids),
|
| 181 |
+
"input_ids": prompt_ids,
|
| 182 |
+
"output_ids": generated,
|
| 183 |
+
"new_token_ids": completion_ids,
|
| 184 |
+
**chat_status(),
|
| 185 |
+
}
|
| 186 |
|
| 187 |
|
| 188 |
def _health_action(_: dict[str, object]) -> dict[str, object]:
|
serve/server_cpu.py
CHANGED
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
| 5 |
import shutil
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
|
|
|
| 8 |
|
| 9 |
from serve.control_plane import build_control_router
|
| 10 |
|
|
@@ -12,10 +13,37 @@ from serve.control_plane import build_control_router
|
|
| 12 |
app = FastAPI(title="SAGE CPU Server")
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
@app.get("/health")
|
| 16 |
def health() -> dict[str, object]:
|
| 17 |
"""Report llama.cpp availability for CPU serving."""
|
| 18 |
-
return {"status": "ok", "llama_cpp_available": shutil.which("llama-server") is not None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def _health_action(_: dict[str, object]) -> dict[str, object]:
|
|
|
|
| 5 |
import shutil
|
| 6 |
|
| 7 |
from fastapi import FastAPI
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
|
| 10 |
from serve.control_plane import build_control_router
|
| 11 |
|
|
|
|
| 13 |
app = FastAPI(title="SAGE CPU Server")
|
| 14 |
|
| 15 |
|
| 16 |
+
class ChatRequest(BaseModel):
|
| 17 |
+
"""Request schema for the browser chat surface."""
|
| 18 |
+
|
| 19 |
+
prompt: str
|
| 20 |
+
max_new_tokens: int = 64
|
| 21 |
+
|
| 22 |
+
|
| 23 |
@app.get("/health")
|
| 24 |
def health() -> dict[str, object]:
|
| 25 |
"""Report llama.cpp availability for CPU serving."""
|
| 26 |
+
return {"status": "ok", "llama_cpp_available": shutil.which("llama-server") is not None, "chat": chat_status()}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def chat_status() -> dict[str, object]:
|
| 30 |
+
"""Return chat readiness for the CPU server."""
|
| 31 |
+
return {
|
| 32 |
+
"available": False,
|
| 33 |
+
"warning": "Browser chat is only wired to the PyTorch GPU server in this repo. Use serve.server:app for direct interaction.",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@app.get("/chat/status")
|
| 38 |
+
def get_chat_status() -> dict[str, object]:
|
| 39 |
+
"""Expose browser-chat readiness."""
|
| 40 |
+
return chat_status()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@app.post("/chat")
|
| 44 |
+
def chat(_: ChatRequest) -> dict[str, object]:
|
| 45 |
+
"""Return a clear error for CPU-only control-plane mode."""
|
| 46 |
+
return {"success": False, "detail": chat_status()["warning"], **chat_status()}
|
| 47 |
|
| 48 |
|
| 49 |
def _health_action(_: dict[str, object]) -> dict[str, object]:
|
serve/static/index.html
CHANGED
|
@@ -139,6 +139,36 @@
|
|
| 139 |
font-size: 12px;
|
| 140 |
color: var(--muted);
|
| 141 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
@media (max-width: 980px) {
|
| 143 |
.grid { grid-template-columns: 1fr; }
|
| 144 |
}
|
|
@@ -188,6 +218,21 @@
|
|
| 188 |
<h2>API Result</h2>
|
| 189 |
<div id="api-result" class="mono">No API result yet.</div>
|
| 190 |
</section>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
</div>
|
| 192 |
|
| 193 |
<div>
|
|
@@ -215,6 +260,7 @@
|
|
| 215 |
selectedJobId: null,
|
| 216 |
stream: null,
|
| 217 |
pollTimer: null,
|
|
|
|
| 218 |
};
|
| 219 |
|
| 220 |
const loginPanel = document.getElementById("login-panel");
|
|
@@ -229,6 +275,11 @@
|
|
| 229 |
const logsEl = document.getElementById("logs");
|
| 230 |
const selectedJobEl = document.getElementById("selected-job");
|
| 231 |
const apiResultEl = document.getElementById("api-result");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
async function api(path, options = {}) {
|
| 234 |
const response = await fetch(path, {
|
|
@@ -265,6 +316,28 @@
|
|
| 265 |
return state.presets.find((item) => item.id === presetSelect.value);
|
| 266 |
}
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
function renderPresetFields() {
|
| 269 |
const preset = currentPreset();
|
| 270 |
presetFields.innerHTML = "";
|
|
@@ -275,7 +348,7 @@
|
|
| 275 |
for (const field of preset.fields) {
|
| 276 |
const wrapper = document.createElement("div");
|
| 277 |
const label = document.createElement("label");
|
| 278 |
-
label.textContent = field.label;
|
| 279 |
label.htmlFor = `field-${field.name}`;
|
| 280 |
wrapper.appendChild(label);
|
| 281 |
|
|
@@ -308,6 +381,7 @@
|
|
| 308 |
input.id = `field-${field.name}`;
|
| 309 |
input.dataset.kind = field.kind;
|
| 310 |
input.dataset.name = field.name;
|
|
|
|
| 311 |
input.placeholder = field.placeholder || "";
|
| 312 |
wrapper.appendChild(input);
|
| 313 |
presetFields.appendChild(wrapper);
|
|
@@ -319,13 +393,20 @@
|
|
| 319 |
for (const element of presetFields.querySelectorAll("[data-name]")) {
|
| 320 |
const name = element.dataset.name;
|
| 321 |
const kind = element.dataset.kind;
|
|
|
|
| 322 |
if (kind === "boolean") {
|
| 323 |
args[name] = element.checked;
|
| 324 |
} else if (kind === "number") {
|
| 325 |
args[name] = element.value === "" ? "" : Number(element.value);
|
| 326 |
} else if (kind === "json") {
|
|
|
|
|
|
|
|
|
|
| 327 |
args[name] = element.value.trim() ? JSON.parse(element.value) : null;
|
| 328 |
} else {
|
|
|
|
|
|
|
|
|
|
| 329 |
args[name] = element.value;
|
| 330 |
}
|
| 331 |
}
|
|
@@ -388,6 +469,15 @@
|
|
| 388 |
renderJobs(payload.jobs || []);
|
| 389 |
}
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
function appendLog(line) {
|
| 392 |
if (!line) {
|
| 393 |
return;
|
|
@@ -432,6 +522,7 @@
|
|
| 432 |
presetSelect.appendChild(option);
|
| 433 |
}
|
| 434 |
renderPresetFields();
|
|
|
|
| 435 |
}
|
| 436 |
|
| 437 |
async function login() {
|
|
@@ -492,6 +583,35 @@
|
|
| 492 |
apiResultEl.textContent = String(error.message || error);
|
| 493 |
}
|
| 494 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
loadPresets().then(() => {
|
| 497 |
showApp();
|
|
|
|
| 139 |
font-size: 12px;
|
| 140 |
color: var(--muted);
|
| 141 |
}
|
| 142 |
+
.chat-status {
|
| 143 |
+
margin-bottom: 12px;
|
| 144 |
+
font-size: 13px;
|
| 145 |
+
color: var(--muted);
|
| 146 |
+
}
|
| 147 |
+
.chat-messages {
|
| 148 |
+
min-height: 220px;
|
| 149 |
+
max-height: 360px;
|
| 150 |
+
overflow: auto;
|
| 151 |
+
padding: 12px;
|
| 152 |
+
border-radius: 10px;
|
| 153 |
+
border: 1px solid var(--panel-border);
|
| 154 |
+
background: #0b1015;
|
| 155 |
+
display: grid;
|
| 156 |
+
gap: 10px;
|
| 157 |
+
}
|
| 158 |
+
.message {
|
| 159 |
+
border-radius: 10px;
|
| 160 |
+
padding: 10px 12px;
|
| 161 |
+
white-space: pre-wrap;
|
| 162 |
+
line-height: 1.5;
|
| 163 |
+
}
|
| 164 |
+
.message-user {
|
| 165 |
+
background: #0f2330;
|
| 166 |
+
border: 1px solid #1e3c50;
|
| 167 |
+
}
|
| 168 |
+
.message-model {
|
| 169 |
+
background: #151b22;
|
| 170 |
+
border: 1px solid var(--panel-border);
|
| 171 |
+
}
|
| 172 |
@media (max-width: 980px) {
|
| 173 |
.grid { grid-template-columns: 1fr; }
|
| 174 |
}
|
|
|
|
| 218 |
<h2>API Result</h2>
|
| 219 |
<div id="api-result" class="mono">No API result yet.</div>
|
| 220 |
</section>
|
| 221 |
+
|
| 222 |
+
<section class="panel">
|
| 223 |
+
<h2>Model Chat</h2>
|
| 224 |
+
<div id="chat-status" class="chat-status">Checking chat status...</div>
|
| 225 |
+
<div id="chat-messages" class="chat-messages">
|
| 226 |
+
<div class="message message-model">Prompt the local model here once a tokenizer exists. If no checkpoint is loaded yet, outputs will be random.</div>
|
| 227 |
+
</div>
|
| 228 |
+
<label for="chat-prompt">Prompt</label>
|
| 229 |
+
<textarea id="chat-prompt" placeholder="Explain what data files I need before training this repo."></textarea>
|
| 230 |
+
<label for="chat-max-tokens">Max New Tokens</label>
|
| 231 |
+
<input id="chat-max-tokens" type="number" value="96" min="1" max="512">
|
| 232 |
+
<div class="button-row">
|
| 233 |
+
<button id="chat-send" class="primary">Send Prompt</button>
|
| 234 |
+
</div>
|
| 235 |
+
</section>
|
| 236 |
</div>
|
| 237 |
|
| 238 |
<div>
|
|
|
|
| 260 |
selectedJobId: null,
|
| 261 |
stream: null,
|
| 262 |
pollTimer: null,
|
| 263 |
+
chatStatus: null,
|
| 264 |
};
|
| 265 |
|
| 266 |
const loginPanel = document.getElementById("login-panel");
|
|
|
|
| 275 |
const logsEl = document.getElementById("logs");
|
| 276 |
const selectedJobEl = document.getElementById("selected-job");
|
| 277 |
const apiResultEl = document.getElementById("api-result");
|
| 278 |
+
const chatStatusEl = document.getElementById("chat-status");
|
| 279 |
+
const chatMessagesEl = document.getElementById("chat-messages");
|
| 280 |
+
const chatPromptEl = document.getElementById("chat-prompt");
|
| 281 |
+
const chatMaxTokensEl = document.getElementById("chat-max-tokens");
|
| 282 |
+
const chatSendEl = document.getElementById("chat-send");
|
| 283 |
|
| 284 |
async function api(path, options = {}) {
|
| 285 |
const response = await fetch(path, {
|
|
|
|
| 316 |
return state.presets.find((item) => item.id === presetSelect.value);
|
| 317 |
}
|
| 318 |
|
| 319 |
+
function appendMessage(role, text) {
|
| 320 |
+
const block = document.createElement("div");
|
| 321 |
+
block.className = `message ${role === "user" ? "message-user" : "message-model"}`;
|
| 322 |
+
block.textContent = text;
|
| 323 |
+
chatMessagesEl.appendChild(block);
|
| 324 |
+
chatMessagesEl.scrollTop = chatMessagesEl.scrollHeight;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
function renderChatStatus(payload) {
|
| 328 |
+
state.chatStatus = payload;
|
| 329 |
+
const warning = payload && payload.warning ? ` ${payload.warning}` : "";
|
| 330 |
+
if (!payload) {
|
| 331 |
+
chatStatusEl.textContent = "Chat status unavailable.";
|
| 332 |
+
chatSendEl.disabled = true;
|
| 333 |
+
return;
|
| 334 |
+
}
|
| 335 |
+
chatStatusEl.textContent = payload.available
|
| 336 |
+
? `Chat ready.${warning}`
|
| 337 |
+
: `Chat unavailable.${warning}`;
|
| 338 |
+
chatSendEl.disabled = !payload.available;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
function renderPresetFields() {
|
| 342 |
const preset = currentPreset();
|
| 343 |
presetFields.innerHTML = "";
|
|
|
|
| 348 |
for (const field of preset.fields) {
|
| 349 |
const wrapper = document.createElement("div");
|
| 350 |
const label = document.createElement("label");
|
| 351 |
+
label.textContent = field.required ? `${field.label} *` : field.label;
|
| 352 |
label.htmlFor = `field-${field.name}`;
|
| 353 |
wrapper.appendChild(label);
|
| 354 |
|
|
|
|
| 381 |
input.id = `field-${field.name}`;
|
| 382 |
input.dataset.kind = field.kind;
|
| 383 |
input.dataset.name = field.name;
|
| 384 |
+
input.dataset.required = field.required ? "true" : "false";
|
| 385 |
input.placeholder = field.placeholder || "";
|
| 386 |
wrapper.appendChild(input);
|
| 387 |
presetFields.appendChild(wrapper);
|
|
|
|
| 393 |
for (const element of presetFields.querySelectorAll("[data-name]")) {
|
| 394 |
const name = element.dataset.name;
|
| 395 |
const kind = element.dataset.kind;
|
| 396 |
+
const required = element.dataset.required === "true";
|
| 397 |
if (kind === "boolean") {
|
| 398 |
args[name] = element.checked;
|
| 399 |
} else if (kind === "number") {
|
| 400 |
args[name] = element.value === "" ? "" : Number(element.value);
|
| 401 |
} else if (kind === "json") {
|
| 402 |
+
if (required && !element.value.trim()) {
|
| 403 |
+
throw new Error(`Field ${name} is required.`);
|
| 404 |
+
}
|
| 405 |
args[name] = element.value.trim() ? JSON.parse(element.value) : null;
|
| 406 |
} else {
|
| 407 |
+
if (required && !element.value.trim()) {
|
| 408 |
+
throw new Error(`Field ${name} is required.`);
|
| 409 |
+
}
|
| 410 |
args[name] = element.value;
|
| 411 |
}
|
| 412 |
}
|
|
|
|
| 469 |
renderJobs(payload.jobs || []);
|
| 470 |
}
|
| 471 |
|
| 472 |
+
async function loadChatStatus() {
|
| 473 |
+
try {
|
| 474 |
+
const payload = await api("/chat/status");
|
| 475 |
+
renderChatStatus(payload);
|
| 476 |
+
} catch (error) {
|
| 477 |
+
renderChatStatus({ available: false, warning: String(error.message || error) });
|
| 478 |
+
}
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
function appendLog(line) {
|
| 482 |
if (!line) {
|
| 483 |
return;
|
|
|
|
| 522 |
presetSelect.appendChild(option);
|
| 523 |
}
|
| 524 |
renderPresetFields();
|
| 525 |
+
await loadChatStatus();
|
| 526 |
}
|
| 527 |
|
| 528 |
async function login() {
|
|
|
|
| 583 |
apiResultEl.textContent = String(error.message || error);
|
| 584 |
}
|
| 585 |
});
|
| 586 |
+
chatSendEl.addEventListener("click", async () => {
|
| 587 |
+
const prompt = chatPromptEl.value.trim();
|
| 588 |
+
if (!prompt) {
|
| 589 |
+
apiResultEl.textContent = "Prompt cannot be empty.";
|
| 590 |
+
return;
|
| 591 |
+
}
|
| 592 |
+
appendMessage("user", prompt);
|
| 593 |
+
chatPromptEl.value = "";
|
| 594 |
+
chatSendEl.disabled = true;
|
| 595 |
+
try {
|
| 596 |
+
const payload = await api("/chat", {
|
| 597 |
+
method: "POST",
|
| 598 |
+
body: JSON.stringify({
|
| 599 |
+
prompt,
|
| 600 |
+
max_new_tokens: Number(chatMaxTokensEl.value || 96),
|
| 601 |
+
}),
|
| 602 |
+
});
|
| 603 |
+
renderChatStatus(payload);
|
| 604 |
+
if (!payload.success) {
|
| 605 |
+
appendMessage("model", payload.detail || "Chat failed.");
|
| 606 |
+
return;
|
| 607 |
+
}
|
| 608 |
+
appendMessage("model", payload.response || "[empty response]");
|
| 609 |
+
} catch (error) {
|
| 610 |
+
appendMessage("model", String(error.message || error));
|
| 611 |
+
} finally {
|
| 612 |
+
chatSendEl.disabled = !(state.chatStatus && state.chatStatus.available);
|
| 613 |
+
}
|
| 614 |
+
});
|
| 615 |
|
| 616 |
loadPresets().then(() => {
|
| 617 |
showApp();
|
tests/test_control_plane.py
CHANGED
|
@@ -50,6 +50,8 @@ def test_login_and_html_index(monkeypatch) -> None:
|
|
| 50 |
assert response.status_code == 200
|
| 51 |
payload = response.json()
|
| 52 |
preset_ids = {item["id"] for item in payload["presets"]}
|
|
|
|
|
|
|
| 53 |
assert "serve_cpu" in preset_ids
|
| 54 |
assert "git_status" in preset_ids
|
| 55 |
|
|
@@ -120,3 +122,14 @@ def test_health_api_preset(monkeypatch) -> None:
|
|
| 120 |
payload = response.json()
|
| 121 |
assert payload["kind"] == "api"
|
| 122 |
assert payload["result"]["status"] == "ok"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
assert response.status_code == 200
|
| 51 |
payload = response.json()
|
| 52 |
preset_ids = {item["id"] for item in payload["presets"]}
|
| 53 |
+
assert "data_bootstrap" in preset_ids
|
| 54 |
+
assert "data_pipeline" in preset_ids
|
| 55 |
assert "serve_cpu" in preset_ids
|
| 56 |
assert "git_status" in preset_ids
|
| 57 |
|
|
|
|
| 122 |
payload = response.json()
|
| 123 |
assert payload["kind"] == "api"
|
| 124 |
assert payload["result"]["status"] == "ok"
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_required_preset_field_validation(monkeypatch) -> None:
|
| 128 |
+
monkeypatch.setenv("SAGE_WEB_PASSWORD", "test-password")
|
| 129 |
+
CONTROL_MANAGER.reset_for_tests()
|
| 130 |
+
client = TestClient(app)
|
| 131 |
+
_login(client)
|
| 132 |
+
|
| 133 |
+
response = client.post("/api/commands/run", json={"preset_id": "tokenizer_train", "args": {"input_paths": ""}})
|
| 134 |
+
assert response.status_code == 400
|
| 135 |
+
assert "Input Paths" in response.json()["detail"]
|
tests/test_data_pipeline.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from data.dataset import pack_sequence
|
|
|
|
| 2 |
from data.dedup import deduplicate_records
|
| 3 |
from data.filter import filter_record
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def test_filter_record_masks_pii() -> None:
|
|
@@ -31,3 +38,48 @@ def test_pack_sequence_shapes() -> None:
|
|
| 31 |
assert packed["input_ids"].tolist() == [1, 2, 3, 4]
|
| 32 |
assert packed["labels"].tolist() == [2, 3, 4, 5]
|
| 33 |
assert packed["document_boundaries"].tolist() == [0, 0, 1, 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
from data.dataset import pack_sequence
|
| 5 |
+
from data.bootstrap import bootstrap_raw_corpora
|
| 6 |
from data.dedup import deduplicate_records
|
| 7 |
from data.filter import filter_record
|
| 8 |
+
from data.ingest import SourceSpec
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
|
| 12 |
|
| 13 |
def test_filter_record_masks_pii() -> None:
|
|
|
|
| 38 |
assert packed["input_ids"].tolist() == [1, 2, 3, 4]
|
| 39 |
assert packed["labels"].tolist() == [2, 3, 4, 5]
|
| 40 |
assert packed["document_boundaries"].tolist() == [0, 0, 1, 0]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_bootstrap_raw_corpora_writes_jsonl(tmp_path: Path) -> None:
|
| 44 |
+
summary = bootstrap_raw_corpora(output_dir=str(tmp_path), overwrite=True)
|
| 45 |
+
assert summary["general_web"] > 0
|
| 46 |
+
sample_path = tmp_path / "general_web.jsonl"
|
| 47 |
+
first = json.loads(sample_path.read_text(encoding="utf-8").splitlines()[0])
|
| 48 |
+
assert "text" in first
|
| 49 |
+
assert len(first["text"]) >= 240
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_pipeline_writes_manifest_from_bootstrap_data(tmp_path: Path, monkeypatch) -> None:
|
| 53 |
+
pytest.importorskip("sentencepiece")
|
| 54 |
+
from data import pipeline
|
| 55 |
+
from tokenizer.train_tokenizer import train_sentencepiece, write_training_text
|
| 56 |
+
|
| 57 |
+
raw_dir = tmp_path / "raw"
|
| 58 |
+
bootstrap_raw_corpora(output_dir=str(raw_dir), overwrite=True)
|
| 59 |
+
training_text = tmp_path / "training.txt"
|
| 60 |
+
write_training_text([str(path) for path in raw_dir.glob("*.jsonl")], str(training_text))
|
| 61 |
+
prefix = tmp_path / "tokenizer"
|
| 62 |
+
train_sentencepiece(str(training_text), str(prefix), vocab_size=512)
|
| 63 |
+
|
| 64 |
+
registry = tuple(
|
| 65 |
+
SourceSpec(
|
| 66 |
+
name=path.stem,
|
| 67 |
+
domain_tag="general",
|
| 68 |
+
quality_tier="high",
|
| 69 |
+
license_category="permissive",
|
| 70 |
+
estimated_tokens=1_000,
|
| 71 |
+
path=str(path),
|
| 72 |
+
)
|
| 73 |
+
for path in raw_dir.glob("*.jsonl")
|
| 74 |
+
)
|
| 75 |
+
monkeypatch.setattr(pipeline, "SOURCE_REGISTRY", registry)
|
| 76 |
+
|
| 77 |
+
output_dir = tmp_path / "processed"
|
| 78 |
+
summary = pipeline.run_pipeline(
|
| 79 |
+
tokenizer_model=str(prefix) + ".model",
|
| 80 |
+
output_dir=str(output_dir),
|
| 81 |
+
shard_size=4,
|
| 82 |
+
)
|
| 83 |
+
manifest_path = output_dir / "manifest.json"
|
| 84 |
+
assert summary["records"] > 0
|
| 85 |
+
assert manifest_path.exists()
|
tests/test_servers.py
CHANGED
|
@@ -45,3 +45,47 @@ def test_gpu_server_generate(monkeypatch) -> None:
|
|
| 45 |
assert response.status_code == 200
|
| 46 |
payload = response.json()
|
| 47 |
assert payload["tokens"] == [1, 2, 3, 3, 3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
assert response.status_code == 200
|
| 46 |
payload = response.json()
|
| 47 |
assert payload["tokens"] == [1, 2, 3, 3, 3]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_gpu_server_chat(monkeypatch) -> None:
|
| 51 |
+
class FakeModel:
|
| 52 |
+
def eval(self) -> "FakeModel":
|
| 53 |
+
return self
|
| 54 |
+
|
| 55 |
+
def to(self, _device) -> "FakeModel":
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
def __call__(self, input_ids, past_key_values=None):
|
| 59 |
+
import torch
|
| 60 |
+
|
| 61 |
+
batch, seq = input_ids.shape
|
| 62 |
+
logits = torch.zeros((batch, seq, 32), dtype=torch.float32)
|
| 63 |
+
logits[:, :, 7] = 1.0
|
| 64 |
+
cache = [] if past_key_values is None else past_key_values
|
| 65 |
+
return logits, cache
|
| 66 |
+
|
| 67 |
+
class FakeTokenizer:
|
| 68 |
+
def encode(self, text, out_type=int):
|
| 69 |
+
return [1, 2, 3]
|
| 70 |
+
|
| 71 |
+
def decode(self, ids):
|
| 72 |
+
return "decoded:" + ",".join(str(item) for item in ids)
|
| 73 |
+
|
| 74 |
+
monkeypatch.setattr(gpu_server, "get_model", lambda: FakeModel())
|
| 75 |
+
monkeypatch.setattr(gpu_server, "get_tokenizer", lambda: FakeTokenizer())
|
| 76 |
+
monkeypatch.setattr(gpu_server, "chat_status", lambda: {"available": True, "warning": None, "checkpoint_loaded": False})
|
| 77 |
+
monkeypatch.setattr(gpu_server.torch.cuda, "is_available", lambda: False)
|
| 78 |
+
client = TestClient(gpu_app)
|
| 79 |
+
response = client.post("/chat", json={"prompt": "hello", "max_new_tokens": 2})
|
| 80 |
+
assert response.status_code == 200
|
| 81 |
+
payload = response.json()
|
| 82 |
+
assert payload["success"] is True
|
| 83 |
+
assert payload["response"] == "decoded:7,7"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def test_cpu_server_chat_status() -> None:
|
| 87 |
+
client = TestClient(cpu_app)
|
| 88 |
+
response = client.get("/chat/status")
|
| 89 |
+
assert response.status_code == 200
|
| 90 |
+
payload = response.json()
|
| 91 |
+
assert payload["available"] is False
|
tests/test_tokenizer.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from pathlib import Path
|
|
|
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
|
|
@@ -23,3 +24,22 @@ def test_validation_suite_roundtrip(tmp_path: Path) -> None:
|
|
| 23 |
train_sentencepiece(str(corpus), str(prefix), vocab_size=512)
|
| 24 |
results = run_validation_suite(str(prefix) + ".model")
|
| 25 |
assert all(result.passed for result in results), results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
|
| 4 |
import pytest
|
| 5 |
|
|
|
|
| 24 |
train_sentencepiece(str(corpus), str(prefix), vocab_size=512)
|
| 25 |
results = run_validation_suite(str(prefix) + ".model")
|
| 26 |
assert all(result.passed for result in results), results
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_write_training_text_reads_jsonl(tmp_path: Path) -> None:
|
| 30 |
+
from tokenizer.train_tokenizer import write_training_text
|
| 31 |
+
|
| 32 |
+
raw = tmp_path / "raw.jsonl"
|
| 33 |
+
raw.write_text(
|
| 34 |
+
"\n".join(
|
| 35 |
+
[
|
| 36 |
+
json.dumps({"text": "first training sample"}),
|
| 37 |
+
json.dumps({"text": "second training sample"}),
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
+ "\n",
|
| 41 |
+
encoding="utf-8",
|
| 42 |
+
)
|
| 43 |
+
output = tmp_path / "combined.txt"
|
| 44 |
+
write_training_text([str(raw)], str(output))
|
| 45 |
+
assert output.read_text(encoding="utf-8").splitlines() == ["first training sample", "second training sample"]
|
tokenizer/train_tokenizer.py
CHANGED
|
@@ -3,32 +3,50 @@
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import argparse
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
-
from typing import Iterable
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
-
def write_training_text(corpus_paths: Iterable[str], output_path: str) -> str:
|
| 16 |
"""Concatenate corpus text into a plain-text file for SentencePiece."""
|
| 17 |
output = Path(output_path)
|
| 18 |
output.parent.mkdir(parents=True, exist_ok=True)
|
| 19 |
with output.open("w", encoding="utf-8") as sink:
|
| 20 |
-
for
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
line = line.strip()
|
| 24 |
-
if line:
|
| 25 |
-
sink.write(line)
|
| 26 |
-
sink.write("\n")
|
| 27 |
return str(output)
|
| 28 |
|
| 29 |
|
| 30 |
def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50_000) -> None:
|
| 31 |
"""Train a byte-fallback SentencePiece BPE model."""
|
|
|
|
|
|
|
| 32 |
spm.SentencePieceTrainer.train(
|
| 33 |
input=input_path,
|
| 34 |
model_prefix=model_prefix,
|
|
@@ -52,17 +70,18 @@ def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50
|
|
| 52 |
def build_argparser() -> argparse.ArgumentParser:
|
| 53 |
"""Build the CLI parser."""
|
| 54 |
parser = argparse.ArgumentParser(description="Train the SAGE SentencePiece tokenizer.")
|
| 55 |
-
parser.add_argument("--input", nargs="+", required=True, help="Plain-text corpus files.")
|
| 56 |
parser.add_argument("--model-prefix", default="tokenizer/tokenizer", help="SentencePiece model prefix.")
|
| 57 |
parser.add_argument("--vocab-size", type=int, default=50_000, help="Tokenizer vocabulary size.")
|
| 58 |
parser.add_argument("--training-text", default="tokenizer/training_corpus.txt", help="Temporary combined text file.")
|
|
|
|
| 59 |
return parser
|
| 60 |
|
| 61 |
|
| 62 |
def main() -> None:
|
| 63 |
"""Train the tokenizer from CLI arguments."""
|
| 64 |
args = build_argparser().parse_args()
|
| 65 |
-
training_text = write_training_text(args.input, args.training_text)
|
| 66 |
train_sentencepiece(training_text, args.model_prefix, args.vocab_size)
|
| 67 |
|
| 68 |
|
|
|
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
import argparse
|
| 6 |
+
import json
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Iterable, Iterator
|
| 9 |
|
| 10 |
+
DEFAULT_SPECIAL_TOKENS = ("<bos>", "<eos>", "<pad>", "<unk>", "[INST]", "[/INST]")
|
| 11 |
|
| 12 |
|
| 13 |
+
def iter_training_text(corpus_paths: Iterable[str], text_key: str = "text") -> Iterator[str]:
|
| 14 |
+
"""Yield training lines from plain-text or JSONL corpus files."""
|
| 15 |
+
for path in corpus_paths:
|
| 16 |
+
source = Path(path)
|
| 17 |
+
suffix = source.suffix.lower()
|
| 18 |
+
with source.open("r", encoding="utf-8") as handle:
|
| 19 |
+
if suffix == ".jsonl":
|
| 20 |
+
for raw_line in handle:
|
| 21 |
+
raw_line = raw_line.strip()
|
| 22 |
+
if not raw_line:
|
| 23 |
+
continue
|
| 24 |
+
payload = json.loads(raw_line)
|
| 25 |
+
text = payload.get(text_key)
|
| 26 |
+
if isinstance(text, str) and text.strip():
|
| 27 |
+
yield text.strip()
|
| 28 |
+
continue
|
| 29 |
+
for raw_line in handle:
|
| 30 |
+
text = raw_line.strip()
|
| 31 |
+
if text:
|
| 32 |
+
yield text
|
| 33 |
|
| 34 |
|
| 35 |
+
def write_training_text(corpus_paths: Iterable[str], output_path: str, text_key: str = "text") -> str:
|
| 36 |
"""Concatenate corpus text into a plain-text file for SentencePiece."""
|
| 37 |
output = Path(output_path)
|
| 38 |
output.parent.mkdir(parents=True, exist_ok=True)
|
| 39 |
with output.open("w", encoding="utf-8") as sink:
|
| 40 |
+
for line in iter_training_text(corpus_paths, text_key=text_key):
|
| 41 |
+
sink.write(line)
|
| 42 |
+
sink.write("\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
return str(output)
|
| 44 |
|
| 45 |
|
| 46 |
def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50_000) -> None:
|
| 47 |
"""Train a byte-fallback SentencePiece BPE model."""
|
| 48 |
+
import sentencepiece as spm
|
| 49 |
+
|
| 50 |
spm.SentencePieceTrainer.train(
|
| 51 |
input=input_path,
|
| 52 |
model_prefix=model_prefix,
|
|
|
|
| 70 |
def build_argparser() -> argparse.ArgumentParser:
|
| 71 |
"""Build the CLI parser."""
|
| 72 |
parser = argparse.ArgumentParser(description="Train the SAGE SentencePiece tokenizer.")
|
| 73 |
+
parser.add_argument("--input", nargs="+", required=True, help="Plain-text or JSONL corpus files.")
|
| 74 |
parser.add_argument("--model-prefix", default="tokenizer/tokenizer", help="SentencePiece model prefix.")
|
| 75 |
parser.add_argument("--vocab-size", type=int, default=50_000, help="Tokenizer vocabulary size.")
|
| 76 |
parser.add_argument("--training-text", default="tokenizer/training_corpus.txt", help="Temporary combined text file.")
|
| 77 |
+
parser.add_argument("--text-key", default="text", help="JSONL field to read when --input contains .jsonl files.")
|
| 78 |
return parser
|
| 79 |
|
| 80 |
|
| 81 |
def main() -> None:
|
| 82 |
"""Train the tokenizer from CLI arguments."""
|
| 83 |
args = build_argparser().parse_args()
|
| 84 |
+
training_text = write_training_text(args.input, args.training_text, text_key=args.text_key)
|
| 85 |
train_sentencepiece(training_text, args.model_prefix, args.vocab_size)
|
| 86 |
|
| 87 |
|