sage002 commited on
Commit
1e799aa
·
verified ·
1 Parent(s): b4f432f

feat: add authenticated remote control UI and ngrok launcher

Browse files
README.md CHANGED
@@ -6,8 +6,16 @@ Designed to be both educational and functional, SAGE can be trained, fine-tuned,
6
 
7
  ---
8
 
9
- ## ☁️ Cloud Quickstart (Kaggle / Colab)
10
- Running SAGE in the cloud? Check out the **[Kaggle & Colab Quickstart Guide](file:///c:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/SAGE_KAGGLE_GUIDE.md)** for one-click setup and a premium interactive chat interface.
 
 
 
 
 
 
 
 
11
 
12
  ---
13
 
@@ -92,11 +100,13 @@ Once launched, simply type your message to chat with SAGE. The system uses a rol
92
  The FastAPI server now serves a minimal remote control panel at `/`.
93
 
94
  ```bash
95
- export SAGE_WEB_PASSWORD=change-me
 
 
96
  python -m uvicorn serve.server:app --host 0.0.0.0 --port 8000
97
  ```
98
 
99
- Open the server root in a browser, log in with `SAGE_WEB_PASSWORD`, and use preset actions or a raw command box to drive the local repo from the UI. The included `test.ipynb` notebook starts the real app and exposes it through ngrok for Colab.
100
 
101
  ---
102
 
 
6
 
7
  ---
8
 
9
+ ## Command Guide
10
+
11
+ Use the repo command sheet at `docs/COMMANDS.md` for the current end-to-end flow:
12
+
13
+ - bootstrap starter JSONL data
14
+ - train the tokenizer
15
+ - build parquet shards
16
+ - launch training
17
+ - serve the model
18
+ - use the browser UI and chat endpoint
19
 
20
  ---
21
 
 
100
  The FastAPI server now serves a minimal remote control panel at `/`.
101
 
102
  ```bash
103
+ $env:SAGE_WEB_PASSWORD="change-me"
104
+ $env:SAGE_CHECKPOINT_DIR="runs/sage-1b"
105
+ $env:SAGE_TOKENIZER_MODEL="tokenizer/tokenizer.model"
106
  python -m uvicorn serve.server:app --host 0.0.0.0 --port 8000
107
  ```
108
 
109
+ Open the server root in a browser, log in with `SAGE_WEB_PASSWORD`, and use preset actions or a raw command box to drive the local repo from the UI. The browser UI now also includes direct text chat through `/chat` when a tokenizer is available. If no checkpoint is loaded yet, the UI warns that outputs are coming from randomly initialized weights. The included `test.ipynb` notebook starts the real app and exposes it through ngrok for Colab.
110
 
111
  ---
112
 
data/bootstrap.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bootstrap small raw corpora for tokenizer and smoke-training flows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+
10
+ BOOTSTRAP_CORPORA: dict[str, list[str]] = {
11
+ "general_web": [
12
+ "Large language models learn by predicting the next token in a sequence, but useful systems depend just as much on data quality as on architecture size.",
13
+ "A good training corpus mixes clean prose, documentation, dialogue, and reference material so the model sees multiple ways humans structure information.",
14
+ "When you build a local model, start with small smoke runs, measure loss curves, and only then scale sequence length, batch size, and parameter count.",
15
+ "The fastest way to waste compute is to train on noisy duplicated text without checking tokenization, filtering, and validation splits first.",
16
+ "Evaluation should include both regression tests and qualitative prompts because perplexity alone does not tell you whether a model follows instructions well.",
17
+ "A serving stack usually needs checkpoint loading, tokenization, generation settings, and telemetry before it is practical for iterative experiments.",
18
+ ],
19
+ "code": [
20
+ "def running_mean(values):\n total = 0.0\n result = []\n for index, value in enumerate(values, start=1):\n total += value\n result.append(total / index)\n return result",
21
+ "class TextBatch:\n def __init__(self, items):\n self.items = list(items)\n\n def join(self, sep='\\n'):\n return sep.join(self.items)",
22
+ "from pathlib import Path\n\ndef read_text(path):\n return Path(path).read_text(encoding='utf-8')",
23
+ "def clamp(value, lo, hi):\n if value < lo:\n return lo\n if value > hi:\n return hi\n return value",
24
+ "def format_metrics(step, loss):\n return f'step={step} loss={loss:.4f}'",
25
+ "def greedy_decode(logits):\n import torch\n return int(torch.argmax(logits, dim=-1).item())",
26
+ ],
27
+ "math_science": [
28
+ "The derivative of x squared is 2x, and gradient-based optimization uses derivatives to decide how to update model parameters.",
29
+ "Perplexity is the exponential of average negative log likelihood; lower perplexity means the model assigns higher probability to the observed sequence.",
30
+ "If a batch contains B sequences of length T, then the number of next-token predictions is roughly B times T.",
31
+ "Matrix multiplication is central to transformer inference because projections for queries, keys, values, and feed-forward layers are all linear maps.",
32
+ "Softmax converts raw logits into a probability distribution by exponentiating each value and dividing by the sum of exponentials.",
33
+ "The context window bounds how many previous tokens the decoder can attend to while producing the next token.",
34
+ ],
35
+ "multilingual": [
36
+ "English: Training data should be filtered, deduplicated, and documented before long runs begin.",
37
+ "Hindi: अच्छे मॉडल के लिए साफ और विविध डेटा उतना ही जरूरी है जितना अच्छा आर्किटेक्चर।",
38
+ "Arabic: جودة البيانات تؤثر على جودة النموذج بقدر تأثير حجم النموذج نفسه.",
39
+ "Chinese: 在开始长时间训练之前,先做小规模验证可以节省大量计算资源。",
40
+ "Spanish: Un buen flujo de datos incluye limpieza, deduplicacion y particiones reproducibles.",
41
+ "French: Un modele utile demande des donnees propres, des tests et une boucle d'evaluation simple.",
42
+ ],
43
+ "synthetic": [
44
+ "[INST] Explain why deduplication matters before tokenizer training. [/INST] Deduplication prevents repeated passages from dominating merge statistics and reduces wasted compute during later model training.",
45
+ "[INST] Write a short checklist for a smoke training run. [/INST] Verify shards exist, verify tokenizer loads, run a short job, inspect metrics, and confirm checkpoints are written.",
46
+ "[INST] How do you know a dataset is too noisy? [/INST] Look for low alpha ratios, malformed markup, repeated content, excessive boilerplate, or corrupted encoding.",
47
+ "[INST] What is the purpose of a validation split? [/INST] It gives you held-out data for tracking generalization and for catching regressions during training.",
48
+ "[INST] Summarize the role of the tokenizer. [/INST] The tokenizer maps raw text into stable token ids the model can consume during training and generation.",
49
+ "[INST] Why keep metadata with each record? [/INST] Metadata helps audit provenance, quality, language mix, and filtering decisions across the pipeline.",
50
+ ],
51
+ }
52
+
53
+
54
+ def _pad_sample(text: str, minimum_chars: int = 240) -> str:
55
+ """Extend short bootstrap samples so they survive the default filters."""
56
+ trailer = (
57
+ " This bootstrap record is intentionally longer so the repo's default "
58
+ "quality filters keep it during smoke-test data preparation and tokenizer training."
59
+ )
60
+ padded = text.strip()
61
+ while len(padded) < minimum_chars:
62
+ padded += trailer
63
+ return padded
64
+
65
+
66
+ def bootstrap_raw_corpora(output_dir: str = "data/raw", overwrite: bool = False) -> dict[str, int]:
67
+ """Write one small JSONL corpus file per registered source."""
68
+ root = Path(output_dir)
69
+ root.mkdir(parents=True, exist_ok=True)
70
+ counts: dict[str, int] = {}
71
+ for source_name, samples in BOOTSTRAP_CORPORA.items():
72
+ path = root / f"{source_name}.jsonl"
73
+ if path.exists() and not overwrite:
74
+ existing = sum(1 for _ in path.open("r", encoding="utf-8"))
75
+ counts[source_name] = existing
76
+ continue
77
+ with path.open("w", encoding="utf-8") as handle:
78
+ for index, text in enumerate(samples, start=1):
79
+ payload = {
80
+ "id": f"{source_name}-{index:04d}",
81
+ "text": _pad_sample(text),
82
+ "source_name": source_name,
83
+ }
84
+ handle.write(json.dumps(payload, ensure_ascii=False) + "\n")
85
+ counts[source_name] = len(samples)
86
+ return counts
87
+
88
+
89
+ def build_argparser() -> argparse.ArgumentParser:
90
+ """Build the CLI parser for corpus bootstrapping."""
91
+ parser = argparse.ArgumentParser(description="Create small JSONL corpora for SAGE smoke runs.")
92
+ parser.add_argument("--output-dir", default="data/raw", help="Directory for raw JSONL corpus files.")
93
+ parser.add_argument("--overwrite", action="store_true", help="Replace any existing bootstrap corpus files.")
94
+ return parser
95
+
96
+
97
+ def main() -> None:
98
+ """CLI entrypoint."""
99
+ args = build_argparser().parse_args()
100
+ summary = bootstrap_raw_corpora(output_dir=args.output_dir, overwrite=args.overwrite)
101
+ print(json.dumps({"output_dir": args.output_dir, "sources": summary}, indent=2, ensure_ascii=False))
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()
data/pipeline.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end raw-corpus to Parquet shard pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ from pathlib import Path
8
+
9
+ import sentencepiece as spm
10
+
11
+ from data.dedup import deduplicate_records
12
+ from data.filter import FilterConfig, filter_record
13
+ from data.ingest import SOURCE_REGISTRY, stream_source
14
+ from data.shard import ShardConfig, write_shards
15
+
16
+
17
+ def _select_sources(names: list[str] | None) -> tuple:
18
+ if not names:
19
+ return SOURCE_REGISTRY
20
+ wanted = set(names)
21
+ selected = tuple(spec for spec in SOURCE_REGISTRY if spec.name in wanted)
22
+ missing = sorted(wanted - {spec.name for spec in selected})
23
+ if missing:
24
+ raise ValueError(f"Unknown sources: {', '.join(missing)}")
25
+ return selected
26
+
27
+
28
+ def build_records(source_names: list[str] | None = None, limit_per_source: int | None = None) -> list[dict[str, object]]:
29
+ """Load, filter, and deduplicate records from the configured raw sources."""
30
+ records: list[dict[str, object]] = []
31
+ for spec in _select_sources(source_names):
32
+ source_records: list[dict[str, object]] = []
33
+ for record in stream_source(spec):
34
+ filtered = filter_record(record, FilterConfig())
35
+ if filtered is None:
36
+ continue
37
+ source_records.append(filtered)
38
+ if limit_per_source is not None and len(source_records) >= limit_per_source:
39
+ break
40
+ records.extend(source_records)
41
+ return deduplicate_records(records)
42
+
43
+
44
+ def run_pipeline(
45
+ tokenizer_model: str,
46
+ output_dir: str = "data/processed",
47
+ source_names: list[str] | None = None,
48
+ shard_size: int = 2048,
49
+ limit_per_source: int | None = None,
50
+ ) -> dict[str, object]:
51
+ """Create Parquet shards from the current raw JSONL corpora."""
52
+ tokenizer = spm.SentencePieceProcessor()
53
+ tokenizer.load(tokenizer_model)
54
+ records = build_records(source_names=source_names, limit_per_source=limit_per_source)
55
+ manifest = write_shards(records, tokenizer, ShardConfig(output_dir=output_dir, shard_size=shard_size))
56
+ summary = {
57
+ "tokenizer_model": tokenizer_model,
58
+ "output_dir": output_dir,
59
+ "records": len(records),
60
+ "sources": source_names or [spec.name for spec in SOURCE_REGISTRY],
61
+ "manifest": manifest,
62
+ }
63
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
64
+ (Path(output_dir) / "pipeline_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
65
+ return summary
66
+
67
+
68
+ def build_argparser() -> argparse.ArgumentParser:
69
+ """Build the CLI parser."""
70
+ parser = argparse.ArgumentParser(description="Filter, deduplicate, tokenize, and shard SAGE raw corpora.")
71
+ parser.add_argument("--tokenizer-model", default="tokenizer/tokenizer.model", help="SentencePiece tokenizer model.")
72
+ parser.add_argument("--output-dir", default="data/processed", help="Destination directory for parquet shards.")
73
+ parser.add_argument("--sources", nargs="*", default=None, help="Subset of source names from data.ingest.SOURCE_REGISTRY.")
74
+ parser.add_argument("--shard-size", type=int, default=2048, help="Rows per parquet shard.")
75
+ parser.add_argument("--limit-per-source", type=int, default=None, help="Optional cap for smoke-testing.")
76
+ return parser
77
+
78
+
79
+ def main() -> None:
80
+ """CLI entrypoint."""
81
+ args = build_argparser().parse_args()
82
+ summary = run_pipeline(
83
+ tokenizer_model=args.tokenizer_model,
84
+ output_dir=args.output_dir,
85
+ source_names=args.sources,
86
+ shard_size=args.shard_size,
87
+ limit_per_source=args.limit_per_source,
88
+ )
89
+ print(json.dumps(summary, indent=2))
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
docs/COMMANDS.md CHANGED
@@ -1,6 +1,6 @@
1
  # SAGE Commands
2
 
3
- This file is the short command-only reference for the repo.
4
 
5
  ## Install
6
 
@@ -14,25 +14,81 @@ pip install -r requirements.txt
14
  pytest -q
15
  ```
16
 
17
- ## Train tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  ```bash
20
  python -m tokenizer.train_tokenizer \
21
- --input data/raw/general_web.txt data/raw/code.txt \
22
- --model-prefix tokenizer/tokenizer \
23
  --model-prefix tokenizer/tokenizer \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ```
25
 
26
- ## Validate tokenizer
27
 
28
  ```bash
29
- bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model
 
 
 
 
30
  ```
31
 
32
- ## Start a short training smoke run
 
 
 
 
 
 
 
 
33
 
34
  ```bash
35
  python -m train.trainer \
 
 
36
  --train-shards data/processed/shard-00000.parquet \
37
  --validation-shards data/processed/shard-00001.parquet \
38
  --output-dir runs/smoke \
@@ -40,7 +96,7 @@ python -m train.trainer \
40
  --disable-wandb
41
  ```
42
 
43
- ## Start full training
44
 
45
  ```bash
46
  python -m train.trainer \
@@ -51,38 +107,93 @@ python -m train.trainer \
51
  --output-dir runs/sage-1b
52
  ```
53
 
54
- ## Run eval harness
 
 
55
 
56
  ```bash
57
- bash scripts/run_eval.sh
 
 
 
58
  ```
59
 
60
- ## Start GPU server
61
 
62
  ```bash
63
- export SAGE_WEB_PASSWORD=change-me
64
- bash scripts/run_serve.sh
65
  ```
66
 
67
- ## Start CPU server
68
 
69
  ```bash
70
- export SAGE_WEB_PASSWORD=change-me
71
  bash scripts/run_serve_cpu.sh
72
  ```
73
 
74
- The server root now hosts the browser control panel at `/`. Log in with `SAGE_WEB_PASSWORD`, then use presets or raw commands from the UI.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- ## Check server health
77
 
78
  ```bash
79
  curl http://127.0.0.1:8000/health
80
  ```
81
 
82
- ## Generate tokens from the API
83
 
84
  ```bash
85
  curl -X POST http://127.0.0.1:8000/generate \
86
  -H "Content-Type: application/json" \
87
  -d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
88
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # SAGE Commands
2
 
3
+ This is the repo's current command reference for data preparation, tokenizer training, model training, serving, browser control, and validation.
4
 
5
  ## Install
6
 
 
14
  pytest -q
15
  ```
16
 
17
+ ## 1. Create a starter dataset
18
+
19
+ This repo does not ship a large training corpus. The fastest way to unblock the pipeline is to generate the built-in smoke dataset first:
20
+
21
+ ```bash
22
+ python -m data.bootstrap --output-dir data/raw --overwrite
23
+ ```
24
+
25
+ That writes JSONL files like:
26
+
27
+ ```text
28
+ data/raw/general_web.jsonl
29
+ data/raw/code.jsonl
30
+ data/raw/math_science.jsonl
31
+ data/raw/multilingual.jsonl
32
+ data/raw/synthetic.jsonl
33
+ ```
34
+
35
+ If you want to use your own corpus, put JSONL records in the same folder with at least a `text` field:
36
+
37
+ ```json
38
+ {"text": "your training sample here"}
39
+ ```
40
+
41
+ ## 2. Train the tokenizer
42
+
43
+ The tokenizer trainer now accepts plain text files or JSONL files.
44
 
45
  ```bash
46
  python -m tokenizer.train_tokenizer \
47
+ --input data/raw/general_web.jsonl data/raw/code.jsonl data/raw/math_science.jsonl data/raw/multilingual.jsonl data/raw/synthetic.jsonl \
 
48
  --model-prefix tokenizer/tokenizer \
49
+ --vocab-size 4096 \
50
+ --training-text tokenizer/training_corpus.txt
51
+ ```
52
+
53
+ ## 3. Validate the tokenizer
54
+
55
+ ```bash
56
+ python -m tokenizer.validate_tokenizer tokenizer/tokenizer.model
57
+ ```
58
+
59
+ ## 4. Build parquet shards
60
+
61
+ ```bash
62
+ python -m data.pipeline \
63
+ --tokenizer-model tokenizer/tokenizer.model \
64
+ --output-dir data/processed \
65
+ --shard-size 128
66
  ```
67
 
68
+ For a short smoke run:
69
 
70
  ```bash
71
+ python -m data.pipeline \
72
+ --tokenizer-model tokenizer/tokenizer.model \
73
+ --output-dir data/processed \
74
+ --shard-size 32 \
75
+ --limit-per-source 4
76
  ```
77
 
78
+ The shell helper now points to the real data pipeline:
79
+
80
+ ```bash
81
+ bash scripts/run_data_pipeline.sh --tokenizer-model tokenizer/tokenizer.model --output-dir data/processed
82
+ ```
83
+
84
+ ## 5. Start training
85
+
86
+ Smoke run:
87
 
88
  ```bash
89
  python -m train.trainer \
90
+ --model-config configs/model/1b.yaml \
91
+ --schedule-config configs/train/schedule.yaml \
92
  --train-shards data/processed/shard-00000.parquet \
93
  --validation-shards data/processed/shard-00001.parquet \
94
  --output-dir runs/smoke \
 
96
  --disable-wandb
97
  ```
98
 
99
+ Longer run:
100
 
101
  ```bash
102
  python -m train.trainer \
 
107
  --output-dir runs/sage-1b
108
  ```
109
 
110
+ ## 6. Serve the model
111
+
112
+ GPU/PyTorch server:
113
 
114
  ```bash
115
+ $env:SAGE_WEB_PASSWORD="change-me"
116
+ $env:SAGE_CHECKPOINT_DIR="runs/sage-1b"
117
+ $env:SAGE_TOKENIZER_MODEL="tokenizer/tokenizer.model"
118
+ python -m uvicorn serve.server:app --host 0.0.0.0 --port 8000
119
  ```
120
 
121
+ CPU control-plane server:
122
 
123
  ```bash
124
+ $env:SAGE_WEB_PASSWORD="change-me"
125
+ python -m uvicorn serve.server_cpu:app --host 0.0.0.0 --port 8001
126
  ```
127
 
128
+ Helper scripts:
129
 
130
  ```bash
131
+ bash scripts/run_serve.sh
132
  bash scripts/run_serve_cpu.sh
133
  ```
134
 
135
+ ## 7. Browser control panel
136
+
137
+ Open the server root:
138
+
139
+ ```text
140
+ http://127.0.0.1:8000/
141
+ ```
142
+
143
+ The browser UI now supports:
144
+
145
+ - login with `SAGE_WEB_PASSWORD`
146
+ - dataset bootstrap preset
147
+ - shard-building preset
148
+ - tokenizer/train/eval/server presets
149
+ - raw shell commands
150
+ - live job logs
151
+ - direct model chat through `/chat`
152
+
153
+ ## 8. API commands
154
 
155
+ Health:
156
 
157
  ```bash
158
  curl http://127.0.0.1:8000/health
159
  ```
160
 
161
+ Generate from token ids:
162
 
163
  ```bash
164
  curl -X POST http://127.0.0.1:8000/generate \
165
  -H "Content-Type: application/json" \
166
  -d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
167
  ```
168
+
169
+ Chat from text:
170
+
171
+ ```bash
172
+ curl -X POST http://127.0.0.1:8000/chat \
173
+ -H "Content-Type: application/json" \
174
+ -d "{\"prompt\": \"Explain the training flow in this repo.\", \"max_new_tokens\": 64}"
175
+ ```
176
+
177
+ Chat status:
178
+
179
+ ```bash
180
+ curl http://127.0.0.1:8000/chat/status
181
+ ```
182
+
183
+ ## 9. Evaluation
184
+
185
+ ```bash
186
+ python -m eval.run_benchmarks
187
+ ```
188
+
189
+ Or use the helper:
190
+
191
+ ```bash
192
+ bash scripts/run_eval.sh
193
+ ```
194
+
195
+ ## 10. Hugging Face sync
196
+
197
+ ```bash
198
+ python hf_push.py
199
+ ```
scripts/run_data_pipeline.sh CHANGED
@@ -1,4 +1,4 @@
1
  #!/usr/bin/env bash
2
  set -euo pipefail
3
 
4
- python -m tokenizer.train_tokenizer "$@"
 
1
  #!/usr/bin/env bash
2
  set -euo pipefail
3
 
4
+ python -m data.pipeline "$@"
serve/control_plane.py CHANGED
@@ -158,6 +158,34 @@ def _build_presets(enable_generate: bool) -> list[CommandPreset]:
158
  "Call the local /health API and show the JSON response.",
159
  "api",
160
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  CommandPreset(
162
  "serve_gpu",
163
  "Serve GPU",
@@ -188,7 +216,7 @@ def _build_presets(enable_generate: bool) -> list[CommandPreset]:
188
  "input_paths",
189
  "Input Paths",
190
  kind="textarea",
191
- placeholder="data/raw/general_web.txt\ndata/raw/code.txt",
192
  required=True,
193
  ),
194
  PresetField("model_prefix", "Model Prefix", default="tokenizer/tokenizer"),
@@ -472,7 +500,49 @@ def _api_response(handler: Callable[[dict[str, Any]], dict[str, Any]], args: dic
472
  return {"kind": "api", "result": handler(args)}
473
 
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  def _build_command_for_preset(preset_id: str, args: dict[str, Any]) -> list[str] | str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  if preset_id == "serve_gpu":
477
  return [
478
  sys.executable,
@@ -607,6 +677,7 @@ def build_control_router(api_handlers: dict[str, Callable[[dict[str, Any]], dict
607
  preset = preset_map.get(payload.preset_id)
608
  if preset is None:
609
  raise HTTPException(status_code=404, detail=f"Unknown preset: {payload.preset_id}")
 
610
  if preset.mode == "api":
611
  handler = api_handlers.get(payload.preset_id)
612
  if handler is None:
@@ -614,11 +685,17 @@ def build_control_router(api_handlers: dict[str, Callable[[dict[str, Any]], dict
614
  return _api_response(handler, payload.args)
615
  command = _build_command_for_preset(payload.preset_id, payload.args)
616
  mode = "shell" if isinstance(command, str) else "job"
617
- job = CONTROL_MANAGER.start_job(preset.label, command, cwd=str(REPO_ROOT), mode=mode)
 
 
 
618
  return {"kind": "job", "job": job.to_dict()}
619
  if payload.command:
620
  cwd = payload.cwd or str(REPO_ROOT)
621
- job = CONTROL_MANAGER.start_job("Raw Command", payload.command, cwd=cwd, mode="shell")
 
 
 
622
  return {"kind": "job", "job": job.to_dict()}
623
  raise HTTPException(status_code=400, detail="Provide either preset_id or command.")
624
 
 
158
  "Call the local /health API and show the JSON response.",
159
  "api",
160
  ),
161
+ CommandPreset(
162
+ "data_bootstrap",
163
+ "Bootstrap Dataset",
164
+ "Create small JSONL corpora under data/raw for tokenizer and smoke-training runs.",
165
+ "job",
166
+ (
167
+ PresetField("output_dir", "Output Dir", default="data/raw"),
168
+ PresetField("overwrite", "Overwrite Existing Files", kind="boolean", default=False),
169
+ ),
170
+ ),
171
+ CommandPreset(
172
+ "data_pipeline",
173
+ "Build Data Shards",
174
+ "Filter raw JSONL corpora, deduplicate them, then write parquet shards with the trained tokenizer.",
175
+ "job",
176
+ (
177
+ PresetField("tokenizer_model", "Tokenizer Model", default="tokenizer/tokenizer.model"),
178
+ PresetField("output_dir", "Output Dir", default="data/processed"),
179
+ PresetField(
180
+ "sources",
181
+ "Sources",
182
+ kind="textarea",
183
+ placeholder="general_web\ncode\nmath_science\nmultilingual\nsynthetic",
184
+ ),
185
+ PresetField("shard_size", "Shard Size", kind="number", default=2048),
186
+ PresetField("limit_per_source", "Limit Per Source", kind="number", default=0),
187
+ ),
188
+ ),
189
  CommandPreset(
190
  "serve_gpu",
191
  "Serve GPU",
 
216
  "input_paths",
217
  "Input Paths",
218
  kind="textarea",
219
+ placeholder="data/raw/general_web.jsonl\ndata/raw/code.jsonl",
220
  required=True,
221
  ),
222
  PresetField("model_prefix", "Model Prefix", default="tokenizer/tokenizer"),
 
500
  return {"kind": "api", "result": handler(args)}
501
 
502
 
503
+ def _validate_preset_args(preset: CommandPreset, args: dict[str, Any]) -> None:
504
+ missing: list[str] = []
505
+ for field in preset.fields:
506
+ if not field.required:
507
+ continue
508
+ value = args.get(field.name)
509
+ if value is None:
510
+ missing.append(field.label)
511
+ continue
512
+ if isinstance(value, str) and not value.strip():
513
+ missing.append(field.label)
514
+ continue
515
+ if isinstance(value, list) and not value:
516
+ missing.append(field.label)
517
+ if missing:
518
+ raise HTTPException(status_code=400, detail=f"Missing required fields: {', '.join(missing)}")
519
+
520
+
521
  def _build_command_for_preset(preset_id: str, args: dict[str, Any]) -> list[str] | str:
522
+ if preset_id == "data_bootstrap":
523
+ command = [sys.executable, "-m", "data.bootstrap", "--output-dir", str(args.get("output_dir") or "data/raw")]
524
+ if bool(args.get("overwrite", False)):
525
+ command.append("--overwrite")
526
+ return command
527
+ if preset_id == "data_pipeline":
528
+ command = [
529
+ sys.executable,
530
+ "-m",
531
+ "data.pipeline",
532
+ "--tokenizer-model",
533
+ str(args.get("tokenizer_model") or "tokenizer/tokenizer.model"),
534
+ "--output-dir",
535
+ str(args.get("output_dir") or "data/processed"),
536
+ "--shard-size",
537
+ str(_parse_number(args.get("shard_size"), 2048)),
538
+ ]
539
+ sources = _split_multi_value(args.get("sources"))
540
+ if sources:
541
+ command.extend(["--sources", *sources])
542
+ limit_per_source = _parse_number(args.get("limit_per_source"), 0)
543
+ if limit_per_source > 0:
544
+ command.extend(["--limit-per-source", str(limit_per_source)])
545
+ return command
546
  if preset_id == "serve_gpu":
547
  return [
548
  sys.executable,
 
677
  preset = preset_map.get(payload.preset_id)
678
  if preset is None:
679
  raise HTTPException(status_code=404, detail=f"Unknown preset: {payload.preset_id}")
680
+ _validate_preset_args(preset, payload.args)
681
  if preset.mode == "api":
682
  handler = api_handlers.get(payload.preset_id)
683
  if handler is None:
 
685
  return _api_response(handler, payload.args)
686
  command = _build_command_for_preset(payload.preset_id, payload.args)
687
  mode = "shell" if isinstance(command, str) else "job"
688
+ try:
689
+ job = CONTROL_MANAGER.start_job(preset.label, command, cwd=str(REPO_ROOT), mode=mode)
690
+ except OSError as exc:
691
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
692
  return {"kind": "job", "job": job.to_dict()}
693
  if payload.command:
694
  cwd = payload.cwd or str(REPO_ROOT)
695
+ try:
696
+ job = CONTROL_MANAGER.start_job("Raw Command", payload.command, cwd=cwd, mode="shell")
697
+ except OSError as exc:
698
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
699
  return {"kind": "job", "job": job.to_dict()}
700
  raise HTTPException(status_code=400, detail="Provide either preset_id or command.")
701
 
serve/server.py CHANGED
@@ -2,7 +2,9 @@
2
 
3
  from __future__ import annotations
4
 
5
- from typing import Optional
 
 
6
 
7
  import torch
8
  from fastapi import FastAPI
@@ -11,13 +13,21 @@ from pydantic import BaseModel
11
  from model.config import ModelConfig
12
  from model.model import SageTransformer
13
  from serve.control_plane import build_control_router
14
- from serve.kv_cache import KVCache
15
  from train.hardware import HardwareConfig
16
 
17
 
18
  app = FastAPI(title="SAGE Server")
19
  _MODEL: SageTransformer | None = None
20
  _TOKENIZER = None
 
 
 
 
 
 
 
 
21
 
22
 
23
  class GenerationRequest(BaseModel):
@@ -27,37 +37,152 @@ class GenerationRequest(BaseModel):
27
  max_new_tokens: int = 32
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def get_model() -> SageTransformer:
31
  """Lazily create the model for server startup."""
32
- global _MODEL
33
  if _MODEL is None:
34
- _MODEL = SageTransformer(ModelConfig())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  _MODEL.eval()
 
 
 
 
36
  return _MODEL
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  @app.get("/health")
40
  def health() -> dict[str, object]:
41
  """Return basic health and hardware information."""
42
  hw = HardwareConfig(model_size_b=1.0, context_length=4096)
43
- return {"status": "ok", "hardware": hw.summary()}
44
 
45
 
46
  @app.post("/generate")
47
  def generate(request: GenerationRequest) -> dict[str, object]:
48
  """Generate continuation token ids from an input token list."""
49
- model = get_model()
50
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
- model.to(device)
52
- input_ids = torch.tensor([request.input_ids], dtype=torch.long, device=device)
53
- generated = list(request.input_ids)
54
- cache: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None
55
- for _ in range(request.max_new_tokens):
56
- logits, cache = model(input_ids[:, -1:] if cache is not None else input_ids, past_key_values=cache)
57
- next_token = int(torch.argmax(logits[:, -1, :], dim=-1).item())
58
- generated.append(next_token)
59
- input_ids = torch.tensor([[next_token]], dtype=torch.long, device=device)
60
- return {"tokens": generated}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def _health_action(_: dict[str, object]) -> dict[str, object]:
 
2
 
3
  from __future__ import annotations
4
 
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any, Optional
8
 
9
  import torch
10
  from fastapi import FastAPI
 
13
  from model.config import ModelConfig
14
  from model.model import SageTransformer
15
  from serve.control_plane import build_control_router
16
+ from train.checkpoint import load_latest_checkpoint
17
  from train.hardware import HardwareConfig
18
 
19
 
20
  app = FastAPI(title="SAGE Server")
21
  _MODEL: SageTransformer | None = None
22
  _TOKENIZER = None
23
+ _MODEL_DEVICE: torch.device | None = None
24
+ _MODEL_STATE: dict[str, Any] = {
25
+ "model_config": None,
26
+ "checkpoint_dir": None,
27
+ "checkpoint_loaded": False,
28
+ "checkpoint_step": 0,
29
+ "tokenizer_path": None,
30
+ }
31
 
32
 
33
  class GenerationRequest(BaseModel):
 
37
  max_new_tokens: int = 32
38
 
39
 
40
+ class ChatRequest(BaseModel):
41
+ """Request schema for text generation through the tokenizer."""
42
+
43
+ prompt: str
44
+ max_new_tokens: int = 64
45
+
46
+
47
+ def get_generation_device() -> torch.device:
48
+ """Pick the active inference device."""
49
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+
52
+ def _resolve_model_config_path() -> Path:
53
+ configured = Path(os.environ.get("SAGE_MODEL_CONFIG", "configs/model/1b.yaml"))
54
+ return configured if configured.exists() else Path("configs/model/1b.yaml")
55
+
56
+
57
+ def _resolve_checkpoint_dir() -> Path:
58
+ return Path(os.environ.get("SAGE_CHECKPOINT_DIR", "runs/default"))
59
+
60
+
61
+ def _resolve_tokenizer_path() -> Path:
62
+ return Path(os.environ.get("SAGE_TOKENIZER_MODEL", "tokenizer/tokenizer.model"))
63
+
64
+
65
  def get_model() -> SageTransformer:
66
  """Lazily create the model for server startup."""
67
+ global _MODEL, _MODEL_DEVICE
68
  if _MODEL is None:
69
+ config_path = _resolve_model_config_path()
70
+ config = ModelConfig.from_yaml(config_path) if config_path.exists() else ModelConfig()
71
+ _MODEL = SageTransformer(config)
72
+ checkpoint_dir = _resolve_checkpoint_dir()
73
+ checkpoint_step = 0
74
+ if checkpoint_dir.exists():
75
+ checkpoint_step = load_latest_checkpoint(_MODEL, None, None, None, str(checkpoint_dir), device="cpu")
76
+ _MODEL_STATE.update(
77
+ {
78
+ "model_config": str(config_path),
79
+ "checkpoint_dir": str(checkpoint_dir),
80
+ "checkpoint_loaded": checkpoint_step > 0,
81
+ "checkpoint_step": checkpoint_step,
82
+ }
83
+ )
84
  _MODEL.eval()
85
+ device = get_generation_device()
86
+ if _MODEL_DEVICE != device:
87
+ _MODEL = _MODEL.to(device)
88
+ _MODEL_DEVICE = device
89
  return _MODEL
90
 
91
 
92
+ def get_tokenizer():
93
+ """Lazily load the SentencePiece tokenizer if present."""
94
+ global _TOKENIZER
95
+ if _TOKENIZER is None:
96
+ tokenizer_path = _resolve_tokenizer_path()
97
+ _MODEL_STATE["tokenizer_path"] = str(tokenizer_path)
98
+ if not tokenizer_path.exists():
99
+ return None
100
+ from tokenizer.validate_tokenizer import load_processor
101
+
102
+ _TOKENIZER = load_processor(str(tokenizer_path))
103
+ return _TOKENIZER
104
+
105
+
106
+ def _generate_token_ids(input_ids: list[int], max_new_tokens: int) -> list[int]:
107
+ """Run greedy decoding from input token ids."""
108
+ model = get_model()
109
+ device = get_generation_device()
110
+ generated = list(input_ids)
111
+ tensor_ids = torch.tensor([input_ids], dtype=torch.long, device=device)
112
+ cache: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None
113
+ with torch.inference_mode():
114
+ for _ in range(max(0, int(max_new_tokens))):
115
+ logits, cache = model(tensor_ids[:, -1:] if cache is not None else tensor_ids, past_key_values=cache)
116
+ next_token = int(torch.argmax(logits[:, -1, :], dim=-1).item())
117
+ generated.append(next_token)
118
+ tensor_ids = torch.tensor([[next_token]], dtype=torch.long, device=device)
119
+ return generated
120
+
121
+
122
+ def chat_status() -> dict[str, object]:
123
+ """Return whether text chat is configured for the current server."""
124
+ tokenizer = get_tokenizer()
125
+ checkpoint_loaded = bool(_MODEL_STATE["checkpoint_loaded"])
126
+ available = tokenizer is not None
127
+ warning = None
128
+ if tokenizer is None:
129
+ warning = "Tokenizer model not found. Train or place tokenizer/tokenizer.model before using browser chat."
130
+ elif not checkpoint_loaded:
131
+ warning = "No checkpoint loaded. Chat will run with randomly initialized model weights until you train or configure SAGE_CHECKPOINT_DIR."
132
+ return {
133
+ "available": available,
134
+ "tokenizer_path": _MODEL_STATE["tokenizer_path"],
135
+ "checkpoint_dir": _MODEL_STATE["checkpoint_dir"],
136
+ "checkpoint_loaded": checkpoint_loaded,
137
+ "checkpoint_step": _MODEL_STATE["checkpoint_step"],
138
+ "warning": warning,
139
+ }
140
+
141
+
142
  @app.get("/health")
143
  def health() -> dict[str, object]:
144
  """Return basic health and hardware information."""
145
  hw = HardwareConfig(model_size_b=1.0, context_length=4096)
146
+ return {"status": "ok", "hardware": hw.summary(), "chat": chat_status()}
147
 
148
 
149
  @app.post("/generate")
150
  def generate(request: GenerationRequest) -> dict[str, object]:
151
  """Generate continuation token ids from an input token list."""
152
+ return {"tokens": _generate_token_ids(request.input_ids, request.max_new_tokens)}
153
+
154
+
155
+ @app.get("/chat/status")
156
+ def get_chat_status() -> dict[str, object]:
157
+ """Expose browser-chat readiness."""
158
+ return chat_status()
159
+
160
+
161
+ @app.post("/chat")
162
+ def chat(request: ChatRequest) -> dict[str, object]:
163
+ """Generate text from a prompt using the local tokenizer."""
164
+ tokenizer = get_tokenizer()
165
+ if tokenizer is None:
166
+ return {
167
+ "success": False,
168
+ "detail": "Tokenizer model not found. Train the tokenizer first or set SAGE_TOKENIZER_MODEL.",
169
+ **chat_status(),
170
+ }
171
+ prompt = request.prompt.strip()
172
+ if not prompt:
173
+ return {"success": False, "detail": "Prompt cannot be empty.", **chat_status()}
174
+ prompt_ids = list(tokenizer.encode(prompt, out_type=int))
175
+ generated = _generate_token_ids(prompt_ids, request.max_new_tokens)
176
+ completion_ids = generated[len(prompt_ids) :]
177
+ return {
178
+ "success": True,
179
+ "prompt": prompt,
180
+ "response": tokenizer.decode(completion_ids),
181
+ "input_ids": prompt_ids,
182
+ "output_ids": generated,
183
+ "new_token_ids": completion_ids,
184
+ **chat_status(),
185
+ }
186
 
187
 
188
  def _health_action(_: dict[str, object]) -> dict[str, object]:
serve/server_cpu.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import shutil
6
 
7
  from fastapi import FastAPI
 
8
 
9
  from serve.control_plane import build_control_router
10
 
@@ -12,10 +13,37 @@ from serve.control_plane import build_control_router
12
  app = FastAPI(title="SAGE CPU Server")
13
 
14
 
 
 
 
 
 
 
 
15
  @app.get("/health")
16
  def health() -> dict[str, object]:
17
  """Report llama.cpp availability for CPU serving."""
18
- return {"status": "ok", "llama_cpp_available": shutil.which("llama-server") is not None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def _health_action(_: dict[str, object]) -> dict[str, object]:
 
5
  import shutil
6
 
7
  from fastapi import FastAPI
8
+ from pydantic import BaseModel
9
 
10
  from serve.control_plane import build_control_router
11
 
 
13
  app = FastAPI(title="SAGE CPU Server")
14
 
15
 
16
+ class ChatRequest(BaseModel):
17
+ """Request schema for the browser chat surface."""
18
+
19
+ prompt: str
20
+ max_new_tokens: int = 64
21
+
22
+
23
  @app.get("/health")
24
  def health() -> dict[str, object]:
25
  """Report llama.cpp availability for CPU serving."""
26
+ return {"status": "ok", "llama_cpp_available": shutil.which("llama-server") is not None, "chat": chat_status()}
27
+
28
+
29
+ def chat_status() -> dict[str, object]:
30
+ """Return chat readiness for the CPU server."""
31
+ return {
32
+ "available": False,
33
+ "warning": "Browser chat is only wired to the PyTorch GPU server in this repo. Use serve.server:app for direct interaction.",
34
+ }
35
+
36
+
37
+ @app.get("/chat/status")
38
+ def get_chat_status() -> dict[str, object]:
39
+ """Expose browser-chat readiness."""
40
+ return chat_status()
41
+
42
+
43
+ @app.post("/chat")
44
+ def chat(_: ChatRequest) -> dict[str, object]:
45
+ """Return a clear error for CPU-only control-plane mode."""
46
+ return {"success": False, "detail": chat_status()["warning"], **chat_status()}
47
 
48
 
49
  def _health_action(_: dict[str, object]) -> dict[str, object]:
serve/static/index.html CHANGED
@@ -139,6 +139,36 @@
139
  font-size: 12px;
140
  color: var(--muted);
141
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  @media (max-width: 980px) {
143
  .grid { grid-template-columns: 1fr; }
144
  }
@@ -188,6 +218,21 @@
188
  <h2>API Result</h2>
189
  <div id="api-result" class="mono">No API result yet.</div>
190
  </section>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  </div>
192
 
193
  <div>
@@ -215,6 +260,7 @@
215
  selectedJobId: null,
216
  stream: null,
217
  pollTimer: null,
 
218
  };
219
 
220
  const loginPanel = document.getElementById("login-panel");
@@ -229,6 +275,11 @@
229
  const logsEl = document.getElementById("logs");
230
  const selectedJobEl = document.getElementById("selected-job");
231
  const apiResultEl = document.getElementById("api-result");
 
 
 
 
 
232
 
233
  async function api(path, options = {}) {
234
  const response = await fetch(path, {
@@ -265,6 +316,28 @@
265
  return state.presets.find((item) => item.id === presetSelect.value);
266
  }
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  function renderPresetFields() {
269
  const preset = currentPreset();
270
  presetFields.innerHTML = "";
@@ -275,7 +348,7 @@
275
  for (const field of preset.fields) {
276
  const wrapper = document.createElement("div");
277
  const label = document.createElement("label");
278
- label.textContent = field.label;
279
  label.htmlFor = `field-${field.name}`;
280
  wrapper.appendChild(label);
281
 
@@ -308,6 +381,7 @@
308
  input.id = `field-${field.name}`;
309
  input.dataset.kind = field.kind;
310
  input.dataset.name = field.name;
 
311
  input.placeholder = field.placeholder || "";
312
  wrapper.appendChild(input);
313
  presetFields.appendChild(wrapper);
@@ -319,13 +393,20 @@
319
  for (const element of presetFields.querySelectorAll("[data-name]")) {
320
  const name = element.dataset.name;
321
  const kind = element.dataset.kind;
 
322
  if (kind === "boolean") {
323
  args[name] = element.checked;
324
  } else if (kind === "number") {
325
  args[name] = element.value === "" ? "" : Number(element.value);
326
  } else if (kind === "json") {
 
 
 
327
  args[name] = element.value.trim() ? JSON.parse(element.value) : null;
328
  } else {
 
 
 
329
  args[name] = element.value;
330
  }
331
  }
@@ -388,6 +469,15 @@
388
  renderJobs(payload.jobs || []);
389
  }
390
 
 
 
 
 
 
 
 
 
 
391
  function appendLog(line) {
392
  if (!line) {
393
  return;
@@ -432,6 +522,7 @@
432
  presetSelect.appendChild(option);
433
  }
434
  renderPresetFields();
 
435
  }
436
 
437
  async function login() {
@@ -492,6 +583,35 @@
492
  apiResultEl.textContent = String(error.message || error);
493
  }
494
  });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
  loadPresets().then(() => {
497
  showApp();
 
139
  font-size: 12px;
140
  color: var(--muted);
141
  }
142
+ .chat-status {
143
+ margin-bottom: 12px;
144
+ font-size: 13px;
145
+ color: var(--muted);
146
+ }
147
+ .chat-messages {
148
+ min-height: 220px;
149
+ max-height: 360px;
150
+ overflow: auto;
151
+ padding: 12px;
152
+ border-radius: 10px;
153
+ border: 1px solid var(--panel-border);
154
+ background: #0b1015;
155
+ display: grid;
156
+ gap: 10px;
157
+ }
158
+ .message {
159
+ border-radius: 10px;
160
+ padding: 10px 12px;
161
+ white-space: pre-wrap;
162
+ line-height: 1.5;
163
+ }
164
+ .message-user {
165
+ background: #0f2330;
166
+ border: 1px solid #1e3c50;
167
+ }
168
+ .message-model {
169
+ background: #151b22;
170
+ border: 1px solid var(--panel-border);
171
+ }
172
  @media (max-width: 980px) {
173
  .grid { grid-template-columns: 1fr; }
174
  }
 
218
  <h2>API Result</h2>
219
  <div id="api-result" class="mono">No API result yet.</div>
220
  </section>
221
+
222
+ <section class="panel">
223
+ <h2>Model Chat</h2>
224
+ <div id="chat-status" class="chat-status">Checking chat status...</div>
225
+ <div id="chat-messages" class="chat-messages">
226
+ <div class="message message-model">Prompt the local model here once a tokenizer exists. If no checkpoint is loaded yet, outputs will be random.</div>
227
+ </div>
228
+ <label for="chat-prompt">Prompt</label>
229
+ <textarea id="chat-prompt" placeholder="Explain what data files I need before training this repo."></textarea>
230
+ <label for="chat-max-tokens">Max New Tokens</label>
231
+ <input id="chat-max-tokens" type="number" value="96" min="1" max="512">
232
+ <div class="button-row">
233
+ <button id="chat-send" class="primary">Send Prompt</button>
234
+ </div>
235
+ </section>
236
  </div>
237
 
238
  <div>
 
260
  selectedJobId: null,
261
  stream: null,
262
  pollTimer: null,
263
+ chatStatus: null,
264
  };
265
 
266
  const loginPanel = document.getElementById("login-panel");
 
275
  const logsEl = document.getElementById("logs");
276
  const selectedJobEl = document.getElementById("selected-job");
277
  const apiResultEl = document.getElementById("api-result");
278
+ const chatStatusEl = document.getElementById("chat-status");
279
+ const chatMessagesEl = document.getElementById("chat-messages");
280
+ const chatPromptEl = document.getElementById("chat-prompt");
281
+ const chatMaxTokensEl = document.getElementById("chat-max-tokens");
282
+ const chatSendEl = document.getElementById("chat-send");
283
 
284
  async function api(path, options = {}) {
285
  const response = await fetch(path, {
 
316
  return state.presets.find((item) => item.id === presetSelect.value);
317
  }
318
 
319
+ function appendMessage(role, text) {
320
+ const block = document.createElement("div");
321
+ block.className = `message ${role === "user" ? "message-user" : "message-model"}`;
322
+ block.textContent = text;
323
+ chatMessagesEl.appendChild(block);
324
+ chatMessagesEl.scrollTop = chatMessagesEl.scrollHeight;
325
+ }
326
+
327
+ function renderChatStatus(payload) {
328
+ state.chatStatus = payload;
329
+ const warning = payload && payload.warning ? ` ${payload.warning}` : "";
330
+ if (!payload) {
331
+ chatStatusEl.textContent = "Chat status unavailable.";
332
+ chatSendEl.disabled = true;
333
+ return;
334
+ }
335
+ chatStatusEl.textContent = payload.available
336
+ ? `Chat ready.${warning}`
337
+ : `Chat unavailable.${warning}`;
338
+ chatSendEl.disabled = !payload.available;
339
+ }
340
+
341
  function renderPresetFields() {
342
  const preset = currentPreset();
343
  presetFields.innerHTML = "";
 
348
  for (const field of preset.fields) {
349
  const wrapper = document.createElement("div");
350
  const label = document.createElement("label");
351
+ label.textContent = field.required ? `${field.label} *` : field.label;
352
  label.htmlFor = `field-${field.name}`;
353
  wrapper.appendChild(label);
354
 
 
381
  input.id = `field-${field.name}`;
382
  input.dataset.kind = field.kind;
383
  input.dataset.name = field.name;
384
+ input.dataset.required = field.required ? "true" : "false";
385
  input.placeholder = field.placeholder || "";
386
  wrapper.appendChild(input);
387
  presetFields.appendChild(wrapper);
 
393
  for (const element of presetFields.querySelectorAll("[data-name]")) {
394
  const name = element.dataset.name;
395
  const kind = element.dataset.kind;
396
+ const required = element.dataset.required === "true";
397
  if (kind === "boolean") {
398
  args[name] = element.checked;
399
  } else if (kind === "number") {
400
  args[name] = element.value === "" ? "" : Number(element.value);
401
  } else if (kind === "json") {
402
+ if (required && !element.value.trim()) {
403
+ throw new Error(`Field ${name} is required.`);
404
+ }
405
  args[name] = element.value.trim() ? JSON.parse(element.value) : null;
406
  } else {
407
+ if (required && !element.value.trim()) {
408
+ throw new Error(`Field ${name} is required.`);
409
+ }
410
  args[name] = element.value;
411
  }
412
  }
 
469
  renderJobs(payload.jobs || []);
470
  }
471
 
472
+ async function loadChatStatus() {
473
+ try {
474
+ const payload = await api("/chat/status");
475
+ renderChatStatus(payload);
476
+ } catch (error) {
477
+ renderChatStatus({ available: false, warning: String(error.message || error) });
478
+ }
479
+ }
480
+
481
  function appendLog(line) {
482
  if (!line) {
483
  return;
 
522
  presetSelect.appendChild(option);
523
  }
524
  renderPresetFields();
525
+ await loadChatStatus();
526
  }
527
 
528
  async function login() {
 
583
  apiResultEl.textContent = String(error.message || error);
584
  }
585
  });
586
+ chatSendEl.addEventListener("click", async () => {
587
+ const prompt = chatPromptEl.value.trim();
588
+ if (!prompt) {
589
+ apiResultEl.textContent = "Prompt cannot be empty.";
590
+ return;
591
+ }
592
+ appendMessage("user", prompt);
593
+ chatPromptEl.value = "";
594
+ chatSendEl.disabled = true;
595
+ try {
596
+ const payload = await api("/chat", {
597
+ method: "POST",
598
+ body: JSON.stringify({
599
+ prompt,
600
+ max_new_tokens: Number(chatMaxTokensEl.value || 96),
601
+ }),
602
+ });
603
+ renderChatStatus(payload);
604
+ if (!payload.success) {
605
+ appendMessage("model", payload.detail || "Chat failed.");
606
+ return;
607
+ }
608
+ appendMessage("model", payload.response || "[empty response]");
609
+ } catch (error) {
610
+ appendMessage("model", String(error.message || error));
611
+ } finally {
612
+ chatSendEl.disabled = !(state.chatStatus && state.chatStatus.available);
613
+ }
614
+ });
615
 
616
  loadPresets().then(() => {
617
  showApp();
tests/test_control_plane.py CHANGED
@@ -50,6 +50,8 @@ def test_login_and_html_index(monkeypatch) -> None:
50
  assert response.status_code == 200
51
  payload = response.json()
52
  preset_ids = {item["id"] for item in payload["presets"]}
 
 
53
  assert "serve_cpu" in preset_ids
54
  assert "git_status" in preset_ids
55
 
@@ -120,3 +122,14 @@ def test_health_api_preset(monkeypatch) -> None:
120
  payload = response.json()
121
  assert payload["kind"] == "api"
122
  assert payload["result"]["status"] == "ok"
 
 
 
 
 
 
 
 
 
 
 
 
50
  assert response.status_code == 200
51
  payload = response.json()
52
  preset_ids = {item["id"] for item in payload["presets"]}
53
+ assert "data_bootstrap" in preset_ids
54
+ assert "data_pipeline" in preset_ids
55
  assert "serve_cpu" in preset_ids
56
  assert "git_status" in preset_ids
57
 
 
122
  payload = response.json()
123
  assert payload["kind"] == "api"
124
  assert payload["result"]["status"] == "ok"
125
+
126
+
127
+ def test_required_preset_field_validation(monkeypatch) -> None:
128
+ monkeypatch.setenv("SAGE_WEB_PASSWORD", "test-password")
129
+ CONTROL_MANAGER.reset_for_tests()
130
+ client = TestClient(app)
131
+ _login(client)
132
+
133
+ response = client.post("/api/commands/run", json={"preset_id": "tokenizer_train", "args": {"input_paths": ""}})
134
+ assert response.status_code == 400
135
+ assert "Input Paths" in response.json()["detail"]
tests/test_data_pipeline.py CHANGED
@@ -1,6 +1,13 @@
 
 
 
1
  from data.dataset import pack_sequence
 
2
  from data.dedup import deduplicate_records
3
  from data.filter import filter_record
 
 
 
4
 
5
 
6
  def test_filter_record_masks_pii() -> None:
@@ -31,3 +38,48 @@ def test_pack_sequence_shapes() -> None:
31
  assert packed["input_ids"].tolist() == [1, 2, 3, 4]
32
  assert packed["labels"].tolist() == [2, 3, 4, 5]
33
  assert packed["document_boundaries"].tolist() == [0, 0, 1, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
  from data.dataset import pack_sequence
5
+ from data.bootstrap import bootstrap_raw_corpora
6
  from data.dedup import deduplicate_records
7
  from data.filter import filter_record
8
+ from data.ingest import SourceSpec
9
+
10
+ import pytest
11
 
12
 
13
  def test_filter_record_masks_pii() -> None:
 
38
  assert packed["input_ids"].tolist() == [1, 2, 3, 4]
39
  assert packed["labels"].tolist() == [2, 3, 4, 5]
40
  assert packed["document_boundaries"].tolist() == [0, 0, 1, 0]
41
+
42
+
43
+ def test_bootstrap_raw_corpora_writes_jsonl(tmp_path: Path) -> None:
44
+ summary = bootstrap_raw_corpora(output_dir=str(tmp_path), overwrite=True)
45
+ assert summary["general_web"] > 0
46
+ sample_path = tmp_path / "general_web.jsonl"
47
+ first = json.loads(sample_path.read_text(encoding="utf-8").splitlines()[0])
48
+ assert "text" in first
49
+ assert len(first["text"]) >= 240
50
+
51
+
52
+ def test_pipeline_writes_manifest_from_bootstrap_data(tmp_path: Path, monkeypatch) -> None:
53
+ pytest.importorskip("sentencepiece")
54
+ from data import pipeline
55
+ from tokenizer.train_tokenizer import train_sentencepiece, write_training_text
56
+
57
+ raw_dir = tmp_path / "raw"
58
+ bootstrap_raw_corpora(output_dir=str(raw_dir), overwrite=True)
59
+ training_text = tmp_path / "training.txt"
60
+ write_training_text([str(path) for path in raw_dir.glob("*.jsonl")], str(training_text))
61
+ prefix = tmp_path / "tokenizer"
62
+ train_sentencepiece(str(training_text), str(prefix), vocab_size=512)
63
+
64
+ registry = tuple(
65
+ SourceSpec(
66
+ name=path.stem,
67
+ domain_tag="general",
68
+ quality_tier="high",
69
+ license_category="permissive",
70
+ estimated_tokens=1_000,
71
+ path=str(path),
72
+ )
73
+ for path in raw_dir.glob("*.jsonl")
74
+ )
75
+ monkeypatch.setattr(pipeline, "SOURCE_REGISTRY", registry)
76
+
77
+ output_dir = tmp_path / "processed"
78
+ summary = pipeline.run_pipeline(
79
+ tokenizer_model=str(prefix) + ".model",
80
+ output_dir=str(output_dir),
81
+ shard_size=4,
82
+ )
83
+ manifest_path = output_dir / "manifest.json"
84
+ assert summary["records"] > 0
85
+ assert manifest_path.exists()
tests/test_servers.py CHANGED
@@ -45,3 +45,47 @@ def test_gpu_server_generate(monkeypatch) -> None:
45
  assert response.status_code == 200
46
  payload = response.json()
47
  assert payload["tokens"] == [1, 2, 3, 3, 3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  assert response.status_code == 200
46
  payload = response.json()
47
  assert payload["tokens"] == [1, 2, 3, 3, 3]
48
+
49
+
50
+ def test_gpu_server_chat(monkeypatch) -> None:
51
+ class FakeModel:
52
+ def eval(self) -> "FakeModel":
53
+ return self
54
+
55
+ def to(self, _device) -> "FakeModel":
56
+ return self
57
+
58
+ def __call__(self, input_ids, past_key_values=None):
59
+ import torch
60
+
61
+ batch, seq = input_ids.shape
62
+ logits = torch.zeros((batch, seq, 32), dtype=torch.float32)
63
+ logits[:, :, 7] = 1.0
64
+ cache = [] if past_key_values is None else past_key_values
65
+ return logits, cache
66
+
67
+ class FakeTokenizer:
68
+ def encode(self, text, out_type=int):
69
+ return [1, 2, 3]
70
+
71
+ def decode(self, ids):
72
+ return "decoded:" + ",".join(str(item) for item in ids)
73
+
74
+ monkeypatch.setattr(gpu_server, "get_model", lambda: FakeModel())
75
+ monkeypatch.setattr(gpu_server, "get_tokenizer", lambda: FakeTokenizer())
76
+ monkeypatch.setattr(gpu_server, "chat_status", lambda: {"available": True, "warning": None, "checkpoint_loaded": False})
77
+ monkeypatch.setattr(gpu_server.torch.cuda, "is_available", lambda: False)
78
+ client = TestClient(gpu_app)
79
+ response = client.post("/chat", json={"prompt": "hello", "max_new_tokens": 2})
80
+ assert response.status_code == 200
81
+ payload = response.json()
82
+ assert payload["success"] is True
83
+ assert payload["response"] == "decoded:7,7"
84
+
85
+
86
+ def test_cpu_server_chat_status() -> None:
87
+ client = TestClient(cpu_app)
88
+ response = client.get("/chat/status")
89
+ assert response.status_code == 200
90
+ payload = response.json()
91
+ assert payload["available"] is False
tests/test_tokenizer.py CHANGED
@@ -1,4 +1,5 @@
1
  from pathlib import Path
 
2
 
3
  import pytest
4
 
@@ -23,3 +24,22 @@ def test_validation_suite_roundtrip(tmp_path: Path) -> None:
23
  train_sentencepiece(str(corpus), str(prefix), vocab_size=512)
24
  results = run_validation_suite(str(prefix) + ".model")
25
  assert all(result.passed for result in results), results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
+ import json
3
 
4
  import pytest
5
 
 
24
  train_sentencepiece(str(corpus), str(prefix), vocab_size=512)
25
  results = run_validation_suite(str(prefix) + ".model")
26
  assert all(result.passed for result in results), results
27
+
28
+
29
+ def test_write_training_text_reads_jsonl(tmp_path: Path) -> None:
30
+ from tokenizer.train_tokenizer import write_training_text
31
+
32
+ raw = tmp_path / "raw.jsonl"
33
+ raw.write_text(
34
+ "\n".join(
35
+ [
36
+ json.dumps({"text": "first training sample"}),
37
+ json.dumps({"text": "second training sample"}),
38
+ ]
39
+ )
40
+ + "\n",
41
+ encoding="utf-8",
42
+ )
43
+ output = tmp_path / "combined.txt"
44
+ write_training_text([str(raw)], str(output))
45
+ assert output.read_text(encoding="utf-8").splitlines() == ["first training sample", "second training sample"]
tokenizer/train_tokenizer.py CHANGED
@@ -3,32 +3,50 @@
3
  from __future__ import annotations
4
 
5
  import argparse
 
6
  from pathlib import Path
7
- from typing import Iterable
8
 
9
- import sentencepiece as spm
10
 
11
 
12
- DEFAULT_SPECIAL_TOKENS = ("<bos>", "<eos>", "<pad>", "<unk>", "[INST]", "[/INST]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
- def write_training_text(corpus_paths: Iterable[str], output_path: str) -> str:
16
  """Concatenate corpus text into a plain-text file for SentencePiece."""
17
  output = Path(output_path)
18
  output.parent.mkdir(parents=True, exist_ok=True)
19
  with output.open("w", encoding="utf-8") as sink:
20
- for path in corpus_paths:
21
- with Path(path).open("r", encoding="utf-8") as source:
22
- for line in source:
23
- line = line.strip()
24
- if line:
25
- sink.write(line)
26
- sink.write("\n")
27
  return str(output)
28
 
29
 
30
  def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50_000) -> None:
31
  """Train a byte-fallback SentencePiece BPE model."""
 
 
32
  spm.SentencePieceTrainer.train(
33
  input=input_path,
34
  model_prefix=model_prefix,
@@ -52,17 +70,18 @@ def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50
52
  def build_argparser() -> argparse.ArgumentParser:
53
  """Build the CLI parser."""
54
  parser = argparse.ArgumentParser(description="Train the SAGE SentencePiece tokenizer.")
55
- parser.add_argument("--input", nargs="+", required=True, help="Plain-text corpus files.")
56
  parser.add_argument("--model-prefix", default="tokenizer/tokenizer", help="SentencePiece model prefix.")
57
  parser.add_argument("--vocab-size", type=int, default=50_000, help="Tokenizer vocabulary size.")
58
  parser.add_argument("--training-text", default="tokenizer/training_corpus.txt", help="Temporary combined text file.")
 
59
  return parser
60
 
61
 
62
  def main() -> None:
63
  """Train the tokenizer from CLI arguments."""
64
  args = build_argparser().parse_args()
65
- training_text = write_training_text(args.input, args.training_text)
66
  train_sentencepiece(training_text, args.model_prefix, args.vocab_size)
67
 
68
 
 
3
  from __future__ import annotations
4
 
5
  import argparse
6
+ import json
7
  from pathlib import Path
8
+ from typing import Iterable, Iterator
9
 
10
+ DEFAULT_SPECIAL_TOKENS = ("<bos>", "<eos>", "<pad>", "<unk>", "[INST]", "[/INST]")
11
 
12
 
13
+ def iter_training_text(corpus_paths: Iterable[str], text_key: str = "text") -> Iterator[str]:
14
+ """Yield training lines from plain-text or JSONL corpus files."""
15
+ for path in corpus_paths:
16
+ source = Path(path)
17
+ suffix = source.suffix.lower()
18
+ with source.open("r", encoding="utf-8") as handle:
19
+ if suffix == ".jsonl":
20
+ for raw_line in handle:
21
+ raw_line = raw_line.strip()
22
+ if not raw_line:
23
+ continue
24
+ payload = json.loads(raw_line)
25
+ text = payload.get(text_key)
26
+ if isinstance(text, str) and text.strip():
27
+ yield text.strip()
28
+ continue
29
+ for raw_line in handle:
30
+ text = raw_line.strip()
31
+ if text:
32
+ yield text
33
 
34
 
35
+ def write_training_text(corpus_paths: Iterable[str], output_path: str, text_key: str = "text") -> str:
36
  """Concatenate corpus text into a plain-text file for SentencePiece."""
37
  output = Path(output_path)
38
  output.parent.mkdir(parents=True, exist_ok=True)
39
  with output.open("w", encoding="utf-8") as sink:
40
+ for line in iter_training_text(corpus_paths, text_key=text_key):
41
+ sink.write(line)
42
+ sink.write("\n")
 
 
 
 
43
  return str(output)
44
 
45
 
46
  def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50_000) -> None:
47
  """Train a byte-fallback SentencePiece BPE model."""
48
+ import sentencepiece as spm
49
+
50
  spm.SentencePieceTrainer.train(
51
  input=input_path,
52
  model_prefix=model_prefix,
 
70
  def build_argparser() -> argparse.ArgumentParser:
71
  """Build the CLI parser."""
72
  parser = argparse.ArgumentParser(description="Train the SAGE SentencePiece tokenizer.")
73
+ parser.add_argument("--input", nargs="+", required=True, help="Plain-text or JSONL corpus files.")
74
  parser.add_argument("--model-prefix", default="tokenizer/tokenizer", help="SentencePiece model prefix.")
75
  parser.add_argument("--vocab-size", type=int, default=50_000, help="Tokenizer vocabulary size.")
76
  parser.add_argument("--training-text", default="tokenizer/training_corpus.txt", help="Temporary combined text file.")
77
+ parser.add_argument("--text-key", default="text", help="JSONL field to read when --input contains .jsonl files.")
78
  return parser
79
 
80
 
81
  def main() -> None:
82
  """Train the tokenizer from CLI arguments."""
83
  args = build_argparser().parse_args()
84
+ training_text = write_training_text(args.input, args.training_text, text_key=args.text_key)
85
  train_sentencepiece(training_text, args.model_prefix, args.vocab_size)
86
 
87