sage002 commited on
Commit
ef18673
·
verified ·
1 Parent(s): 36daf84

feat: rewrite SAGE 1B architecture and replace legacy repo contents

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -5,7 +5,7 @@
5
  # Git files
6
  .git/*
7
  .gitignore
8
-
9
  # Python virtual environments
10
  .venv/*
11
  venv/*
 
5
  # Git files
6
  .git/*
7
  .gitignore
8
+ hf_push.py
9
  # Python virtual environments
10
  .venv/*
11
  venv/*
.gitignore CHANGED
@@ -3,6 +3,8 @@ __pycache__/
3
  *.py[cod]
4
  *$py.class
5
 
 
 
6
  # C extensions
7
  *.so
8
 
@@ -25,8 +27,6 @@ share/python-wheels/
25
  .installed.cfg
26
  *.egg
27
  MANIFEST
28
- hf_push.py
29
- .hugging_face_ignore
30
 
31
  # PyInstaller
32
  # Usually these files are written by a python script, before a-one-file pack
@@ -112,6 +112,12 @@ celerybeat.pid
112
 
113
  # Sage Project Specific
114
  checkpoints/
 
 
 
 
 
 
115
  .venv/
116
  .env
117
  .DS_Store
 
3
  *.py[cod]
4
  *$py.class
5
 
6
+ wandb/
7
+
8
  # C extensions
9
  *.so
10
 
 
27
  .installed.cfg
28
  *.egg
29
  MANIFEST
 
 
30
 
31
  # PyInstaller
32
  # Usually these files are written by a python script, before a-one-file pack
 
112
 
113
  # Sage Project Specific
114
  checkpoints/
115
+ runs/
116
+ tokenizer/*.model
117
+ tokenizer/*.vocab
118
+ tokenizer/training_corpus.txt
119
+ data/raw/
120
+ data/processed/
121
  .venv/
122
  .env
123
  .DS_Store
README.md CHANGED
@@ -1,163 +1,455 @@
1
- # SAGE — Self-Adaptive General Engine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- **SAGE** is a senior-grade, production-structured Large Language Model (LLM) system built entirely from scratch using Python and PyTorch. It implements modern transformer architectures including Mixture of Experts (MoE), Rotary Positional Embeddings (RoPE), and Low-Rank Adaptation (LoRA).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- Designed to be both educational and functional, SAGE can be trained, fine-tuned, quantized, and deployed on a single consumer GPU (e.g., NVIDIA T4 with 16GB VRAM).
6
 
7
- ---
8
 
9
- ## ☁️ Cloud Quickstart (Kaggle / Colab)
10
- Running SAGE in the cloud? Check out the **[Kaggle & Colab Quickstart Guide](file:///c:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/SAGE_KAGGLE_GUIDE.md)** for one-click setup and a premium interactive chat interface.
 
 
 
 
 
 
 
 
11
 
12
- ---
13
 
14
- ## 🚀 Key Features
 
 
 
15
 
16
- - **Decoder-Only Transformer**: A GPT-style architecture with pre-layer normalization.
17
- - **Mixture of Experts (MoE)**: Efficient scaling with a learned router selecting top-k experts per token.
18
- - **Rotary Positional Embeddings (RoPE)**: Enhanced long-sequence generalization.
19
- - **KV-Cache Inference**: O(1) time-per-token generation for high-speed response.
20
- - **Retrieval-Augmented Generation (RAG)**: Integration with FAISS for document-based context lookup.
21
- - **Efficient Fine-Tuning**: Support for LoRA and instruction tuning with loss masking.
22
- - **Post-Training Quantization**: INT8 support to reduce memory footprint.
23
- - **Interactive CLI**: A full REPL (Read-Eval-Print Loop) for chatting and system management.
24
 
25
- ---
26
 
27
- ## 📂 Project Structure
28
 
29
- ```text
30
- sage/
31
- ├── model.py # Core architecture (Transformer, MoE, RoPE, Attention)
32
- ├── data.py # Tokenization (tiktoken) & Streaming Datasets (HuggingFace)
33
- ├── train.py # Pre-training loop with AdamW, AMP, and Cosine Decay
34
- ├── inference.py # Text generation (Greedy, Temp, Top-k, Top-p sampling)
35
- ├── finetune.py # LoRA implementation & Instruction Tuning
36
- ├── optimize.py # INT8 Quantization & Pruning utilities
37
- ├── memory.py # RAG Vector Store & Conversation History
38
- ├── cli.py # Interactive Terminal Interface
39
- ├── utils.py # Logging, Checkpointing, and Helper functions
40
- ├── config.py # Central Hyperparameter Configuration
41
- └── requirements.txt # System dependencies
42
- sage_single.py # Consolidated single-file version for easy portability
43
  ```
44
 
45
- ---
 
 
 
 
 
 
46
 
47
- ## 🛠️ Installation & Setup
48
 
49
- ### 1. Requirements
50
- Ensure you have Python 3.9+ and a CUDA-compatible GPU (recommended).
 
51
 
52
  ```bash
53
- # Clone the repository (GitHub)
54
- git clone https://github.com/er-del/sage.git
55
- cd sage
56
 
57
- # OR Clone from Hugging Face
58
- git clone https://huggingface.co/sage002/sage
59
- cd sage
60
 
61
- # Install dependencies
62
- pip install -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ```
64
 
65
- ### 2. Dependencies
66
- - **PyTorch**: Core deep learning framework.
67
- - **tiktoken**: Fast BPE tokenization (OpenAI's cl100k_base).
68
- - **datasets**: For streaming training data from HuggingFace.
69
- - **faiss-cpu**: For vector-based retrieval (RAG).
70
- - **tqdm**: Progress bars for training.
71
- - **bitsandbytes**: (Optional) For advanced quantization.
72
 
73
- ---
 
 
74
 
75
- ## 🎮 Getting Started
76
 
77
- ### Launching the CLI
78
- You can run the modular version or the single-file version:
79
 
80
  ```bash
81
- # Modular version
82
- python -m sage.cli
 
 
 
 
 
 
 
83
 
84
- # Single-file version
85
- python sage_single.py
 
 
 
 
 
 
 
 
 
 
86
  ```
87
 
88
- ### Basic Chat
89
- Once launched, simply type your message to chat with SAGE. The system uses a rolling conversation history to maintain context.
90
 
91
- ---
 
 
92
 
93
- ## 👨‍🏫 Training SAGE
94
 
95
- SAGE supports real-time training either directly from the interactive REPL or via simple one-liner CLI commands (useful for background scripts).
96
 
97
- ### Non-Interactive "One-Liner" Commands
98
- If you want to bypass the chat interface and just run a training job, pass the command as a CLI argument:
99
  ```bash
100
- python sage_single.py --train 100 # Pre-train for 100 steps
101
- python sage_single.py --finetune 200 # Instruction-tune for 200 steps
102
- python sage_single.py --quantize # Apply INT8 quantization
103
  ```
104
 
105
- ### Interactive REPL Commands
106
- If you are inside the chat interface, use the slash commands:
107
 
108
- ### /train [steps]
109
- Run pre-training using the `TinyStories` dataset (default).
110
- - `/train 100` — Trains for 100 steps and saves a checkpoint.
111
 
112
- ### /finetune [steps]
113
- Perform instruction fine-tuning using LoRA adapters.
114
- - `/finetune 200` — Trains on instruction/response pairs and merges weights.
115
 
116
- ---
 
 
117
 
118
- ## 🧠 Advanced Commands
119
 
120
- | Command | Action |
121
- | :--- | :--- |
122
- | `/save` | Manually save the current model checkpoint. |
123
- | `/load` | Reload the latest checkpoint from the `checkpoints` directory. |
124
- | `/quantize` | Convert model weights to INT8 (CPU) for reduced memory usage. |
125
- | `/rag on` | Enable Retrieval-Augmented Generation. |
126
- | `/rag add <text>` | Add new knowledge to SAGE's retrieval database. |
127
- | `/clear` | Clear the current conversation history. |
128
- | `/help` | Show the list of available commands. |
129
- | `/exit` | Exit the program cleanly. |
130
 
131
- ---
132
 
133
- ## 🏗️ Architecture Details
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- ### Mixture of Experts (MoE)
136
- SAGE swaps standard FFN layers for MoE blocks. Each block contains 4 experts, where exactly 2 are activated per token via a learned linear router. This allows for higher total capacity without increasing the computational cost per token.
137
 
138
- ### Rotary Positional Embeddings (RoPE)
139
- Positions are encoded via complex-valued rotations of query and key vectors. This allows SAGE to better handle sequences longer than what it was trained on compared to absolute position embeddings.
140
 
141
- ### Inference Engine
142
- Generation supports:
143
- - **Temperature**: Adjusts randomness.
144
- - **Top-k**: Limits sampling to the most likely 'k' tokens.
145
- - **Top-p (Nucleus)**: Limits sampling to a cumulative probability threshold.
146
- - **KV-Caching**: Caches Attention keys and values to avoid redundant computation.
 
 
 
 
 
147
 
148
- ---
149
 
150
- ## 🤗 Hugging Face Model Hub
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- This project is actively maintained on Hugging Face. You can find pre-trained checkpoints, datasets, and community discussions here:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- 🔗 **[huggingface.co/sage002/sage](https://huggingface.co/sage002/sage)**
155
 
156
- **Developed by Antigravity AI Systems.**
 
 
 
 
 
 
157
 
158
- ---
159
 
160
- ## 📜 Disclaimer
161
- SAGE is an experimental engine. While architecturally complete, the quality of generated responses depends heavily on the amount of training data and compute steps provided.
162
 
163
- **Developed by Antigravity AI Systems.**
 
 
 
 
 
1
+ # SAGE 1B
2
+
3
+ SAGE is a root-level rewrite of this repository into a production-style dense language model project. The current baseline is a 1B-class decoder-only transformer with RMSNorm, RoPE, grouped-query attention, SwiGLU, SentencePiece, resumable training, Parquet-backed datasets, and FastAPI serving.
4
+
5
+ This README is written as a practical operator guide. It tells you:
6
+
7
+ - what the project contains
8
+ - what is already implemented
9
+ - what commands to run
10
+ - what files are inputs and outputs
11
+ - what parts are scaffolding versus fully wired
12
+
13
+ ## What SAGE Is
14
+
15
+ SAGE is organized into these layers:
16
+
17
+ 1. `tokenizer/`
18
+ Trains and validates a SentencePiece tokenizer.
19
+ 2. `data/`
20
+ Handles raw corpus ingest, filtering, deduplication, sharding, and packed datasets.
21
+ 3. `model/`
22
+ Implements the dense decoder-only transformer.
23
+ 4. `train/`
24
+ Handles optimizer setup, scheduler, hardware detection, checkpoints, and the training loop.
25
+ 5. `eval/`
26
+ Provides perplexity evaluation and benchmark harness registration.
27
+ 6. `serve/`
28
+ Exposes FastAPI servers and quantization helpers.
29
+
30
+ ## Current Baseline
31
+
32
+ | Component | Value |
33
+ | --- | --- |
34
+ | Layers | 24 |
35
+ | d_model | 2048 |
36
+ | Attention heads | 16 |
37
+ | KV heads | 8 |
38
+ | Head dim | 128 |
39
+ | FFN dim | 5632 |
40
+ | Context length | 4096 |
41
+ | Vocab size | 50000 |
42
+ | Norm | RMSNorm |
43
+ | Positional encoding | RoPE |
44
+ | Attention | GQA + SDPA |
45
+ | Activation | SwiGLU |
46
+ | Weight tying | Enabled |
47
+
48
+ ## Repository Layout
49
 
50
+ ```text
51
+ configs/
52
+ model/ model YAMLs for 1B, 3B, 7B
53
+ data/ corpus mix and shard config
54
+ train/ LR, checkpoint, and logging schedule
55
+ data/
56
+ ingest.py raw source registry and streaming helpers
57
+ filter.py license/lang/PII/safety/quality filtering
58
+ dedup.py exact and near-duplicate removal
59
+ shard.py tokenization + parquet shard writing + manifest
60
+ dataset.py packed iterable dataset with resume skip()
61
+ tokenizer/
62
+ train_tokenizer.py
63
+ validate_tokenizer.py
64
+ model/
65
+ config.py
66
+ rmsnorm.py
67
+ rope.py
68
+ attention.py
69
+ mlp.py
70
+ block.py
71
+ model.py
72
+ train/
73
+ loss.py
74
+ optimizer.py
75
+ checkpoint.py
76
+ distributed.py
77
+ hardware.py
78
+ trainer.py
79
+ eval/
80
+ perplexity.py
81
+ benchmarks.py
82
+ long_context.py
83
+ regression.py
84
+ serve/
85
+ kv_cache.py
86
+ quantize.py
87
+ server.py
88
+ server_cpu.py
89
+ scripts/
90
+ run_data_pipeline.sh
91
+ run_training.sh
92
+ run_eval.sh
93
+ run_serve.sh
94
+ run_serve_cpu.sh
95
+ run_validate_tokenizer.sh
96
+ tests/
97
+ ```
98
 
99
+ ## What Is Fully Working vs. What Is Scaffolded
100
 
101
+ ### Working now
102
 
103
+ - tokenizer training
104
+ - tokenizer validation
105
+ - data filtering and dedup helpers
106
+ - packed dataset logic
107
+ - dense transformer forward pass
108
+ - checkpoint save and resume
109
+ - hardware detection
110
+ - trainer entrypoint
111
+ - FastAPI health and basic generate endpoint
112
+ - unit and smoke tests
113
 
114
+ ### Scaffolded but not yet a full production runner
115
 
116
+ - benchmark execution against downloaded external datasets
117
+ - a single end-to-end corpus build command that downloads and preprocesses public corpora automatically
118
+ - production-grade multi-node launch tooling
119
+ - real llama.cpp server wiring beyond availability checks
120
 
121
+ That means the core codebase is real, but you still need to provide your own corpus files and Parquet shards before running a training job.
 
 
 
 
 
 
 
122
 
123
+ ## Install
124
 
125
+ Create and activate a virtual environment, then install dependencies:
126
 
127
+ ```bash
128
+ pip install -r requirements.txt
 
 
 
 
 
 
 
 
 
 
 
 
129
  ```
130
 
131
+ Recommended optional extras:
132
+
133
+ - `sentencepiece` is required for tokenizer training and validation
134
+ - `bitsandbytes` is useful for 8-bit experiments
135
+ - `llama.cpp` or `llama-cpp-python` is needed for the CPU serving path
136
+
137
+ ## Quick Start
138
 
139
+ If you want the shortest path to verifying the repo:
140
 
141
+ 1. Install dependencies.
142
+ 2. Run tests.
143
+ 3. Start the FastAPI server.
144
 
145
  ```bash
146
+ pytest -q
147
+ uvicorn serve.server:app --host 127.0.0.1 --port 8000
148
+ ```
149
 
150
+ Then check:
 
 
151
 
152
+ ```bash
153
+ curl http://127.0.0.1:8000/health
154
+ ```
155
+
156
+ ## Command Reference
157
+
158
+ The detailed command guide is in [docs/COMMANDS.md](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/docs/COMMANDS.md:1). The most important commands are below.
159
+
160
+ ### 1. Train tokenizer
161
+
162
+ Cross-platform Python command:
163
+
164
+ ```bash
165
+ python -m tokenizer.train_tokenizer \
166
+ --input data/raw/general_web.txt data/raw/code.txt \
167
+ --model-prefix tokenizer/tokenizer \
168
+ --vocab-size 50000
169
+ ```
170
+
171
+ Linux/macOS/WSL wrapper:
172
+
173
+ ```bash
174
+ bash scripts/run_data_pipeline.sh \
175
+ --input data/raw/general_web.txt data/raw/code.txt \
176
+ --model-prefix tokenizer/tokenizer \
177
+ --vocab-size 50000
178
+ ```
179
+
180
+ Outputs:
181
+
182
+ - `tokenizer/tokenizer.model`
183
+ - `tokenizer/tokenizer.vocab`
184
+ - `tokenizer/training_corpus.txt`
185
+
186
+ ### 2. Validate tokenizer
187
+
188
+ ```bash
189
+ python - <<'PY'
190
+ from tokenizer.validate_tokenizer import validate_model_file
191
+ validate_model_file("tokenizer/tokenizer.model")
192
+ print("tokenizer ok")
193
+ PY
194
  ```
195
 
196
+ Or:
 
 
 
 
 
 
197
 
198
+ ```bash
199
+ bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model
200
+ ```
201
 
202
+ ### 3. Train the model
203
 
204
+ Training expects existing Parquet shards. Example:
 
205
 
206
  ```bash
207
+ python -m train.trainer \
208
+ --model-config configs/model/1b.yaml \
209
+ --schedule-config configs/train/schedule.yaml \
210
+ --train-shards data/processed/shard-00000.parquet data/processed/shard-00001.parquet \
211
+ --validation-shards data/processed/shard-00002.parquet \
212
+ --output-dir runs/sage-1b
213
+ ```
214
+
215
+ Useful options:
216
 
217
+ - `--steps 100` for a short smoke run
218
+ - `--disable-wandb` to disable offline W&B logging
219
+
220
+ Example smoke run:
221
+
222
+ ```bash
223
+ python -m train.trainer \
224
+ --train-shards data/processed/shard-00000.parquet \
225
+ --validation-shards data/processed/shard-00001.parquet \
226
+ --output-dir runs/smoke \
227
+ --steps 20 \
228
+ --disable-wandb
229
  ```
230
 
231
+ ### 4. Run evaluation harness
 
232
 
233
+ ```bash
234
+ bash scripts/run_eval.sh
235
+ ```
236
 
237
+ This currently prints the registered benchmark surfaces. It is a harness check, not a full benchmark download-and-run pipeline.
238
 
239
+ ### 5. Start the GPU server
240
 
 
 
241
  ```bash
242
+ uvicorn serve.server:app --host 0.0.0.0 --port 8000
 
 
243
  ```
244
 
245
+ Or:
 
246
 
247
+ ```bash
248
+ bash scripts/run_serve.sh
249
+ ```
250
 
251
+ ### 6. Start the CPU server
 
 
252
 
253
+ ```bash
254
+ uvicorn serve.server_cpu:app --host 0.0.0.0 --port 8001
255
+ ```
256
 
257
+ Or:
258
 
259
+ ```bash
260
+ bash scripts/run_serve_cpu.sh
261
+ ```
262
+
263
+ ### 7. Call the generate endpoint
 
 
 
 
 
264
 
265
+ The current server takes token IDs directly, not raw text strings.
266
 
267
+ ```bash
268
+ curl -X POST http://127.0.0.1:8000/generate \
269
+ -H "Content-Type: application/json" \
270
+ -d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
271
+ ```
272
+
273
+ Response shape:
274
+
275
+ ```json
276
+ {
277
+ "tokens": [1, 42, 99, 123, 456]
278
+ }
279
+ ```
280
 
281
+ ## How Training Works
 
282
 
283
+ The training flow is:
 
284
 
285
+ 1. load model config from `configs/model/*.yaml`
286
+ 2. load schedule config from `configs/train/schedule.yaml`
287
+ 3. detect hardware in `train/hardware.py`
288
+ 4. build optimizer and cosine scheduler
289
+ 5. load latest checkpoint if one exists
290
+ 6. call `PackedDataset.skip()` so resume does not replay already-trained batches
291
+ 7. run forward/backward with autocast on CUDA or MPS
292
+ 8. clip gradients
293
+ 9. log metrics to `metrics.jsonl` and optionally offline W&B
294
+ 10. run validation perplexity at eval intervals
295
+ 11. save checkpoint every configured interval
296
 
297
+ Important output files during training:
298
 
299
+ - `runs/<run-name>/metrics.jsonl`
300
+ - `runs/<run-name>/ckpt_step_0001000.pt`
301
+ - later checkpoints in the same folder
302
+
303
+ ## How Data Is Expected to Look
304
+
305
+ ### Raw text files for tokenizer training
306
+
307
+ Simple UTF-8 text files are enough:
308
+
309
+ ```text
310
+ This is a training document.
311
+ This is another one.
312
+ ```
313
 
314
+ ### Raw JSONL records for ingest/filter work
315
+
316
+ The ingest layer assumes records like:
317
+
318
+ ```json
319
+ {"text": "example text"}
320
+ ```
321
+
322
+ ### Processed Parquet shards for training
323
+
324
+ The trainer expects Parquet rows with at least:
325
+
326
+ - `tokens`
327
+ - `split`
328
+
329
+ The sharding helper writes:
330
+
331
+ - `id`
332
+ - `text`
333
+ - `tokens`
334
+ - `domain_tag`
335
+ - `quality_tier`
336
+ - `lang`
337
+ - `token_count`
338
+ - `split`
339
+
340
+ ## Main Config Files
341
+
342
+ ### [configs/model/1b.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/model/1b.yaml:1)
343
+
344
+ Controls the model shape:
345
+
346
+ - layers
347
+ - hidden size
348
+ - heads
349
+ - KV heads
350
+ - FFN size
351
+ - vocab size
352
+ - context length
353
+
354
+ ### [configs/data/mix.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/data/mix.yaml:1)
355
+
356
+ Controls corpus weights and split ratios.
357
+
358
+ ### [configs/train/schedule.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/train/schedule.yaml:1)
359
+
360
+ Controls:
361
+
362
+ - total token target
363
+ - LR schedule
364
+ - warmup
365
+ - checkpoint interval
366
+ - log interval
367
+ - eval interval
368
+
369
+ ## Common Workflows
370
+
371
+ ### Workflow A: verify the repo
372
+
373
+ ```bash
374
+ pip install -r requirements.txt
375
+ pytest -q
376
+ ```
377
+
378
+ ### Workflow B: train tokenizer only
379
+
380
+ ```bash
381
+ python -m tokenizer.train_tokenizer --input data/raw/general_web.txt --model-prefix tokenizer/tokenizer
382
+ python - <<'PY'
383
+ from tokenizer.validate_tokenizer import validate_model_file
384
+ validate_model_file("tokenizer/tokenizer.model")
385
+ print("ok")
386
+ PY
387
+ ```
388
+
389
+ ### Workflow C: smoke-train on local shards
390
+
391
+ ```bash
392
+ python -m train.trainer \
393
+ --train-shards data/processed/shard-00000.parquet \
394
+ --validation-shards data/processed/shard-00001.parquet \
395
+ --output-dir runs/smoke \
396
+ --steps 20 \
397
+ --disable-wandb
398
+ ```
399
+
400
+ ### Workflow D: serve locally
401
+
402
+ ```bash
403
+ uvicorn serve.server:app --host 127.0.0.1 --port 8000
404
+ curl http://127.0.0.1:8000/health
405
+ ```
406
+
407
+ ## Troubleshooting
408
+
409
+ ### `No training shards provided`
410
+
411
+ You launched the trainer without `--train-shards`. The trainer is working as designed, but it needs Parquet shard paths.
412
+
413
+ ### `ModuleNotFoundError: sentencepiece`
414
+
415
+ Install dependencies:
416
+
417
+ ```bash
418
+ pip install -r requirements.txt
419
+ ```
420
+
421
+ ### FastAPI starts but generate is not useful
422
+
423
+ That is expected right now if you have not trained or loaded a checkpoint. The server instantiates the model architecture, but it does not yet load a trained checkpoint automatically.
424
+
425
+ ### CPU server says llama.cpp is unavailable
426
+
427
+ Install `llama.cpp` or `llama-cpp-python`. The current CPU server is a readiness surface, not a bundled llama.cpp runtime.
428
+
429
+ ## Tests
430
+
431
+ Run the full suite:
432
+
433
+ ```bash
434
+ pytest -q
435
+ ```
436
 
437
+ Coverage areas:
438
 
439
+ - tokenizer roundtrip validation
440
+ - model shapes
441
+ - attention math
442
+ - data filtering and packing
443
+ - checkpoint roundtrip
444
+ - hardware summaries
445
+ - FastAPI health endpoints
446
 
447
+ ## Next Practical Step
448
 
449
+ If you want the fastest real progress from here, the next step is:
 
450
 
451
+ 1. prepare a small local corpus
452
+ 2. train the tokenizer
453
+ 3. write Parquet shards with `data/shard.py`
454
+ 4. run a `--steps 20` smoke training job
455
+ 5. only then start extending benchmark or serving behavior
SAGE_KAGGLE_GUIDE.md DELETED
@@ -1,146 +0,0 @@
1
- # 🪐 SAGE: Kaggle & Colab Quickstart Guide
2
-
3
- Welcome to the **Self-Adaptive General Engine (SAGE)**. This guide will help you get SAGE v2 running on a cloud environment (like Kaggle's 2x T4 or Google Colab) in under 5 minutes.
4
-
5
- ---
6
-
7
- ## 🛠️ Step 1: Environment Setup
8
-
9
- Run this cell first to install dependencies and fix any common binary incompatibilities (like the Numpy/Torch mismatch).
10
-
11
- ```python
12
- # Install PyTorch 2.1 with CUDA 12.1 (supports Tesla P100 sm_60)
13
- !pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
14
-
15
- # Install other dependencies
16
- !pip install "numpy<2.0.0" --force-reinstall
17
- !pip install bitsandbytes tqdm tiktoken faiss-cpu datasets wandb --upgrade
18
-
19
- print("✅ Environment ready. Please RESTART YOUR KERNEL now if this is your first run.")
20
- ```
21
-
22
- ---
23
-
24
- ## 🔑 Step 2: Weights & Biases Logging (Optional but Recommended)
25
-
26
- To track your training progress with professional charts:
27
-
28
- 1. Get your API Key from [wandb.ai/authorize](https://wandb.ai/authorize).
29
- 2. Add it to your Kaggle **Secrets** with the label `WANDB_API_KEY`.
30
- 3. Run this:
31
-
32
- ```python
33
- import wandb
34
- from kaggle_secrets import UserSecretsClient
35
- try:
36
- user_secrets = UserSecretsClient()
37
- wandb.login(key=user_secrets.get_secret("WANDB_API_KEY"))
38
- except:
39
- import os
40
- os.environ["WANDB_MODE"] = "offline"
41
- print("⚠️ W&B Secret not found. Running in offline mode.")
42
- ```
43
-
44
- ---
45
-
46
- ## 💬 Step 3: Launch the SAGE Chat Interface
47
-
48
- This is a premium, multi-GPU enabled chat widget. Paste this into a cell to start interacting with SAGE.
49
-
50
- **Note:** SAGE automatically detects GPU compatibility and falls back to CPU if needed.
51
-
52
- ```python
53
- import sys, os, torch, random
54
- import torch.nn as nn
55
- import ipywidgets as widgets
56
- from IPython.display import display, HTML
57
-
58
- # Verify SAGE is accessible (debugging import issues)
59
- if not os.path.exists('sage/__init__.py'):
60
- print("❌ ERROR: sage/ folder not found in current directory!")
61
- print(" Make sure you've cloned the repo: !git clone https://github.com/er-del/sage.git")
62
- raise ImportError("sage module not found")
63
-
64
- # Add current directory to path if needed
65
- if '' not in sys.path and '.' not in sys.path:
66
- sys.path.insert(0, '')
67
-
68
- # Import SAGE
69
- from sage import SageModel, SageConfig, SageTokenizer, generate, ConversationHistory, train as train_model, finetune
70
- from sage import __version__ as sage_version
71
-
72
- # Verify import worked
73
- print(f"✅ SAGE v{sage_version} loaded successfully")
74
-
75
- # -- Initialization --
76
- config = SageConfig()
77
- # Note: config.device automatically checks GPU compatibility and falls back to CPU if needed
78
- device = config.device
79
- print(f"🖥️ Using device: {device}")
80
-
81
- tokenizer = SageTokenizer()
82
- history = ConversationHistory(tokenizer, max_tokens=1024)
83
- model = SageModel(config)
84
-
85
- # -- Multi-GPU Logic (only if CUDA is actually being used) --
86
- if device.type == "cuda":
87
- gpu_count = torch.cuda.device_count()
88
- if gpu_count > 1:
89
- print(f"🚀 Multi-GPU active: {gpu_count} GPUs.")
90
- model = nn.DataParallel(model)
91
- model = model.to(device)
92
-
93
- # -- Load Weights --
94
- ckpt_path = "checkpoints/sage_latest.pt"
95
- if os.path.exists(ckpt_path):
96
- base_model = getattr(model, "module", model)
97
- ckpt = torch.load(ckpt_path, map_location=device)
98
- base_model.load_state_dict(ckpt['model_state_dict'])
99
- print("✅ Weights loaded from checkpoint.")
100
- else:
101
- print("⚠️ RANDOM WEIGHTS (Type /train <steps> to begin learning).")
102
-
103
- # -- Render UI --
104
- chat_display = widgets.Output(layout={'border': '1px solid #444', 'height': '450px', 'overflow_y': 'scroll', 'padding': '10px'})
105
- text_input = widgets.Text(placeholder="Chat or type /train 1000...", layout={'width': '80%'})
106
- send_button = widgets.Button(description="Send", button_style='primary', layout={'width': '18%'})
107
- display(HTML("<style>.user-msg { background: #2b2d42; color: #fff; padding: 10px; border-radius: 10px; margin: 5px; border-left: 5px solid #ef233c; } .sage-msg { background: #1a1b2e; color: #fff; padding: 10px; border-radius: 10px; margin: 5px; border-left: 5px solid #4cc9f0; }</style>"))
108
-
109
- def on_send(_=None):
110
- user_text = text_input.value.strip()
111
- if not user_text: return
112
- text_input.value = ""
113
- with chat_display:
114
- if user_text.startswith("/train"):
115
- steps = int(user_text.split()[1]) if len(user_text.split()) > 1 else 100
116
- print(f"🚀 TRAINING STARTING ({steps} steps)...")
117
- train_model(model, config, total_steps=steps)
118
- print("✅ DONE.")
119
- return
120
- display(HTML(f'<div class="user-msg"><b>You:</b> {user_text}</div>'))
121
- response = generate(model, tokenizer, history.build_prompt(user_text), stream=False)
122
- res = response.split("SAGE:")[-1].split("</response>")[0].replace("<response>", "").strip()
123
- history.add("user", user_text); history.add("assistant", res)
124
- display(HTML(f'<div class="sage-msg"><b>SAGE:</b> {res}</div>'))
125
-
126
- text_input.on_submit(on_send); send_button.on_click(lambda b: on_send())
127
- display(chat_display, widgets.HBox([text_input, send_button]))
128
- ```
129
-
130
- ---
131
-
132
- ## 🎮 Command Cheat Sheet
133
-
134
- | Command | Action |
135
- | :--- | :--- |
136
- | `/train <steps>` | Starts pre-training (Base knowledge). Recommended: 5000+ |
137
- | `/clear` | Resets the conversation history. |
138
- | `/finetune <steps>` | (Coming Soon) Starts instruction fine-tuning. |
139
-
140
- ---
141
-
142
- ## 💡 Pro Tips for T4 GPUs
143
-
144
- 1. **Batch Size**: The default `batch_size=4` with `gradient_accumulation=16` is perfect for a 2x T4 setup (32GB VRAM total).
145
- 2. **Persistence**: Kaggle outputs are deleted when the session ends. Make sure to **download** the `checkpoints/` folder or sync it to **Hugging Face** regularly.
146
- 3. **Patience**: Loss will fluctuate. Look for a steady downward trend on your W&B dashboard!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
SAGE_V3_ROADMAP.md DELETED
@@ -1,52 +0,0 @@
1
- # SAGE v3: The "Long-Vision" Roadmap 🗺️
2
-
3
- This document outlines the high-impact architectural upgrades that will transform SAGE into a multi-thousand token reasoning assistant with multimedia capabilities.
4
-
5
- ---
6
-
7
- ## 🏗️ 1. Long-Context Scaling (YaRN / RoPE-Interpolation)
8
-
9
- **Goal**: Increase SAGE's maximum comprehension from 1,024 to **4,096+ tokens**.
10
-
11
- ### Technical Strategy:
12
- Currently, our `freqs_cis` are precomputed for a fixed window. In v3, we will implement **NTK-Aware Interpolation**.
13
- - **Implementation**: We will add a `scaling_factor` to `SageConfig`.
14
- - **Logic**: During inference, if the sequence length exceeds the original training window, we will "stretch" the rotary frequencies dynamically rather than letting them overflow.
15
- - **Benefit**: SAGE can read entire source code files or long essays without "losing its mind" at the 1,025th token.
16
-
17
- ---
18
-
19
- ## 📂 2. Interactive RAG (Kaggle UI Integration)
20
-
21
- **Goal**: Allow users to "Upload and Chat" with any file instantly in the notebook.
22
-
23
- ### Technical Strategy:
24
- - **Widget Update**: Add a `widgets.FileUpload` component to the Kaggle chat interface.
25
- - **Auto-Ingestion**: When a file is uploaded, a background hook will:
26
- 1. Parse the text (PDF, `.py`, `.md`).
27
- 2. Chunk it into 200-token segments.
28
- 3. Generate embeddings and add them to the **FAISS Vector Store**.
29
- - **Real-time Recall**: SAGE will automatically pull context from these uploaded files using the `retrieve_context` logic we've already built.
30
-
31
- ---
32
-
33
- ## 👁️ 3. Multimodal Foundation (Vision Projection)
34
-
35
- **Goal**: Let SAGE "see" images.
36
-
37
- ### Technical Strategy:
38
- Since SAGE is a small, efficient model (133M), it is the perfect candidate for a **Vision-Language Model (VLM)**.
39
- - **Architecture**: We will add a frozen **CLIP-ViT** image encoder.
40
- - **The Bridge**: We will implement a `VisionProjector` (a simple 2-layer MLP) that converts CLIP image embeddings (e.g., 768-dim) into SAGE token embeddings (512-dim).
41
- - **Outcome**: You will be able to provide an image URL and a prompt like "What is in this picture?", and SAGE will respond based on the visual tokens.
42
-
43
- ---
44
-
45
- ## ⚡ 4. Training Stability: LayerNorm Tuning
46
-
47
- To support these advanced features, we will move to **RMSNorm** (Root Mean Square Layer Normalization) for even faster convergence and better numerical stability on the double-T4 setup.
48
-
49
- ---
50
-
51
- ### Which one first?
52
- We can begin implementing **RoPE Scaling** immediately to give SAGE a massive context boost without needing new weights. Just let me know when you're ready!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/data/mix.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_sources:
2
+ general_web:
3
+ weight_percent: 55
4
+ quality_tiers: [high, medium]
5
+ code:
6
+ weight_percent: 15
7
+ quality_tiers: [high, medium]
8
+ math_science:
9
+ weight_percent: 12
10
+ quality_tiers: [high, medium]
11
+ books_longform:
12
+ weight_percent: 10
13
+ quality_tiers: [high, medium]
14
+ multilingual:
15
+ weight_percent: 5
16
+ quality_tiers: [high, medium]
17
+ synthetic:
18
+ weight_percent: 3
19
+ quality_tiers: [high]
20
+ splits:
21
+ train: 0.989
22
+ validation: 0.01
23
+ test: 0.001
24
+ shard_size_bytes: 2147483648
25
+ format: parquet
configs/model/1b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sage-1b
2
+ num_layers: 24
3
+ d_model: 2048
4
+ num_attn_heads: 16
5
+ num_kv_heads: 8
6
+ head_dim: 128
7
+ ffn_hidden_dim: 5632
8
+ vocab_size: 50000
9
+ context_length: 4096
10
+ rope_base_frequency: 500000
11
+ rope_scaling_factor: 1.0
12
+ dropout: 0.0
13
+ tie_word_embeddings: true
configs/model/3b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sage-3b
2
+ num_layers: 28
3
+ d_model: 3072
4
+ num_attn_heads: 24
5
+ num_kv_heads: 8
6
+ head_dim: 128
7
+ ffn_hidden_dim: 8192
8
+ vocab_size: 50000
9
+ context_length: 8192
10
+ rope_base_frequency: 500000
11
+ rope_scaling_factor: 1.0
12
+ dropout: 0.0
13
+ tie_word_embeddings: true
configs/model/7b.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sage-7b
2
+ num_layers: 32
3
+ d_model: 4096
4
+ num_attn_heads: 32
5
+ num_kv_heads: 8
6
+ head_dim: 128
7
+ ffn_hidden_dim: 11008
8
+ vocab_size: 50000
9
+ context_length: 8192
10
+ rope_base_frequency: 500000
11
+ rope_scaling_factor: 1.0
12
+ dropout: 0.0
13
+ tie_word_embeddings: true
configs/train/schedule.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_name: sage-1b-pretrain
2
+ total_tokens: 50000000000
3
+ effective_batch_tokens: 2000000
4
+ peak_learning_rate: 3.0e-4
5
+ min_learning_rate: 3.0e-5
6
+ warmup_steps: 2000
7
+ weight_decay: 0.1
8
+ betas: [0.9, 0.95]
9
+ adam_eps: 1.0e-8
10
+ gradient_clip_norm: 1.0
11
+ checkpoint_interval: 1000
12
+ log_interval: 10
13
+ eval_interval: 1000
14
+ seed: 42
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Data pipeline modules for SAGE."""
data/dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Packed training dataset with deterministic resume support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Iterable, Iterator
9
+
10
+ import torch
11
+ from torch.utils.data import IterableDataset
12
+
13
+ try:
14
+ import pyarrow.parquet as pq
15
+ except ImportError: # pragma: no cover - optional at import time
16
+ pq = None
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class DatasetConfig:
21
+ """Configuration for packing token streams into training batches."""
22
+
23
+ shard_paths: tuple[str, ...]
24
+ context_length: int
25
+ split: str = "train"
26
+ seed: int = 42
27
+
28
+
29
+ class PackedDataset(IterableDataset):
30
+ """Iterate packed token sequences with document-boundary masks."""
31
+
32
+ def __init__(self, config: DatasetConfig):
33
+ super().__init__()
34
+ self.config = config
35
+ self._skip = 0
36
+
37
+ def skip(self, n_batches: int) -> None:
38
+ """Fast-forward the iterator by discarding the first n batches."""
39
+ self._skip = max(0, int(n_batches))
40
+
41
+ def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
42
+ skipped = 0
43
+ for batch in self._generate():
44
+ if skipped < self._skip:
45
+ skipped += 1
46
+ continue
47
+ yield batch
48
+
49
+ def _generate(self) -> Iterator[dict[str, torch.Tensor]]:
50
+ token_buffer: list[int] = []
51
+ boundary_buffer: list[int] = []
52
+ for row in self._iter_rows():
53
+ tokens = list(row["tokens"])
54
+ if len(tokens) < 2:
55
+ continue
56
+ token_buffer.extend(tokens)
57
+ boundary_buffer.extend([0] * (len(tokens) - 1) + [1])
58
+ while len(token_buffer) >= self.config.context_length + 1:
59
+ window_tokens = token_buffer[: self.config.context_length + 1]
60
+ window_boundaries = boundary_buffer[: self.config.context_length + 1]
61
+ yield pack_sequence(window_tokens, window_boundaries)
62
+ token_buffer = token_buffer[self.config.context_length :]
63
+ boundary_buffer = boundary_buffer[self.config.context_length :]
64
+
65
+ def _iter_rows(self) -> Iterator[dict[str, object]]:
66
+ if pq is None:
67
+ raise ImportError("pyarrow is required to read parquet shards.")
68
+ shard_paths = [Path(path) for path in self.config.shard_paths]
69
+ rng = random.Random(self.config.seed)
70
+ shard_paths = shard_paths[:]
71
+ rng.shuffle(shard_paths)
72
+ for path in shard_paths:
73
+ table = pq.read_table(path, columns=["tokens", "split"])
74
+ rows = table.to_pylist()
75
+ for row in rows:
76
+ if row["split"] != self.config.split:
77
+ continue
78
+ yield row
79
+
80
+
81
+ def pack_sequence(tokens: list[int], boundaries: list[int]) -> dict[str, torch.Tensor]:
82
+ """Turn one packed token window into model-ready tensors."""
83
+ input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
84
+ labels = torch.tensor(tokens[1:], dtype=torch.long)
85
+ loss_mask = torch.ones_like(input_ids, dtype=torch.float32)
86
+ attention_document_mask = torch.tensor(boundaries[:-1], dtype=torch.long)
87
+ return {
88
+ "input_ids": input_ids,
89
+ "labels": labels,
90
+ "loss_mask": loss_mask,
91
+ "document_boundaries": attention_document_mask,
92
+ }
data/dedup.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Exact and near-duplicate detection helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import re
7
+ from collections import defaultdict
8
+ from typing import Iterable
9
+
10
+
11
+ TOKEN_RE = re.compile(r"\w+")
12
+
13
+
14
+ def exact_content_hash(text: str) -> str:
15
+ """Return an exact content hash."""
16
+ return hashlib.sha1(text.encode("utf-8")).hexdigest()
17
+
18
+
19
+ def shingles(text: str, n: int = 5) -> set[str]:
20
+ """Build token shingles for near-duplicate detection."""
21
+ tokens = TOKEN_RE.findall(text.lower())
22
+ if len(tokens) < n:
23
+ return {" ".join(tokens)} if tokens else set()
24
+ return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
25
+
26
+
27
+ def jaccard_similarity(left: str, right: str, n: int = 5) -> float:
28
+ """Compute shingle-level Jaccard similarity."""
29
+ left_set = shingles(left, n)
30
+ right_set = shingles(right, n)
31
+ if not left_set and not right_set:
32
+ return 1.0
33
+ if not left_set or not right_set:
34
+ return 0.0
35
+ return len(left_set & right_set) / len(left_set | right_set)
36
+
37
+
38
+ def deduplicate_records(records: Iterable[dict[str, object]], near_dup_threshold: float = 0.92) -> list[dict[str, object]]:
39
+ """Drop exact and near-duplicate records."""
40
+ exact_seen: set[str] = set()
41
+ buckets: dict[str, list[dict[str, object]]] = defaultdict(list)
42
+ kept: list[dict[str, object]] = []
43
+ for record in records:
44
+ text = str(record["text"])
45
+ digest = exact_content_hash(text)
46
+ if digest in exact_seen:
47
+ continue
48
+ signature = digest[:8]
49
+ near_duplicate = False
50
+ for candidate in buckets[signature]:
51
+ if jaccard_similarity(text, str(candidate["text"])) >= near_dup_threshold:
52
+ near_duplicate = True
53
+ break
54
+ if near_duplicate:
55
+ continue
56
+ exact_seen.add(digest)
57
+ buckets[signature].append(record)
58
+ kept.append(record)
59
+ return kept
data/filter.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Corpus filtering, safety, and quality heuristics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from dataclasses import dataclass
7
+ from typing import Iterable
8
+
9
+
10
+ EMAIL_RE = re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE)
11
+ PHONE_RE = re.compile(r"(?:(?:\+?\d{1,3})?[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?){2}\d{4}")
12
+ SSN_RE = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
13
+ HTML_RE = re.compile(r"<[^>]+>")
14
+ MULTISPACE_RE = re.compile(r"[ \t]+")
15
+ NSFW_TERMS = {"porn", "explicit sex", "rape"}
16
+ HATE_TERMS = {"kill all", "ethnic cleansing"}
17
+ ALLOWED_LICENSES = {"permissive", "restricted"}
18
+ ALLOWED_LANGS = {"en", "es", "fr", "de", "hi", "zh", "ar", "pt"}
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class FilterConfig:
23
+ """Policy controls for the filtering pipeline."""
24
+
25
+ minimum_chars: int = 200
26
+ maximum_chars: int = 200_000
27
+ minimum_alpha_ratio: float = 0.45
28
+ minimum_quality_score: float = 0.20
29
+ language_confidence_threshold: float = 0.65
30
+
31
+
32
+ def normalize_text(text: str) -> str:
33
+ """Strip tags and normalize whitespace."""
34
+ text = HTML_RE.sub(" ", text)
35
+ text = MULTISPACE_RE.sub(" ", text)
36
+ return text.strip()
37
+
38
+
39
+ def detect_language(text: str) -> tuple[str, float]:
40
+ """Use a light heuristic to assign a language code."""
41
+ ascii_ratio = sum(ch.isascii() for ch in text) / max(len(text), 1)
42
+ devanagari = sum("\u0900" <= ch <= "\u097f" for ch in text)
43
+ cjk = sum("\u4e00" <= ch <= "\u9fff" for ch in text)
44
+ arabic = sum("\u0600" <= ch <= "\u06ff" for ch in text)
45
+ if cjk > 8:
46
+ return "zh", 0.95
47
+ if arabic > 8:
48
+ return "ar", 0.95
49
+ if devanagari > 8:
50
+ return "hi", 0.95
51
+ if ascii_ratio > 0.9:
52
+ return "en", 0.80
53
+ return "unknown", 0.40
54
+
55
+
56
+ def quality_score(text: str) -> float:
57
+ """Score text using length, punctuation, and alphabetic density."""
58
+ if not text:
59
+ return 0.0
60
+ alpha_ratio = sum(ch.isalpha() for ch in text) / len(text)
61
+ punct_ratio = sum(ch in ".,;:!?()[]{}" for ch in text) / len(text)
62
+ line_count = text.count("\n") + 1
63
+ score = min(len(text) / 4000.0, 1.0) * 0.4 + alpha_ratio * 0.4 + min(punct_ratio * 8.0, 1.0) * 0.2
64
+ if line_count < 2 and len(text) > 10_000:
65
+ score *= 0.85
66
+ return round(score, 4)
67
+
68
+
69
+ def quality_tier(score: float) -> str:
70
+ """Map a numeric score to a quality tier."""
71
+ if score >= 0.70:
72
+ return "high"
73
+ if score >= 0.40:
74
+ return "medium"
75
+ return "low"
76
+
77
+
78
+ def strip_pii(text: str) -> str:
79
+ """Mask basic email, phone, and SSN patterns."""
80
+ text = EMAIL_RE.sub("[EMAIL]", text)
81
+ text = PHONE_RE.sub("[PHONE]", text)
82
+ text = SSN_RE.sub("[SSN]", text)
83
+ return text
84
+
85
+
86
+ def passes_safety_filter(text: str) -> bool:
87
+ """Reject obviously unsafe content with simple keyword checks."""
88
+ lower = text.lower()
89
+ if any(term in lower for term in NSFW_TERMS):
90
+ return False
91
+ if any(term in lower for term in HATE_TERMS):
92
+ return False
93
+ return True
94
+
95
+
96
+ def license_allowed(category: str) -> bool:
97
+ """Return whether the source license category is allowed."""
98
+ return category in ALLOWED_LICENSES
99
+
100
+
101
+ def filter_record(record: dict[str, object], config: FilterConfig = FilterConfig()) -> dict[str, object] | None:
102
+ """Apply the full filter pipeline to one record."""
103
+ if not license_allowed(str(record.get("license_category", ""))):
104
+ return None
105
+ text = normalize_text(str(record.get("text", "")))
106
+ if not (config.minimum_chars <= len(text) <= config.maximum_chars):
107
+ return None
108
+ lang, confidence = detect_language(text)
109
+ if lang not in ALLOWED_LANGS or confidence < config.language_confidence_threshold:
110
+ return None
111
+ text = strip_pii(text)
112
+ if not passes_safety_filter(text):
113
+ return None
114
+ score = quality_score(text)
115
+ if score < config.minimum_quality_score:
116
+ return None
117
+ return {
118
+ **record,
119
+ "text": text,
120
+ "lang": lang,
121
+ "lang_confidence": confidence,
122
+ "quality_score": score,
123
+ "quality_tier": quality_tier(score),
124
+ "token_count_estimate": max(1, len(text) // 4),
125
+ }
126
+
127
+
128
+ def filter_corpus(records: Iterable[dict[str, object]], config: FilterConfig = FilterConfig()) -> list[dict[str, object]]:
129
+ """Filter a corpus in memory."""
130
+ kept: list[dict[str, object]] = []
131
+ for record in records:
132
+ filtered = filter_record(record, config)
133
+ if filtered is not None:
134
+ kept.append(filtered)
135
+ return kept
data/ingest.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Raw corpus ingestion utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Iterable, Iterator
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class SourceSpec:
14
+ """Describes one raw corpus source."""
15
+
16
+ name: str
17
+ domain_tag: str
18
+ quality_tier: str
19
+ license_category: str
20
+ estimated_tokens: int
21
+ path: str
22
+ text_key: str = "text"
23
+
24
+
25
+ SOURCE_REGISTRY: tuple[SourceSpec, ...] = (
26
+ SourceSpec("general_web", "general", "medium", "permissive", 20_000_000_000, "data/raw/general_web.jsonl"),
27
+ SourceSpec("code", "code", "high", "permissive", 8_000_000_000, "data/raw/code.jsonl"),
28
+ SourceSpec("math_science", "math", "high", "permissive", 4_000_000_000, "data/raw/math_science.jsonl"),
29
+ SourceSpec("books_longform", "general", "high", "restricted", 5_000_000_000, "data/raw/books.jsonl"),
30
+ SourceSpec("multilingual", "multilingual", "medium", "permissive", 3_000_000_000, "data/raw/multilingual.jsonl"),
31
+ SourceSpec("synthetic", "reasoning", "high", "permissive", 1_000_000_000, "data/raw/synthetic.jsonl"),
32
+ )
33
+
34
+
35
+ def iter_jsonl(path: Path, text_key: str = "text") -> Iterator[dict[str, object]]:
36
+ """Yield JSONL records from disk."""
37
+ with path.open("r", encoding="utf-8") as handle:
38
+ for line in handle:
39
+ line = line.strip()
40
+ if not line:
41
+ continue
42
+ payload = json.loads(line)
43
+ text = payload.get(text_key)
44
+ if not isinstance(text, str) or not text.strip():
45
+ continue
46
+ yield payload
47
+
48
+
49
+ def stream_source(spec: SourceSpec) -> Iterator[dict[str, object]]:
50
+ """Yield normalized records for one configured source."""
51
+ path = Path(spec.path)
52
+ if not path.exists():
53
+ return iter(())
54
+ return (
55
+ {
56
+ "id": stable_record_id(spec.name, record[spec.text_key]),
57
+ "text": record[spec.text_key],
58
+ "domain_tag": spec.domain_tag,
59
+ "quality_tier": spec.quality_tier,
60
+ "license_category": spec.license_category,
61
+ "source_name": spec.name,
62
+ }
63
+ for record in iter_jsonl(path, spec.text_key)
64
+ )
65
+
66
+
67
+ def stream_all_sources(sources: Iterable[SourceSpec] = SOURCE_REGISTRY) -> Iterator[dict[str, object]]:
68
+ """Yield records from every source in the registry."""
69
+ for spec in sources:
70
+ yield from stream_source(spec)
71
+
72
+
73
+ def stable_record_id(source_name: str, text: str) -> str:
74
+ """Hash a source+text pair into a stable content id."""
75
+ digest = hashlib.sha256()
76
+ digest.update(source_name.encode("utf-8"))
77
+ digest.update(b"\0")
78
+ digest.update(text.encode("utf-8"))
79
+ return digest.hexdigest()
data/shard.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenization, manifesting, and Parquet sharding."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Iterable
10
+
11
+ try:
12
+ import pyarrow as pa
13
+ import pyarrow.parquet as pq
14
+ except ImportError: # pragma: no cover - optional at import time
15
+ pa = None
16
+ pq = None
17
+
18
+
19
+ SCHEMA_COLUMNS = ("id", "text", "tokens", "domain_tag", "quality_tier", "lang", "token_count", "split")
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class ShardConfig:
24
+ """Parameters for Parquet shard writing."""
25
+
26
+ output_dir: str
27
+ shard_size: int = 2048
28
+ validation_ratio: float = 0.01
29
+ test_ratio: float = 0.001
30
+
31
+
32
+ def assign_split(record_id: str, validation_ratio: float, test_ratio: float) -> str:
33
+ """Assign a deterministic split from the content id."""
34
+ value = int(record_id[:8], 16) / 0xFFFFFFFF
35
+ if value < test_ratio:
36
+ return "test"
37
+ if value < test_ratio + validation_ratio:
38
+ return "validation"
39
+ return "train"
40
+
41
+
42
+ def build_manifest(shard_paths: Iterable[Path]) -> dict[str, object]:
43
+ """Create a manifest describing shard files."""
44
+ shard_paths = list(shard_paths)
45
+ digest = hashlib.sha256()
46
+ for path in shard_paths:
47
+ digest.update(path.name.encode("utf-8"))
48
+ digest.update(str(path.stat().st_size).encode("utf-8"))
49
+ return {
50
+ "format": "parquet",
51
+ "schema": list(SCHEMA_COLUMNS),
52
+ "shards": [path.name for path in shard_paths],
53
+ "dataset_hash": digest.hexdigest(),
54
+ }
55
+
56
+
57
+ def write_shards(records: Iterable[dict[str, object]], tokenizer, config: ShardConfig) -> dict[str, object]:
58
+ """Write tokenized records to Parquet shards and emit a manifest."""
59
+ if pa is None or pq is None:
60
+ raise ImportError("pyarrow is required to write parquet shards.")
61
+ output_dir = Path(config.output_dir)
62
+ output_dir.mkdir(parents=True, exist_ok=True)
63
+ buffer: list[dict[str, object]] = []
64
+ shard_paths: list[Path] = []
65
+ shard_index = 0
66
+ for record in records:
67
+ tokens = tokenizer.encode(str(record["text"]), out_type=int)
68
+ row = {
69
+ "id": str(record["id"]),
70
+ "text": str(record["text"]),
71
+ "tokens": tokens,
72
+ "domain_tag": str(record["domain_tag"]),
73
+ "quality_tier": str(record["quality_tier"]),
74
+ "lang": str(record["lang"]),
75
+ "token_count": len(tokens),
76
+ "split": assign_split(str(record["id"]), config.validation_ratio, config.test_ratio),
77
+ }
78
+ buffer.append(row)
79
+ if len(buffer) >= config.shard_size:
80
+ shard_paths.append(_flush_shard(output_dir, shard_index, buffer))
81
+ shard_index += 1
82
+ buffer = []
83
+ if buffer:
84
+ shard_paths.append(_flush_shard(output_dir, shard_index, buffer))
85
+ manifest = build_manifest(shard_paths)
86
+ (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
87
+ return manifest
88
+
89
+
90
+ def _flush_shard(output_dir: Path, shard_index: int, rows: list[dict[str, object]]) -> Path:
91
+ """Persist one Parquet shard."""
92
+ table = pa.table({column: [row[column] for row in rows] for column in SCHEMA_COLUMNS})
93
+ path = output_dir / f"shard-{shard_index:05d}.parquet"
94
+ pq.write_table(table, path)
95
+ return path
docs/COMMANDS.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAGE Commands
2
+
3
+ This file is the short command-only reference for the repo.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install -r requirements.txt
9
+ ```
10
+
11
+ ## Run tests
12
+
13
+ ```bash
14
+ pytest -q
15
+ ```
16
+
17
+ ## Train tokenizer
18
+
19
+ ```bash
20
+ python -m tokenizer.train_tokenizer \
21
+ --input data/raw/general_web.txt data/raw/code.txt \
22
+ --model-prefix tokenizer/tokenizer \
23
+ --vocab-size 50000
24
+ ```
25
+
26
+ ## Validate tokenizer
27
+
28
+ ```bash
29
+ bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model
30
+ ```
31
+
32
+ ## Start a short training smoke run
33
+
34
+ ```bash
35
+ python -m train.trainer \
36
+ --train-shards data/processed/shard-00000.parquet \
37
+ --validation-shards data/processed/shard-00001.parquet \
38
+ --output-dir runs/smoke \
39
+ --steps 20 \
40
+ --disable-wandb
41
+ ```
42
+
43
+ ## Start full training
44
+
45
+ ```bash
46
+ python -m train.trainer \
47
+ --model-config configs/model/1b.yaml \
48
+ --schedule-config configs/train/schedule.yaml \
49
+ --train-shards data/processed/shard-00000.parquet data/processed/shard-00001.parquet \
50
+ --validation-shards data/processed/shard-00002.parquet \
51
+ --output-dir runs/sage-1b
52
+ ```
53
+
54
+ ## Run eval harness
55
+
56
+ ```bash
57
+ bash scripts/run_eval.sh
58
+ ```
59
+
60
+ ## Start GPU server
61
+
62
+ ```bash
63
+ bash scripts/run_serve.sh
64
+ ```
65
+
66
+ ## Start CPU server
67
+
68
+ ```bash
69
+ bash scripts/run_serve_cpu.sh
70
+ ```
71
+
72
+ ## Check server health
73
+
74
+ ```bash
75
+ curl http://127.0.0.1:8000/health
76
+ ```
77
+
78
+ ## Generate tokens from the API
79
+
80
+ ```bash
81
+ curl -X POST http://127.0.0.1:8000/generate \
82
+ -H "Content-Type: application/json" \
83
+ -d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
84
+ ```
docs/flow_llm.mmd ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flowchart TB
2
+
3
+ %% =========================================================
4
+ %% SAGE - Simplified Operational Flow
5
+ %% This file is pure Mermaid so .mmd renderers can open it.
6
+ %% =========================================================
7
+
8
+ user["You / Operator"]
9
+
10
+ subgraph inputs["Inputs"]
11
+ raw["Raw text / JSONL corpus"]
12
+ cfg_model["configs/model/1b.yaml"]
13
+ cfg_train["configs/train/schedule.yaml"]
14
+ cfg_data["configs/data/mix.yaml"]
15
+ end
16
+
17
+ subgraph tokenizer["Tokenizer Stage"]
18
+ tok_train["tokenizer/train_tokenizer.py"]
19
+ tok_validate["tokenizer/validate_tokenizer.py"]
20
+ tok_model["tokenizer.model + tokenizer.vocab"]
21
+ end
22
+
23
+ subgraph prep["Data Preparation Stage"]
24
+ ingest["data/ingest.py"]
25
+ filter["data/filter.py"]
26
+ dedup["data/dedup.py"]
27
+ shard["data/shard.py"]
28
+ parquet["Parquet shards + manifest.json"]
29
+ packed["data/dataset.py<br/>PackedDataset"]
30
+ end
31
+
32
+ subgraph model["Model Stage"]
33
+ model_cfg["model/config.py"]
34
+ rms["RMSNorm"]
35
+ rope["RoPE"]
36
+ attn["GQA Attention + SDPA"]
37
+ mlp["SwiGLU MLP"]
38
+ blocks["Transformer Blocks x24"]
39
+ sage["SageTransformer"]
40
+ end
41
+
42
+ subgraph train["Training Stage"]
43
+ hw["train/hardware.py"]
44
+ opt["train/optimizer.py"]
45
+ loss["train/loss.py"]
46
+ ckpt["train/checkpoint.py"]
47
+ trainer["train/trainer.py"]
48
+ metrics["runs/<name>/metrics.jsonl"]
49
+ saves["runs/<name>/ckpt_step_xxxxxxx.pt"]
50
+ end
51
+
52
+ subgraph evals["Evaluation Stage"]
53
+ ppl["eval/perplexity.py"]
54
+ bench["eval/benchmarks.py"]
55
+ longctx["eval/long_context.py"]
56
+ regress["eval/regression.py"]
57
+ end
58
+
59
+ subgraph serving["Serving Stage"]
60
+ kv["serve/kv_cache.py"]
61
+ quant["serve/quantize.py"]
62
+ api["serve/server.py"]
63
+ cpu["serve/server_cpu.py"]
64
+ health["/health"]
65
+ generate["/generate"]
66
+ end
67
+
68
+ user --> raw
69
+ user --> cfg_model
70
+ user --> cfg_train
71
+ user --> cfg_data
72
+
73
+ raw --> tok_train
74
+ tok_train --> tok_model
75
+ tok_model --> tok_validate
76
+
77
+ raw --> ingest
78
+ cfg_data --> ingest
79
+ ingest --> filter
80
+ filter --> dedup
81
+ dedup --> shard
82
+ tok_model --> shard
83
+ shard --> parquet
84
+ parquet --> packed
85
+
86
+ cfg_model --> model_cfg
87
+ model_cfg --> rms
88
+ model_cfg --> rope
89
+ model_cfg --> attn
90
+ model_cfg --> mlp
91
+ rms --> blocks
92
+ rope --> attn
93
+ attn --> blocks
94
+ mlp --> blocks
95
+ blocks --> sage
96
+
97
+ packed --> trainer
98
+ cfg_train --> trainer
99
+ cfg_model --> trainer
100
+ hw --> trainer
101
+ opt --> trainer
102
+ loss --> trainer
103
+ ckpt --> trainer
104
+ sage --> trainer
105
+
106
+ trainer --> metrics
107
+ trainer --> saves
108
+ trainer --> ppl
109
+
110
+ sage --> ppl
111
+ sage --> bench
112
+ sage --> longctx
113
+ ppl --> regress
114
+ bench --> regress
115
+ longctx --> regress
116
+
117
+ sage --> kv
118
+ sage --> quant
119
+ sage --> api
120
+ quant --> cpu
121
+ kv --> api
122
+ api --> health
123
+ api --> generate
124
+ cpu --> health
125
+
126
+ classDef input fill:#0f172a,stroke:#93c5fd,color:#ffffff
127
+ classDef token fill:#1d4ed8,stroke:#bfdbfe,color:#ffffff
128
+ classDef prep fill:#0f766e,stroke:#99f6e4,color:#ffffff
129
+ classDef model fill:#581c87,stroke:#d8b4fe,color:#ffffff
130
+ classDef train fill:#92400e,stroke:#fde68a,color:#ffffff
131
+ classDef eval fill:#991b1b,stroke:#fecaca,color:#ffffff
132
+ classDef serve fill:#166534,stroke:#bbf7d0,color:#ffffff
133
+
134
+ class user,raw,cfg_model,cfg_train,cfg_data input
135
+ class tok_train,tok_validate,tok_model token
136
+ class ingest,filter,dedup,shard,parquet,packed prep
137
+ class model_cfg,rms,rope,attn,mlp,blocks,sage model
138
+ class hw,opt,loss,ckpt,trainer,metrics,saves train
139
+ class ppl,bench,longctx,regress eval
140
+ class kv,quant,api,cpu,health,generate serve
docs/llm_Arch.mmd ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SAGE 1B System Architecture
3
+ ---
4
+ flowchart TB
5
+
6
+ %% =========================================================
7
+ %% SAGE 1B - End-to-End Architecture and Flow Overview
8
+ %% =========================================================
9
+
10
+ user["Developer / Operator"]
11
+
12
+ subgraph repo["SAGE Repository"]
13
+ direction TB
14
+
15
+ subgraph configs["configs/"]
16
+ cfg_model["model/1b.yaml<br/>24L, 2048 d_model, 16Q / 8KV, 4096 ctx"]
17
+ cfg_data["data/mix.yaml<br/>corpus weights + split ratios"]
18
+ cfg_train["train/schedule.yaml<br/>LR, warmup, checkpoints, logging"]
19
+ end
20
+
21
+ subgraph tokenizer["tokenizer/"]
22
+ tok_train["train_tokenizer.py<br/>SentencePiece BPE training"]
23
+ tok_validate["validate_tokenizer.py<br/>roundtrip + edge-case checks"]
24
+ tok_model["tokenizer.model / tokenizer.vocab"]
25
+ end
26
+
27
+ subgraph data_layer["data/"]
28
+ ingest["ingest.py<br/>source registry + raw record streaming"]
29
+ filter["filter.py<br/>license, lang, PII, safety, quality"]
30
+ dedup["dedup.py<br/>exact + near-duplicate removal"]
31
+ shard["shard.py<br/>tokenize -> parquet shards + manifest"]
32
+ dataset["dataset.py<br/>PackedDataset + skip(n_batches)"]
33
+ end
34
+
35
+ subgraph model_layer["model/"]
36
+ model_cfg["config.py<br/>ModelConfig"]
37
+ rmsnorm["rmsnorm.py<br/>pre-norm RMSNorm"]
38
+ rope["rope.py<br/>RoPE cache + apply_rope"]
39
+ attn["attention.py<br/>fused QKV + GQA + SDPA"]
40
+ mlp["mlp.py<br/>SwiGLU FFN"]
41
+ block["block.py<br/>Transformer block"]
42
+ full_model["model.py<br/>SageTransformer"]
43
+ end
44
+
45
+ subgraph train_layer["train/"]
46
+ hw["hardware.py<br/>device / dtype / batch routing"]
47
+ dist["distributed.py<br/>single / DDP / FSDP strategy"]
48
+ opt["optimizer.py<br/>AdamW + cosine schedule"]
49
+ loss["loss.py<br/>masked next-token cross entropy"]
50
+ ckpt["checkpoint.py<br/>save / prune / resume"]
51
+ trainer["trainer.py<br/>main training loop"]
52
+ end
53
+
54
+ subgraph eval_layer["eval/"]
55
+ ppl["perplexity.py<br/>validation loss + perplexity"]
56
+ benches["benchmarks.py<br/>benchmark harness registry"]
57
+ longctx["long_context.py<br/>needle-in-haystack probes"]
58
+ regress["regression.py<br/>checkpoint metric comparison"]
59
+ end
60
+
61
+ subgraph serve_layer["serve/"]
62
+ kv["kv_cache.py<br/>cache container"]
63
+ quant["quantize.py<br/>int8 export + GGUF command helper"]
64
+ gpu_api["server.py<br/>FastAPI GPU server"]
65
+ cpu_api["server_cpu.py<br/>FastAPI CPU readiness surface"]
66
+ end
67
+
68
+ subgraph scripts["scripts/"]
69
+ s_data["run_data_pipeline.sh"]
70
+ s_train["run_training.sh"]
71
+ s_eval["run_eval.sh"]
72
+ s_serve["run_serve.sh / run_serve_cpu.sh"]
73
+ end
74
+
75
+ subgraph outputs["Runtime Outputs"]
76
+ raw["Raw text / JSONL corpora"]
77
+ parquet["Parquet shards + manifest.json"]
78
+ runs["runs/<name>/metrics.jsonl"]
79
+ checkpoints["runs/<name>/ckpt_step_xxxxxxx.pt"]
80
+ api_out["/health + /generate responses"]
81
+ end
82
+ end
83
+
84
+ %% =========================================================
85
+ %% Top-level usage
86
+ %% =========================================================
87
+
88
+ user --> s_data
89
+ user --> s_train
90
+ user --> s_eval
91
+ user --> s_serve
92
+
93
+ s_data --> tok_train
94
+ s_train --> trainer
95
+ s_eval --> benches
96
+ s_serve --> gpu_api
97
+ s_serve --> cpu_api
98
+
99
+ %% =========================================================
100
+ %% Tokenizer flow
101
+ %% =========================================================
102
+
103
+ raw --> tok_train
104
+ tok_train --> tok_model
105
+ tok_model --> tok_validate
106
+
107
+ %% =========================================================
108
+ %% Data preparation flow
109
+ %% =========================================================
110
+
111
+ raw --> ingest
112
+ ingest --> filter
113
+ filter --> dedup
114
+ dedup --> shard
115
+ tok_model --> shard
116
+ shard --> parquet
117
+ parquet --> dataset
118
+
119
+ cfg_data --> ingest
120
+ cfg_data --> filter
121
+ cfg_data --> shard
122
+
123
+ %% =========================================================
124
+ %% Model construction flow
125
+ %% =========================================================
126
+
127
+ cfg_model --> model_cfg
128
+ model_cfg --> rmsnorm
129
+ model_cfg --> rope
130
+ model_cfg --> attn
131
+ model_cfg --> mlp
132
+ rmsnorm --> block
133
+ rope --> attn
134
+ attn --> block
135
+ mlp --> block
136
+ block --> full_model
137
+
138
+ %% =========================================================
139
+ %% Training flow
140
+ %% =========================================================
141
+
142
+ cfg_train --> opt
143
+ cfg_train --> trainer
144
+ cfg_train --> ckpt
145
+ model_cfg --> full_model
146
+ dataset --> trainer
147
+ full_model --> trainer
148
+ hw --> trainer
149
+ dist --> trainer
150
+ opt --> trainer
151
+ loss --> trainer
152
+ ckpt --> trainer
153
+
154
+ trainer --> runs
155
+ trainer --> checkpoints
156
+ trainer --> ppl
157
+
158
+ %% =========================================================
159
+ %% Evaluation flow
160
+ %% =========================================================
161
+
162
+ full_model --> ppl
163
+ full_model --> benches
164
+ full_model --> longctx
165
+ ppl --> regress
166
+ benches --> regress
167
+ longctx --> regress
168
+
169
+ %% =========================================================
170
+ %% Serving flow
171
+ %% =========================================================
172
+
173
+ full_model --> kv
174
+ full_model --> quant
175
+ full_model --> gpu_api
176
+ quant --> cpu_api
177
+ kv --> gpu_api
178
+ hw --> gpu_api
179
+ gpu_api --> api_out
180
+ cpu_api --> api_out
181
+
182
+ %% =========================================================
183
+ %% Visual grouping
184
+ %% =========================================================
185
+
186
+ classDef config fill:#1f2937,stroke:#93c5fd,color:#ffffff
187
+ classDef pipeline fill:#0f766e,stroke:#5eead4,color:#ffffff
188
+ classDef model fill:#4c1d95,stroke:#c4b5fd,color:#ffffff
189
+ classDef train fill:#92400e,stroke:#fcd34d,color:#ffffff
190
+ classDef eval fill:#7f1d1d,stroke:#fca5a5,color:#ffffff
191
+ classDef serve fill:#065f46,stroke:#86efac,color:#ffffff
192
+ classDef io fill:#111827,stroke:#9ca3af,color:#ffffff
193
+ classDef actor fill:#2563eb,stroke:#bfdbfe,color:#ffffff
194
+
195
+ class user actor
196
+ class cfg_model,cfg_data,cfg_train config
197
+ class tok_train,tok_validate,ingest,filter,dedup,shard,dataset pipeline
198
+ class model_cfg,rmsnorm,rope,attn,mlp,block,full_model model
199
+ class hw,dist,opt,loss,ckpt,trainer train
200
+ class ppl,benches,longctx,regress eval
201
+ class kv,quant,gpu_api,cpu_api serve
202
+ class raw,parquet,runs,checkpoints,api_out,tok_model,s_data,s_train,s_eval,s_serve io
eval/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Evaluation helpers for SAGE."""
eval/benchmarks.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark harness registration for SAGE."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class BenchmarkResult:
10
+ """A normalized benchmark result."""
11
+
12
+ name: str
13
+ status: str
14
+ score: float | None
15
+ detail: str
16
+
17
+
18
+ BENCHMARKS = (
19
+ "hellaswag",
20
+ "winogrande",
21
+ "arc_easy",
22
+ "arc_challenge",
23
+ "gsm8k",
24
+ "math",
25
+ "humaneval",
26
+ "mbpp",
27
+ )
28
+
29
+
30
+ def run_registered_benchmarks(model, tokenizer=None) -> list[BenchmarkResult]:
31
+ """Return a lightweight result set for the configured benchmarks."""
32
+ _ = model
33
+ _ = tokenizer
34
+ return [
35
+ BenchmarkResult(
36
+ name=name,
37
+ status="skipped",
38
+ score=None,
39
+ detail="Benchmark harness registered; dataset/task execution is external to unit tests.",
40
+ )
41
+ for name in BENCHMARKS
42
+ ]
eval/long_context.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Long-context retrieval evaluation helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class RetrievalProbe:
10
+ """A synthetic retrieval probe for long-context checks."""
11
+
12
+ prompt: str
13
+ needle: str
14
+ expected_index: int
15
+
16
+
17
+ def build_needle_in_haystack_probe(context_length: int) -> RetrievalProbe:
18
+ """Create a deterministic retrieval prompt for smoke tests."""
19
+ needle = "SAGE_LONG_CONTEXT_NEEDLE"
20
+ haystack = ["token"] * max(context_length - 16, 16)
21
+ insert_at = min(len(haystack) // 2, max(context_length // 4, 1))
22
+ haystack.insert(insert_at, needle)
23
+ prompt = " ".join(haystack)
24
+ return RetrievalProbe(prompt=prompt, needle=needle, expected_index=insert_at)
eval/perplexity.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validation perplexity evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+ from train.loss import masked_cross_entropy
10
+
11
+
12
+ @torch.no_grad()
13
+ def evaluate_perplexity(
14
+ model: torch.nn.Module,
15
+ dataloader,
16
+ device: torch.device,
17
+ dtype: torch.dtype | None = None,
18
+ max_batches: int = 16,
19
+ ) -> dict[str, float]:
20
+ """Evaluate average loss and perplexity on a validation loader."""
21
+ model.eval()
22
+ losses: list[float] = []
23
+ for index, batch in enumerate(dataloader):
24
+ if index >= max_batches:
25
+ break
26
+ input_ids = batch["input_ids"].to(device)
27
+ labels = batch["labels"].to(device)
28
+ loss_mask = batch["loss_mask"].to(device)
29
+ if dtype is not None and device.type != "cpu":
30
+ with torch.amp.autocast(device_type=device.type, dtype=dtype):
31
+ logits, _ = model(input_ids)
32
+ loss = masked_cross_entropy(logits, labels, loss_mask)
33
+ else:
34
+ logits, _ = model(input_ids)
35
+ loss = masked_cross_entropy(logits, labels, loss_mask)
36
+ losses.append(float(loss))
37
+ model.train()
38
+ mean_loss = sum(losses) / max(len(losses), 1)
39
+ return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))}
eval/regression.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Checkpoint-to-checkpoint regression checks."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def compare_metrics(previous: dict[str, float], current: dict[str, float], threshold: float = 0.005) -> dict[str, object]:
7
+ """Flag metric drops larger than the configured threshold."""
8
+ regressions: list[str] = []
9
+ for key, prev_value in previous.items():
10
+ curr_value = current.get(key)
11
+ if curr_value is None:
12
+ continue
13
+ if curr_value < prev_value * (1.0 - threshold):
14
+ regressions.append(key)
15
+ return {"regressions": regressions, "passed": not regressions}
hf_push.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Upload the current SAGE repository contents to the Hugging Face Hub."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from huggingface_hub import HfApi
6
+
7
+
8
+ REPO_ID = "sage002/sage"
9
+
10
+
11
+ def main() -> None:
12
+ """Replace the remote Hugging Face repo contents with the local folder state."""
13
+ api = HfApi()
14
+ print(f"Syncing current repository to {REPO_ID}...")
15
+ api.upload_folder(
16
+ folder_path=".",
17
+ repo_id=REPO_ID,
18
+ repo_type="model",
19
+ ignore_patterns=[
20
+ ".git/*",
21
+ ".venv/*",
22
+ "__pycache__/*",
23
+ "*.pyc",
24
+ "checkpoints/*",
25
+ "runs/*",
26
+ "wandb/*",
27
+ "data/raw/*",
28
+ "data/processed/*",
29
+ "tokenizer/*.model",
30
+ "tokenizer/*.vocab",
31
+ "tokenizer/training_corpus.txt",
32
+ ],
33
+ delete_patterns="*",
34
+ commit_message="feat: rewrite SAGE 1B architecture and replace legacy repo contents",
35
+ )
36
+ print("Sync complete.")
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Model architecture for SAGE."""
model/attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grouped-query attention with SDPA and KV-cache support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from model.config import ModelConfig
12
+ from model.rope import apply_rope
13
+
14
+
15
+ def repeat_kv(x: torch.Tensor, num_groups: int) -> torch.Tensor:
16
+ """Expand KV heads to match the number of query heads."""
17
+ if num_groups == 1:
18
+ return x
19
+ batch, kv_heads, seq_len, head_dim = x.shape
20
+ x = x[:, :, None, :, :].expand(batch, kv_heads, num_groups, seq_len, head_dim)
21
+ return x.reshape(batch, kv_heads * num_groups, seq_len, head_dim)
22
+
23
+
24
+ class GQAAttention(nn.Module):
25
+ """Fused-QKV grouped-query attention."""
26
+
27
+ def __init__(self, config: ModelConfig):
28
+ super().__init__()
29
+ self.config = config
30
+ self.num_heads = config.num_attn_heads
31
+ self.num_kv_heads = config.num_kv_heads
32
+ self.head_dim = config.head_dim
33
+ self.num_groups = self.num_heads // self.num_kv_heads
34
+ qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
35
+ self.qkv_proj = nn.Linear(config.d_model, qkv_dim, bias=False)
36
+ self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
37
+ self.dropout = config.dropout
38
+
39
+ def forward(
40
+ self,
41
+ hidden_states: torch.Tensor,
42
+ cos: torch.Tensor,
43
+ sin: torch.Tensor,
44
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
45
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
46
+ """Compute causal self-attention and return an updated KV cache."""
47
+ batch_size, seq_len, _ = hidden_states.shape
48
+ qkv = self.qkv_proj(hidden_states)
49
+ q_end = self.num_heads * self.head_dim
50
+ k_end = q_end + self.num_kv_heads * self.head_dim
51
+ q, k, v = qkv.split((q_end, self.num_kv_heads * self.head_dim, self.num_kv_heads * self.head_dim), dim=-1)
52
+
53
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
54
+ k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
55
+ v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
56
+
57
+ q_rope, k_rope = apply_rope(q, repeat_kv(k, self.num_groups), cos, sin)
58
+ k = k_rope[:, :: self.num_groups, :, :]
59
+
60
+ if past_key_value is not None:
61
+ past_key, past_value = past_key_value
62
+ k = torch.cat([past_key, k], dim=-2)
63
+ v = torch.cat([past_value, v], dim=-2)
64
+
65
+ expanded_k = repeat_kv(k, self.num_groups)
66
+ expanded_v = repeat_kv(v, self.num_groups)
67
+ attn_output = F.scaled_dot_product_attention(
68
+ q_rope,
69
+ expanded_k,
70
+ expanded_v,
71
+ attn_mask=None,
72
+ dropout_p=self.dropout if self.training else 0.0,
73
+ is_causal=past_key_value is None,
74
+ )
75
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.config.d_model)
76
+ return self.out_proj(attn_output), (k, v)
model/block.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Transformer block for the dense SAGE model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from model.attention import GQAAttention
11
+ from model.config import ModelConfig
12
+ from model.mlp import SwiGLUMLP
13
+ from model.rmsnorm import RMSNorm
14
+
15
+
16
+ class TransformerBlock(nn.Module):
17
+ """Pre-norm transformer block with attention and SwiGLU."""
18
+
19
+ def __init__(self, config: ModelConfig):
20
+ super().__init__()
21
+ self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
22
+ self.attn = GQAAttention(config)
23
+ self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
24
+ self.mlp = SwiGLUMLP(config)
25
+
26
+ def forward(
27
+ self,
28
+ hidden_states: torch.Tensor,
29
+ cos: torch.Tensor,
30
+ sin: torch.Tensor,
31
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
32
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
33
+ """Forward pass with residual connections."""
34
+ attn_output, present = self.attn(self.norm1(hidden_states), cos, sin, past_key_value=past_key_value)
35
+ hidden_states = hidden_states + attn_output
36
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
37
+ return hidden_states, present
model/config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model configuration for SAGE."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict, dataclass
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import yaml
10
+
11
+
12
+ @dataclass
13
+ class ModelConfig:
14
+ """Configuration for the dense SAGE decoder-only transformer."""
15
+
16
+ name: str = "sage-1b"
17
+ num_layers: int = 24
18
+ d_model: int = 2048
19
+ num_attn_heads: int = 16
20
+ num_kv_heads: int = 8
21
+ head_dim: int = 128
22
+ ffn_hidden_dim: int = 5632
23
+ vocab_size: int = 50_000
24
+ context_length: int = 4096
25
+ rope_base_frequency: int = 500_000
26
+ rope_scaling_factor: float = 1.0
27
+ dropout: float = 0.0
28
+ tie_word_embeddings: bool = True
29
+ rms_norm_eps: float = 1.0e-5
30
+ initializer_range: float = 0.02
31
+
32
+ def __post_init__(self) -> None:
33
+ if self.num_attn_heads * self.head_dim != self.d_model:
34
+ raise ValueError("num_attn_heads * head_dim must equal d_model.")
35
+ if self.num_attn_heads % self.num_kv_heads != 0:
36
+ raise ValueError("num_attn_heads must be divisible by num_kv_heads.")
37
+ if self.ffn_hidden_dim % 256 != 0:
38
+ raise ValueError("ffn_hidden_dim must be a multiple of 256.")
39
+
40
+ @classmethod
41
+ def from_yaml(cls, path: str | Path) -> "ModelConfig":
42
+ """Load a config from YAML."""
43
+ payload = yaml.safe_load(Path(path).read_text(encoding="utf-8"))
44
+ return cls(**payload)
45
+
46
+ def to_dict(self) -> dict[str, Any]:
47
+ """Serialize the config to a dict."""
48
+ return asdict(self)
model/mlp.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SwiGLU feed-forward module."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from model.config import ModelConfig
10
+
11
+
12
+ class SwiGLUMLP(nn.Module):
13
+ """Bias-free SwiGLU feed-forward network."""
14
+
15
+ def __init__(self, config: ModelConfig):
16
+ super().__init__()
17
+ self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
18
+ self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
19
+ self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False)
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ """Apply SwiGLU and project back to the model width."""
23
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
model/model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full dense decoder-only transformer model for SAGE."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from model.block import TransformerBlock
12
+ from model.config import ModelConfig
13
+ from model.rope import build_rope_cache
14
+ from model.rmsnorm import RMSNorm
15
+
16
+
17
+ class SageTransformer(nn.Module):
18
+ """A dense Llama-style decoder-only transformer."""
19
+
20
+ def __init__(self, config: ModelConfig):
21
+ super().__init__()
22
+ self.config = config
23
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
24
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
25
+ self.norm = RMSNorm(config.d_model, eps=config.rms_norm_eps)
26
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
27
+ if config.tie_word_embeddings:
28
+ self.lm_head.weight = self.embed_tokens.weight
29
+ cos, sin = build_rope_cache(
30
+ seq_len=config.context_length,
31
+ head_dim=config.head_dim,
32
+ base_frequency=config.rope_base_frequency,
33
+ scaling_factor=config.rope_scaling_factor,
34
+ )
35
+ self.register_buffer("rope_cos", cos, persistent=False)
36
+ self.register_buffer("rope_sin", sin, persistent=False)
37
+ self._reset_parameters()
38
+
39
+ def _reset_parameters(self) -> None:
40
+ """Apply scaled initialization to the model."""
41
+ embed_std = 1.0 / math.sqrt(self.config.d_model)
42
+ nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=embed_std)
43
+ for module in self.modules():
44
+ if not isinstance(module, nn.Linear):
45
+ continue
46
+ std = self.config.initializer_range
47
+ if module is self.lm_head and self.config.tie_word_embeddings:
48
+ continue
49
+ if module.out_features == self.config.d_model:
50
+ std = std / math.sqrt(2 * self.config.num_layers)
51
+ nn.init.normal_(module.weight, mean=0.0, std=std)
52
+
53
+ def forward(
54
+ self,
55
+ input_ids: torch.Tensor,
56
+ past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
57
+ ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
58
+ """Return logits and the updated KV cache."""
59
+ batch_size, seq_len = input_ids.shape
60
+ hidden_states = self.embed_tokens(input_ids)
61
+ past_key_values = past_key_values or [None] * self.config.num_layers
62
+ start = 0
63
+ if past_key_values[0] is not None:
64
+ start = past_key_values[0][0].size(-2)
65
+ cos = self.rope_cos[start : start + seq_len].to(hidden_states.device)
66
+ sin = self.rope_sin[start : start + seq_len].to(hidden_states.device)
67
+ presents: list[tuple[torch.Tensor, torch.Tensor]] = []
68
+ for layer, past in zip(self.layers, past_key_values):
69
+ hidden_states, present = layer(hidden_states, cos, sin, past_key_value=past)
70
+ presents.append(present)
71
+ hidden_states = self.norm(hidden_states)
72
+ logits = self.lm_head(hidden_states)
73
+ return logits, presents
model/rmsnorm.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RMSNorm implementation used by SAGE."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class RMSNorm(nn.Module):
10
+ """Root mean square normalization with float32 accumulation."""
11
+
12
+ def __init__(self, dim: int, eps: float = 1.0e-5):
13
+ super().__init__()
14
+ self.eps = eps
15
+ self.weight = nn.Parameter(torch.ones(dim))
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ """Normalize the last dimension and cast back to the input dtype."""
19
+ if x.ndim < 2:
20
+ raise ValueError("RMSNorm expects at least 2 dimensions.")
21
+ variance = x.float().pow(2).mean(dim=-1, keepdim=True)
22
+ normalized = x.float() * torch.rsqrt(variance + self.eps)
23
+ return (normalized.to(dtype=x.dtype)) * self.weight
model/rope.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Rotary positional embedding helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+
8
+ def _scaled_positions(seq_len: int, scaling_factor: float, device: torch.device) -> torch.Tensor:
9
+ """Apply a simple YaRN-style position scaling factor."""
10
+ positions = torch.arange(seq_len, device=device, dtype=torch.float32)
11
+ if scaling_factor > 1.0:
12
+ positions = positions / scaling_factor
13
+ return positions
14
+
15
+
16
+ def build_rope_cache(
17
+ seq_len: int,
18
+ head_dim: int,
19
+ base_frequency: int = 500_000,
20
+ scaling_factor: float = 1.0,
21
+ device: torch.device | None = None,
22
+ ) -> tuple[torch.Tensor, torch.Tensor]:
23
+ """Precompute cosine and sine tables for RoPE."""
24
+ if head_dim % 2 != 0:
25
+ raise ValueError("head_dim must be even for RoPE.")
26
+ device = device or torch.device("cpu")
27
+ positions = _scaled_positions(seq_len, scaling_factor, device)
28
+ inv_freq = 1.0 / (base_frequency ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim))
29
+ freqs = torch.outer(positions, inv_freq)
30
+ cos = torch.cos(freqs)
31
+ sin = torch.sin(freqs)
32
+ return cos, sin
33
+
34
+
35
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
36
+ """Rotate the last dimension in pairs."""
37
+ even = x[..., ::2]
38
+ odd = x[..., 1::2]
39
+ rotated = torch.stack((-odd, even), dim=-1)
40
+ return rotated.flatten(start_dim=-2)
41
+
42
+
43
+ def apply_rope(
44
+ q: torch.Tensor,
45
+ k: torch.Tensor,
46
+ cos: torch.Tensor,
47
+ sin: torch.Tensor,
48
+ ) -> tuple[torch.Tensor, torch.Tensor]:
49
+ """Apply rotary embeddings to query and key tensors."""
50
+ if q.shape != k.shape:
51
+ raise ValueError("q and k must share the same shape for RoPE application.")
52
+ seq_len = q.size(-2)
53
+ cos = cos[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1)
54
+ sin = sin[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1)
55
+ q_out = (q * cos) + (rotate_half(q) * sin)
56
+ k_out = (k * cos) + (rotate_half(k) * sin)
57
+ return q_out, k_out
requirements.txt CHANGED
@@ -1,27 +1,13 @@
1
- # SAGE - Self-Adaptive General Engine
2
- # ======================================
3
- # Core dependencies
4
-
5
- # PyTorch - GPU compatibility notes:
6
- # - For Tesla P100 (sm_60), V100, T4, A100: torch>=2.1.0
7
- # - For older GPUs (sm_60): use torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121
8
- # - The code auto-detects GPU compatibility and falls back to CPU if needed
9
  torch>=2.1.0
10
-
11
- # Tokenization & Data
12
- tiktoken>=0.5.1
13
- datasets>=2.14.0
14
-
15
- # Vector search (for RAG)
16
- faiss-cpu>=1.7.4
17
-
18
- # Utilities
19
- tqdm>=4.66.1
20
- numpy<2.0.0
21
-
22
- # Quantization (optional GPU support)
23
- bitsandbytes>=0.41.0
24
-
25
- # Model Hub & Experiment Tracking
26
- huggingface_hub>=0.20.0
27
- wandb>=0.16.0
 
 
 
 
 
 
 
 
 
1
  torch>=2.1.0
2
+ fastapi>=0.110.0
3
+ uvicorn>=0.29.0
4
+ python-multipart>=0.0.9
5
+ pydantic>=2.7.0
6
+ pyyaml>=6.0.1
7
+ sentencepiece>=0.2.0
8
+ pyarrow>=16.0.0
9
+ psutil>=5.9.8
10
+ wandb>=0.17.0
11
+ pytest>=8.2.0
12
+ httpx>=0.27.0
13
+ bitsandbytes>=0.43.0
 
 
 
 
 
 
sage/__init__.py DELETED
@@ -1,15 +0,0 @@
1
- """
2
- SAGE — Self-Adaptive General Engine
3
- A complete mini-LLM system built from scratch.
4
- """
5
-
6
- __version__ = "1.0.0"
7
-
8
- from .model import SageModel
9
- from .config import SageConfig
10
- from .data import SageTokenizer
11
- from .inference import generate
12
- from .memory import ConversationHistory, RAGManager
13
- from .train import train
14
- from .finetune import finetune_instruction as finetune
15
- from .utils import get_compatible_device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/cli.py DELETED
@@ -1,299 +0,0 @@
1
- """
2
- SAGE CLI — Interactive Terminal Interface
3
- ==========================================
4
- Provides a REPL with slash-commands for training, fine-tuning, quantization,
5
- RAG toggling, and real-time chat with streaming output.
6
- """
7
-
8
- import sys
9
- import os
10
- import torch
11
- from typing import Optional
12
-
13
- from .config import SageConfig
14
- from .model import SageModel
15
- from .data import SageTokenizer
16
- from .train import train
17
- from .inference import generate
18
- from .finetune import finetune_instruction, DEMO_INSTRUCTION_SAMPLES
19
- from .optimize import quantize_int8
20
- from .memory import RAGManager, ConversationHistory
21
- from .utils import setup_logger, save_checkpoint, load_checkpoint
22
- from . import __version__
23
-
24
- logger = setup_logger("sage.cli")
25
-
26
- # ===================================================================
27
- # Banner
28
- # ===================================================================
29
-
30
- BANNER = r"""
31
- ╔══════════════════════════════════════════════════════════════╗
32
- ║ ║
33
- ║ ███████ █████ ██████ ███████ ║
34
- ║ ██ ██ ██ ██ ██ ║
35
- ║ ███████ ███████ ██ ███ █████ ║
36
- ║ ██ ██ ██ ██ ██ ██ ║
37
- ║ ███████ ██ ██ ██████ ███████ ║
38
- ║ ║
39
- ║ Self-Adaptive General Engine v{version} ║
40
- ║ ║
41
- ╚══════════════════════════════════════════════════════════════╝
42
- """
43
-
44
-
45
- def print_banner(model: SageModel, config: SageConfig) -> None:
46
- """Display startup banner with model statistics."""
47
- base_model = getattr(model, "module", model)
48
- total_params = sum(p.numel() for p in base_model.parameters())
49
- trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
50
-
51
- print(BANNER.format(version=__version__))
52
- print(f" Model params : {total_params:,} ({total_params/1e6:.1f}M)")
53
- print(f" Trainable : {trainable_params:,}")
54
- print(f" Context length: {config.max_seq_len}")
55
- print(f" Device : {config.device}")
56
- print(f" Layers: {config.n_layers} | Heads: {config.n_heads} | Experts: {config.n_experts}")
57
- print()
58
- print(" Type /help for commands, or start chatting!\n")
59
-
60
-
61
- # ===================================================================
62
- # Help text
63
- # ===================================================================
64
-
65
- HELP_TEXT = """
66
- Available Commands:
67
- /train [steps] Train the model (default: 100 steps)
68
- /finetune [steps] Instruction-tune with LoRA (default: 200 steps)
69
- /save Save current model checkpoint
70
- /load Load latest checkpoint
71
- /quantize Quantize model to INT8 (CPU only)
72
- /rag on|off Enable/disable retrieval-augmented generation
73
- /rag add <text> Add a document for RAG retrieval
74
- /clear Clear conversation history
75
- /help Show this message
76
- /exit Exit SAGE
77
- """
78
-
79
-
80
- # ===================================================================
81
- # Command handlers
82
- # ===================================================================
83
-
84
- def handle_train(model, config, tokenizer, args):
85
- """Handle /train [steps]"""
86
- steps = 100
87
- if args:
88
- try:
89
- steps = int(args[0])
90
- except ValueError:
91
- print(f" Invalid step count: {args[0]}")
92
- return model
93
-
94
- print(f"\n Starting training for {steps} steps …\n")
95
- model = train(model, config, total_steps=steps, tokenizer=tokenizer, resume=True)
96
-
97
- # Show a quick sample after training
98
- print("\n --- Sample generation after training ---")
99
- generate(model, tokenizer, "Once upon a time", max_new_tokens=80, stream=True, device=config.device)
100
- print()
101
- return model
102
-
103
-
104
- def handle_finetune(model, config, tokenizer, args):
105
- """Handle /finetune [steps]"""
106
- steps = 200
107
- if args:
108
- try:
109
- steps = int(args[0])
110
- except ValueError:
111
- print(f" Invalid step count: {args[0]}")
112
- return model
113
-
114
- print(f"\n Starting instruction fine-tuning for {steps} steps (LoRA) …\n")
115
- model = finetune_instruction(
116
- model, config,
117
- samples=DEMO_INSTRUCTION_SAMPLES,
118
- total_steps=steps,
119
- use_lora=True,
120
- tokenizer=tokenizer,
121
- )
122
-
123
- print("\n --- Sample after fine-tuning ---")
124
- prompt = "### Instruction:\nWhat is the speed of light?\n\n### Response:\n"
125
- generate(model, tokenizer, prompt, max_new_tokens=100, stream=True, device=config.device)
126
- print()
127
- return model
128
-
129
-
130
- def handle_save(model, config):
131
- """Handle /save"""
132
- path = save_checkpoint(model, None, 0, config.checkpoint_dir)
133
- print(f" Model saved to {path}")
134
-
135
-
136
- def handle_load(model, config):
137
- """Handle /load"""
138
- model, _, step = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device))
139
- model = model.to(config.device)
140
- print(f" Model loaded (step {step})")
141
- return model
142
-
143
-
144
- def handle_quantize(model):
145
- """Handle /quantize"""
146
- print(" Quantizing model to INT8 (model will be on CPU) …")
147
- model = quantize_int8(model)
148
- print(" Quantization complete.")
149
- return model
150
-
151
-
152
- def handle_rag(rag_manager: RAGManager, args):
153
- """Handle /rag on|off|add <text>"""
154
- if not args:
155
- state = "enabled" if rag_manager.enabled else "disabled"
156
- print(f" RAG is currently {state} ({rag_manager.store.size} chunks indexed)")
157
- return
158
-
159
- subcmd = args[0].lower()
160
- if subcmd == "on":
161
- rag_manager.toggle(True)
162
- print(" RAG enabled.")
163
- elif subcmd == "off":
164
- rag_manager.toggle(False)
165
- print(" RAG disabled.")
166
- elif subcmd == "add":
167
- text = " ".join(args[1:])
168
- if text:
169
- rag_manager.add_documents([text])
170
- print(f" Document added. Store now has {rag_manager.store.size} chunks.")
171
- else:
172
- print(" Usage: /rag add <your document text here>")
173
- else:
174
- print(" Usage: /rag on|off|add <text>")
175
-
176
-
177
- # ===================================================================
178
- # Main REPL
179
- # ===================================================================
180
-
181
- def main() -> None:
182
- """Entry point for the SAGE interactive CLI."""
183
- config = SageConfig()
184
- tokenizer = SageTokenizer()
185
-
186
- # Ensure vocab_size matches the tokenizer
187
- config.vocab_size = tokenizer.vocab_size
188
-
189
- print(" Initializing SAGE model …")
190
- model = SageModel(config)
191
- model = model.to(config.device)
192
-
193
- if torch.cuda.is_available() and torch.cuda.device_count() > 1:
194
- print(f" Multi-GPU detected! Wrapping model in DataParallel across {torch.cuda.device_count()} GPUs.")
195
- model = torch.nn.DataParallel(model)
196
-
197
- # Attempt to load existing checkpoint
198
- model, _, loaded_step = load_checkpoint(
199
- model, None, config.checkpoint_dir, device=str(config.device)
200
- )
201
- if loaded_step > 0:
202
- print(f" Resumed from checkpoint at step {loaded_step}")
203
-
204
- print_banner(model, config)
205
-
206
- # Initialize RAG and conversation history
207
- rag_manager = RAGManager(model, tokenizer, config.device)
208
- history = ConversationHistory(tokenizer, max_tokens=config.max_seq_len - 128)
209
-
210
- # ---------- One-liner CLI arguments ----------
211
- if len(sys.argv) > 1:
212
- cmd = sys.argv[1].lower()
213
- args = sys.argv[2:]
214
- if cmd == "--train":
215
- handle_train(model, config, tokenizer, args)
216
- elif cmd == "--finetune":
217
- handle_finetune(model, config, tokenizer, args)
218
- elif cmd == "--quantize":
219
- handle_quantize(model)
220
- else:
221
- print(f" Unknown argument: {cmd}")
222
- print(" Usage: --train [steps] | --finetune [steps] | --quantize")
223
- return
224
-
225
- # ---------- REPL loop ----------
226
- while True:
227
- try:
228
- user_input = input("You: ").strip()
229
- except (EOFError, KeyboardInterrupt):
230
- print("\n Goodbye!")
231
- break
232
-
233
- if not user_input:
234
- continue
235
-
236
- # ---------- Slash commands ----------
237
- if user_input.startswith("/"):
238
- parts = user_input.split()
239
- cmd = parts[0].lower()
240
- args = parts[1:]
241
-
242
- if cmd == "/exit":
243
- print(" Goodbye!")
244
- break
245
- elif cmd == "/help":
246
- print(HELP_TEXT)
247
- elif cmd == "/train":
248
- model = handle_train(model, config, tokenizer, args)
249
- elif cmd == "/finetune":
250
- model = handle_finetune(model, config, tokenizer, args)
251
- elif cmd == "/save":
252
- handle_save(model, config)
253
- elif cmd == "/load":
254
- model = handle_load(model, config)
255
- # Re-attach to RAG manager since model changed
256
- rag_manager.model = model
257
- elif cmd == "/quantize":
258
- model = handle_quantize(model)
259
- rag_manager.model = model
260
- elif cmd == "/rag":
261
- handle_rag(rag_manager, args)
262
- elif cmd == "/clear":
263
- history.clear()
264
- print(" Conversation history cleared.")
265
- else:
266
- print(f" Unknown command: {cmd}. Type /help for a list.")
267
- continue
268
-
269
- # ---------- Chat mode ----------
270
- # Build prompt with history and optional RAG context
271
- rag_context = rag_manager.retrieve_context(user_input)
272
- prompt = history.build_prompt(user_input, rag_context=rag_context)
273
-
274
- history.add("user", user_input)
275
-
276
- print("SAGE: ", end="", flush=True)
277
- response = generate(
278
- model,
279
- tokenizer,
280
- prompt,
281
- max_new_tokens=256,
282
- temperature=0.8,
283
- top_k=50,
284
- top_p=0.9,
285
- stream=True,
286
- device=config.device,
287
- )
288
-
289
- # Extract only the SAGE response part from the full generation
290
- if "SAGE:" in response:
291
- reply = response.split("SAGE:")[-1].strip()
292
- else:
293
- reply = response[len(prompt):].strip()
294
-
295
- history.add("assistant", reply)
296
-
297
-
298
- if __name__ == "__main__":
299
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/config.py DELETED
@@ -1,48 +0,0 @@
1
- import os
2
- from dataclasses import dataclass, field
3
- from typing import Any
4
-
5
- @dataclass
6
- class SageConfig:
7
- # Model dimensions corresponding to T4 (16GB VRAM) fit
8
- d_model: int = 512
9
- n_heads: int = 8
10
- n_kv_heads: int = 4 # GQA: must divide n_heads
11
- n_layers: int = 6
12
- d_ff: int = 2048
13
-
14
- # MoE (Mixture of Experts) config
15
- n_experts: int = 4
16
- num_experts_per_tok: int = 2
17
-
18
- # Vocabulary and sequence parameters
19
- vocab_size: int = 100277 # Default for tiktoken "cl100k_base"
20
- max_seq_len: int = 1024
21
-
22
- # Regularization
23
- dropout: float = 0.1
24
-
25
- # Training Loop defaults
26
- batch_size: int = 4
27
- gradient_accumulation_steps: int = 16
28
- learning_rate: float = 3e-4
29
- min_learning_rate: float = 1e-5
30
- warmup_steps: int = 100
31
- weight_decay: float = 0.01
32
- max_grad_norm: float = 1.0
33
-
34
- # Checkpointing and path details
35
- checkpoint_dir: str = "checkpoints"
36
- project_name: str = "sage-v2"
37
-
38
- # Cache for device (set on first access)
39
- _device: Any = field(default=None, repr=False)
40
-
41
- @property
42
- def device(self):
43
- """Returns the best available device with CUDA compatibility checking."""
44
- if self._device is None:
45
- # Import here to avoid circular imports
46
- from .utils import get_compatible_device
47
- self._device = get_compatible_device()
48
- return self._device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/data.py DELETED
@@ -1,255 +0,0 @@
1
- """
2
- SAGE Data Pipeline
3
- ==================
4
- Handles tokenization (tiktoken), streaming dataset loading from HuggingFace,
5
- text cleaning, chunking into fixed-length sequences, and batched DataLoader
6
- construction with shuffle buffering.
7
- """
8
-
9
- import re
10
- import random
11
- import tiktoken
12
- import torch
13
- from torch.utils.data import IterableDataset, DataLoader
14
- from typing import Iterator, List, Optional
15
- from .config import SageConfig
16
- from .utils import setup_logger
17
-
18
- logger = setup_logger("sage.data")
19
-
20
- # ---------------------------------------------------------------------------
21
- # Tokenizer wrapper
22
- # ---------------------------------------------------------------------------
23
-
24
- class SageTokenizer:
25
- """Thin wrapper around tiktoken providing encode/decode and special tokens."""
26
-
27
- def __init__(self, encoding_name: str = "cl100k_base"):
28
- self.enc = tiktoken.get_encoding(encoding_name)
29
- self.encoding_name = encoding_name
30
-
31
- # Use the last token in the vocabulary as the EOS sentinel.
32
- # tiktoken doesn't expose a dedicated EOS, so we pick one that
33
- # won't collide with real text.
34
- self.eos_token_id: int = self.enc.n_vocab - 1
35
- self.pad_token_id: int = self.enc.n_vocab - 2
36
- self.vocab_size: int = self.enc.n_vocab
37
-
38
- def encode(self, text: str, add_eos: bool = False) -> List[int]:
39
- """Encode text to token IDs."""
40
- tokens = self.enc.encode(text, allowed_special="all")
41
- if add_eos:
42
- tokens.append(self.eos_token_id)
43
- return tokens
44
-
45
- def decode(self, tokens: List[int]) -> str:
46
- """Decode token IDs back to text, filtering out special sentinel IDs."""
47
- # Filter out our custom pad/eos sentinels before decoding
48
- filtered = [t for t in tokens if t not in (self.eos_token_id, self.pad_token_id)]
49
- return self.enc.decode(filtered)
50
-
51
- # ---------------------------------------------------------------------------
52
- # Text cleaning
53
- # ---------------------------------------------------------------------------
54
-
55
- _HTML_TAG_RE = re.compile(r"<[^>]+>")
56
- _MULTI_SPACE_RE = re.compile(r"[ \t]+")
57
- _MULTI_NEWLINE_RE = re.compile(r"\n{3,}")
58
-
59
-
60
- def clean_text(text: str) -> str:
61
- """Strip HTML tags, collapse whitespace, and trim to reasonable length."""
62
- text = _HTML_TAG_RE.sub("", text) # remove HTML tags
63
- text = _MULTI_SPACE_RE.sub(" ", text) # collapse horizontal whitespace
64
- text = _MULTI_NEWLINE_RE.sub("\n\n", text) # collapse vertical whitespace
65
- return text.strip()
66
-
67
- # ---------------------------------------------------------------------------
68
- # Streaming iterable dataset
69
- # ---------------------------------------------------------------------------
70
-
71
- class StreamingTextDataset(IterableDataset):
72
- """
73
- An IterableDataset that streams data from HuggingFace ``datasets``,
74
- tokenizes on the fly, and yields fixed-length chunks.
75
-
76
- It maintains an internal shuffle buffer so that consecutive chunks are
77
- not always from the same document.
78
- """
79
-
80
- def __init__(
81
- self,
82
- dataset_name: str = "HuggingFaceFW/fineweb-edu",
83
- split: str = "train",
84
- seq_len: int = 512,
85
- tokenizer: Optional[SageTokenizer] = None,
86
- shuffle_buffer_size: int = 1000,
87
- text_field: str = "text",
88
- min_doc_len: int = 50,
89
- max_doc_len: int = 50000,
90
- ):
91
- super().__init__()
92
- self.dataset_name = dataset_name
93
- self.split = split
94
- self.seq_len = seq_len
95
- self.tokenizer = tokenizer or SageTokenizer()
96
- self.shuffle_buffer_size = shuffle_buffer_size
97
- self.text_field = text_field
98
- self.min_doc_len = min_doc_len
99
- self.max_doc_len = max_doc_len
100
-
101
- # Auto-adjust configuration based on popular datasets
102
- if "fineweb-edu" in dataset_name.lower():
103
- self.text_field = "text"
104
- self.split = "train" if split == "train" else split
105
- elif "tinystories" in dataset_name.lower():
106
- self.text_field = "text"
107
-
108
- def _stream_tokens(self) -> Iterator[int]:
109
- """Yields individual token IDs from the HuggingFace dataset stream."""
110
- try:
111
- from datasets import load_dataset
112
- except ImportError:
113
- raise ImportError(
114
- "The 'datasets' library is required. Install it with: "
115
- "pip install datasets"
116
- )
117
-
118
- logger.info(
119
- f"Streaming dataset '{self.dataset_name}' (split={self.split}) …"
120
- )
121
- ds = load_dataset(
122
- self.dataset_name,
123
- split=self.split,
124
- streaming=True,
125
- )
126
-
127
- for sample in ds:
128
- raw = sample.get(self.text_field, "")
129
- if not raw:
130
- continue
131
-
132
- text = clean_text(raw)
133
-
134
- # Filter documents that are too short or too long
135
- if len(text) < self.min_doc_len or len(text) > self.max_doc_len:
136
- continue
137
-
138
- tokens = self.tokenizer.encode(text, add_eos=True)
139
- yield from tokens
140
-
141
- def _chunk_tokens(self) -> Iterator[torch.Tensor]:
142
- """Groups raw token stream into fixed-length chunks of (seq_len + 1).
143
-
144
- The extra token is needed so that input = chunk[:-1] and
145
- target = chunk[1:] for next-token-prediction.
146
- """
147
- chunk: List[int] = []
148
- for tok in self._stream_tokens():
149
- chunk.append(tok)
150
- if len(chunk) == self.seq_len + 1:
151
- yield torch.tensor(chunk, dtype=torch.long)
152
- chunk = []
153
- # Discard any trailing partial chunk
154
-
155
- def __iter__(self) -> Iterator[torch.Tensor]:
156
- """Yields shuffled chunks from an internal buffer."""
157
- buffer: List[torch.Tensor] = []
158
- for chunk in self._chunk_tokens():
159
- buffer.append(chunk)
160
- if len(buffer) >= self.shuffle_buffer_size:
161
- random.shuffle(buffer)
162
- while len(buffer) > self.shuffle_buffer_size // 2:
163
- yield buffer.pop()
164
- # Flush remaining items
165
- random.shuffle(buffer)
166
- yield from buffer
167
-
168
- # ---------------------------------------------------------------------------
169
- # DataLoader factory
170
- # ---------------------------------------------------------------------------
171
-
172
- def create_dataloader(
173
- config: SageConfig,
174
- dataset_name: str = "HuggingFaceFW/fineweb-edu",
175
- split: str = "train",
176
- tokenizer: Optional[SageTokenizer] = None,
177
- ) -> DataLoader:
178
- """Creates a streaming DataLoader ready for the training loop."""
179
- tok = tokenizer or SageTokenizer()
180
- ds = StreamingTextDataset(
181
- dataset_name=dataset_name,
182
- split=split,
183
- seq_len=config.max_seq_len,
184
- tokenizer=tok,
185
- )
186
- return DataLoader(
187
- ds,
188
- batch_size=config.batch_size,
189
- num_workers=2,
190
- pin_memory=True,
191
- drop_last=True,
192
- )
193
-
194
- # ---------------------------------------------------------------------------
195
- # Instruction-tuning data helpers
196
- # ---------------------------------------------------------------------------
197
-
198
- INSTRUCTION_TEMPLATE = (
199
- "### Instruction:\n{instruction}\n\n### Response:\n{response}"
200
- )
201
-
202
-
203
- def format_instruction_sample(instruction: str, response: str) -> str:
204
- """Formats an instruction/response pair into the chat template."""
205
- return INSTRUCTION_TEMPLATE.format(
206
- instruction=instruction.strip(),
207
- response=response.strip(),
208
- )
209
-
210
-
211
- def create_instruction_batch(
212
- samples: List[dict],
213
- tokenizer: SageTokenizer,
214
- max_len: int = 512,
215
- ) -> dict:
216
- """
217
- Tokenize a list of {instruction, response} dicts and produce input_ids,
218
- labels, and a loss_mask that zeros out the instruction portion.
219
-
220
- Returns a dict with keys: input_ids, labels, loss_mask — all as tensors.
221
- """
222
- all_input_ids: List[List[int]] = []
223
- all_labels: List[List[int]] = []
224
- all_masks: List[List[int]] = []
225
-
226
- for sample in samples:
227
- instruction_text = f"### Instruction:\n{sample['instruction'].strip()}\n\n### Response:\n"
228
- response_text = sample["response"].strip()
229
- full_text = instruction_text + response_text
230
-
231
- instruction_tokens = tokenizer.encode(instruction_text)
232
- full_tokens = tokenizer.encode(full_text, add_eos=True)
233
-
234
- # Truncate to max_len
235
- full_tokens = full_tokens[:max_len]
236
- n_instruction = min(len(instruction_tokens), len(full_tokens))
237
-
238
- # Labels are the same as input shifted by 1 (handled by caller),
239
- # but we need a mask to zero out loss on instruction tokens.
240
- mask = [0] * n_instruction + [1] * (len(full_tokens) - n_instruction)
241
-
242
- # Pad to max_len
243
- pad_len = max_len - len(full_tokens)
244
- full_tokens += [tokenizer.pad_token_id] * pad_len
245
- mask += [0] * pad_len
246
-
247
- all_input_ids.append(full_tokens)
248
- all_labels.append(full_tokens) # shift will be done in the loss fn
249
- all_masks.append(mask)
250
-
251
- return {
252
- "input_ids": torch.tensor(all_input_ids, dtype=torch.long),
253
- "labels": torch.tensor(all_labels, dtype=torch.long),
254
- "loss_mask": torch.tensor(all_masks, dtype=torch.float32),
255
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/finetune.py DELETED
@@ -1,268 +0,0 @@
1
- """
2
- SAGE Fine-Tuning
3
- ================
4
- Provides two fine-tuning modes:
5
-
6
- 1. **Instruction tuning** — trains on instruction/response pairs with loss
7
- masked on the instruction portion.
8
- 2. **LoRA (Low-Rank Adaptation)** — injects small trainable matrices into
9
- attention layers while keeping the base model frozen.
10
- """
11
-
12
- import math
13
- import time
14
- import copy
15
- import torch
16
- import torch.nn as nn
17
- from torch.amp import GradScaler, autocast
18
- from tqdm import tqdm
19
- import wandb
20
- from typing import Optional, List
21
-
22
- from .config import SageConfig
23
- from .model import SageModel, CausalSelfAttention
24
- from .data import SageTokenizer, create_instruction_batch
25
- from .train import create_optimizer, get_lr, set_lr
26
- from .utils import setup_logger, save_checkpoint
27
-
28
- logger = setup_logger("sage.finetune")
29
-
30
-
31
- # ===================================================================
32
- # LoRA Implementation
33
- # ===================================================================
34
-
35
- class LoRALinear(nn.Module):
36
- """
37
- Wraps an existing ``nn.Linear`` with a low-rank adapter (A @ B).
38
-
39
- During fine-tuning only *A* and *B* are trained; the original weight
40
- is frozen. After fine-tuning the adapter can be merged back into
41
- the original weight for zero-overhead inference.
42
- """
43
-
44
- def __init__(self, original: nn.Linear, rank: int = 8, alpha: float = 16.0):
45
- super().__init__()
46
- self.original = original
47
- self.rank = rank
48
- self.alpha = alpha
49
- self.scaling = alpha / rank
50
-
51
- in_features = original.in_features
52
- out_features = original.out_features
53
-
54
- # Low-rank matrices
55
- device, dtype = original.weight.device, original.weight.dtype
56
- self.lora_A = nn.Parameter(torch.randn(in_features, rank, device=device, dtype=dtype) * 0.01)
57
- self.lora_B = nn.Parameter(torch.zeros(rank, out_features, device=device, dtype=dtype))
58
-
59
- # Freeze the original weight
60
- self.original.weight.requires_grad = False
61
- if self.original.bias is not None:
62
- self.original.bias.requires_grad = False
63
-
64
- def forward(self, x: torch.Tensor) -> torch.Tensor:
65
- """original(x) + x @ A @ B * scaling"""
66
- base_out = self.original(x)
67
- lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
68
- return base_out + lora_out
69
-
70
- def merge(self) -> nn.Linear:
71
- """Merge LoRA weights back into the original linear layer."""
72
- merged = copy.deepcopy(self.original)
73
- merged.weight.data += (self.lora_B.T @ self.lora_A.T).T * self.scaling
74
- merged.weight.requires_grad = True
75
- return merged
76
-
77
-
78
- # ---------------------------------------------------------------------------
79
- # LoRA injection / removal helpers
80
- # ---------------------------------------------------------------------------
81
-
82
- def inject_lora(model: SageModel, rank: int = 8, alpha: float = 16.0) -> SageModel:
83
- """
84
- Replace the Q, K, V, O projection layers in every attention block with
85
- LoRA-wrapped versions. Returns the same model (mutated in-place).
86
- """
87
- base_model = getattr(model, "module", model)
88
- for layer in base_model.layers:
89
- attn: CausalSelfAttention = layer.attn
90
- attn.wq = LoRALinear(attn.wq, rank=rank, alpha=alpha)
91
- attn.wk = LoRALinear(attn.wk, rank=rank, alpha=alpha)
92
- attn.wv = LoRALinear(attn.wv, rank=rank, alpha=alpha)
93
- attn.wo = LoRALinear(attn.wo, rank=rank, alpha=alpha)
94
-
95
- # Log trainable parameter count
96
- trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
97
- total = sum(p.numel() for p in model.parameters())
98
- logger.info(
99
- f"LoRA injected (rank={rank}). Trainable: {trainable:,} / {total:,} "
100
- f"({100 * trainable / total:.2f}%)"
101
- )
102
- return model
103
-
104
-
105
- def merge_lora(model: SageModel) -> SageModel:
106
- """
107
- Merge all LoRA adapters back into the base weights and replace the
108
- LoRALinear wrappers with plain nn.Linear modules.
109
- """
110
- base_model = getattr(model, "module", model)
111
- for layer in base_model.layers:
112
- attn: CausalSelfAttention = layer.attn
113
- for name in ("wq", "wk", "wv", "wo"):
114
- module = getattr(attn, name)
115
- if isinstance(module, LoRALinear):
116
- setattr(attn, name, module.merge())
117
- logger.info("LoRA weights merged into base model.")
118
- return model
119
-
120
-
121
- # ===================================================================
122
- # Instruction fine-tuning loop
123
- # ===================================================================
124
-
125
- def finetune_instruction(
126
- model: SageModel,
127
- config: SageConfig,
128
- samples: List[dict],
129
- total_steps: int = 200,
130
- use_lora: bool = True,
131
- lora_rank: int = 8,
132
- tokenizer: Optional[SageTokenizer] = None,
133
- ) -> SageModel:
134
- """
135
- Fine-tune the model on instruction/response pairs.
136
-
137
- Parameters
138
- ----------
139
- model : SageModel
140
- config : SageConfig
141
- samples : list[dict]
142
- Each dict must contain ``instruction`` and ``response`` string keys.
143
- total_steps : int
144
- use_lora : bool
145
- If True, inject LoRA adapters before training.
146
- lora_rank : int
147
- tokenizer : SageTokenizer, optional
148
-
149
- Returns
150
- -------
151
- SageModel — the fine-tuned model (LoRA merged if applicable).
152
- """
153
- # --- TURBO MODE: TF32 & COMPILE ---
154
- if torch.cuda.is_available():
155
- torch.set_float32_matmul_precision('high')
156
-
157
- device = config.device
158
- model = model.to(device)
159
-
160
- # Wrap model with torch.compile for graph-level optimization
161
- if hasattr(torch, "compile"):
162
- logger.info("Turbo Mode: Compiling fine-tune engine...")
163
- base = getattr(model, "module", model)
164
- compiled_base = torch.compile(base, mode="reduce-overhead")
165
- if hasattr(model, "module"):
166
- model.module = compiled_base
167
- else:
168
- model = compiled_base
169
-
170
- tok = tokenizer or SageTokenizer()
171
-
172
- if use_lora:
173
- model = inject_lora(model, rank=lora_rank)
174
-
175
- # ------- W&B Logging -------
176
- wandb.init(
177
- project=config.project_name,
178
- name=f"finetune-{time.strftime('%Y%m%d-%H%M')}",
179
- config=config.__dict__,
180
- )
181
-
182
- optimizer = create_optimizer(model, config)
183
-
184
- # AMP setup
185
- use_amp = device.type == "cuda"
186
- amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16
187
- scaler = GradScaler("cuda", enabled=(use_amp and amp_dtype == torch.float16))
188
-
189
- model.train()
190
- pbar = tqdm(range(total_steps), desc="Fine-tuning", unit="step")
191
- accum_loss = 0.0
192
-
193
- for step in pbar:
194
- lr = get_lr(step, config, total_steps)
195
- set_lr(optimizer, lr)
196
-
197
- # Build a batch by sampling from the instruction dataset
198
- batch_size = min(config.batch_size, len(samples))
199
- import random
200
- batch_samples = random.choices(samples, k=batch_size)
201
- batch = create_instruction_batch(batch_samples, tok, max_len=config.max_seq_len)
202
-
203
- input_ids = batch["input_ids"].to(device)
204
- labels = batch["labels"].to(device)
205
- loss_mask = batch["loss_mask"].to(device)
206
-
207
- optimizer.zero_grad(set_to_none=True)
208
-
209
- with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
210
- logits, _ = model(input_ids)
211
- # Shift: predict next token
212
- shift_logits = logits[:, :-1, :].contiguous()
213
- shift_labels = labels[:, 1:].contiguous()
214
- shift_mask = loss_mask[:, 1:].contiguous()
215
-
216
- # Compute per-token loss
217
- per_token_loss = nn.functional.cross_entropy(
218
- shift_logits.view(-1, shift_logits.size(-1)),
219
- shift_labels.view(-1),
220
- reduction="none",
221
- )
222
- per_token_loss = per_token_loss.view(shift_labels.size())
223
-
224
- # Mask out instruction tokens so we only learn from responses
225
- masked_loss = (per_token_loss * shift_mask).sum() / shift_mask.sum().clamp(min=1)
226
-
227
- scaler.scale(masked_loss).backward()
228
- scaler.unscale_(optimizer)
229
- torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
230
- scaler.step(optimizer)
231
- scaler.update()
232
-
233
- accum_loss += masked_loss.item()
234
-
235
- if (step + 1) % 10 == 0:
236
- avg = accum_loss / 10
237
- pbar.set_postfix(loss=f"{avg:.4f}", lr=f"{lr:.2e}")
238
- logger.info(f"finetune step={step+1} | loss={avg:.4f}")
239
- wandb.log({
240
- "finetune/loss": avg,
241
- "finetune/lr": lr,
242
- }, step=step + 1)
243
- accum_loss = 0.0
244
-
245
- # Merge LoRA weights back for clean inference
246
- if use_lora:
247
- model = merge_lora(model)
248
-
249
- save_checkpoint(model, None, total_steps, config.checkpoint_dir, filename="sage_finetuned.pt")
250
- logger.info("Instruction fine-tuning complete. Checkpoint saved as sage_finetuned.pt")
251
- wandb.finish()
252
- return model
253
-
254
-
255
- # ---------------------------------------------------------------------------
256
- # Demo instruction samples (used when no dataset is provided)
257
- # ---------------------------------------------------------------------------
258
-
259
- DEMO_INSTRUCTION_SAMPLES = [
260
- {"instruction": "What is the capital of France?", "response": "The capital of France is Paris."},
261
- {"instruction": "Explain gravity in simple terms.", "response": "Gravity is the force that pulls objects toward each other. The more mass an object has, the stronger its gravitational pull."},
262
- {"instruction": "Write a short poem about the ocean.", "response": "Waves crash upon the sandy shore,\nThe ocean's song forevermore.\nDeep blue stretching to the sky,\nSeagulls dance and clouds float by."},
263
- {"instruction": "What is 15 times 12?", "response": "15 times 12 equals 180."},
264
- {"instruction": "Summarize photosynthesis.", "response": "Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen, providing energy for the plant."},
265
- {"instruction": "Tell me a fun fact about space.", "response": "A day on Venus is longer than a year on Venus. It takes Venus 243 Earth days to rotate once on its axis but only 225 Earth days to orbit the Sun."},
266
- {"instruction": "How do airplanes fly?", "response": "Airplanes fly by generating lift through their wings. Air moves faster over the curved top of the wing than the flat bottom, creating lower pressure above and higher pressure below, which pushes the wing upward."},
267
- {"instruction": "What is machine learning?", "response": "Machine learning is a branch of artificial intelligence where computers learn patterns from data instead of being explicitly programmed, allowing them to make predictions or decisions."},
268
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/inference.py DELETED
@@ -1,171 +0,0 @@
1
- """
2
- SAGE Inference Engine
3
- =====================
4
- Text generation with greedy, temperature, top-k, and nucleus (top-p) sampling.
5
- Supports KV-cache for O(1)-per-token generation and streaming output.
6
- """
7
-
8
- import sys
9
- import torch
10
- import torch.nn.functional as F
11
- from typing import Optional, List
12
-
13
- from .config import SageConfig
14
- from .model import SageModel
15
- from .data import SageTokenizer
16
- from .utils import setup_logger
17
-
18
- logger = setup_logger("sage.inference")
19
-
20
-
21
- # ---------------------------------------------------------------------------
22
- # Sampling helpers
23
- # ---------------------------------------------------------------------------
24
-
25
- def _top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor:
26
- """Zero out all logits outside the top-k highest values."""
27
- if k <= 0 or k >= logits.size(-1):
28
- return logits
29
- values, _ = torch.topk(logits, k)
30
- min_val = values[:, -1].unsqueeze(-1)
31
- return torch.where(logits < min_val, torch.full_like(logits, float("-inf")), logits)
32
-
33
-
34
- def _top_p_filter(logits: torch.Tensor, p: float) -> torch.Tensor:
35
- """Nucleus sampling: keep the smallest set of tokens whose cumulative
36
- probability exceeds *p*."""
37
- if p >= 1.0:
38
- return logits
39
- sorted_logits, sorted_idx = torch.sort(logits, descending=True)
40
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
41
-
42
- # Identify tokens to remove (cumulative prob exceeds p)
43
- sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= p
44
- sorted_logits[sorted_mask] = float("-inf")
45
-
46
- # Scatter back to original order
47
- logits = logits.scatter(1, sorted_idx, sorted_logits)
48
- return logits
49
-
50
-
51
- def sample_next_token(
52
- logits: torch.Tensor,
53
- temperature: float = 0.8,
54
- top_k: int = 50,
55
- top_p: float = 0.9,
56
- greedy: bool = False,
57
- ) -> torch.Tensor:
58
- """
59
- Given raw logits for the last position, sample or greedily select the
60
- next token.
61
-
62
- Parameters
63
- ----------
64
- logits : Tensor [batch, vocab]
65
- temperature : float
66
- top_k : int
67
- top_p : float
68
- greedy : bool — if True, ignore temperature/top-k/top-p and pick argmax.
69
-
70
- Returns
71
- -------
72
- Tensor [batch, 1]
73
- """
74
- if greedy:
75
- return logits.argmax(dim=-1, keepdim=True)
76
-
77
- logits = logits / max(temperature, 1e-8)
78
- logits = _top_k_filter(logits, top_k)
79
- logits = _top_p_filter(logits, top_p)
80
-
81
- probs = F.softmax(logits, dim=-1)
82
- return torch.multinomial(probs, num_samples=1)
83
-
84
-
85
- # ---------------------------------------------------------------------------
86
- # Main generation function
87
- # ---------------------------------------------------------------------------
88
-
89
- @torch.no_grad()
90
- def generate(
91
- model: SageModel,
92
- tokenizer: SageTokenizer,
93
- prompt: str,
94
- max_new_tokens: int = 256,
95
- temperature: float = 0.8,
96
- top_k: int = 50,
97
- top_p: float = 0.9,
98
- greedy: bool = False,
99
- stream: bool = True,
100
- device: Optional[torch.device] = None,
101
- ) -> str:
102
- """
103
- Generate text from *prompt* using the SAGE model.
104
-
105
- Parameters
106
- ----------
107
- model : SageModel
108
- tokenizer : SageTokenizer
109
- prompt : str
110
- max_new_tokens : int
111
- temperature, top_k, top_p : sampling hyper-parameters
112
- greedy : bool — use argmax decoding
113
- stream : bool — print tokens as they are generated
114
- device : torch.device
115
-
116
- Returns
117
- -------
118
- str — the complete generated text (prompt + new tokens).
119
- """
120
- if device is None:
121
- device = next(model.parameters()).device
122
-
123
- base_model = getattr(model, "module", model)
124
- base_model.eval()
125
-
126
- # Encode prompt
127
- prompt_tokens = tokenizer.encode(prompt)
128
- if not prompt_tokens:
129
- prompt_tokens = [tokenizer.eos_token_id]
130
-
131
- input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
132
-
133
- generated_tokens: List[int] = list(prompt_tokens)
134
- kv_caches = None
135
-
136
- # --- Prefill: run the full prompt through the model once ---
137
- logits, kv_caches = base_model(input_ids)
138
- next_logits = logits[:, -1, :]
139
-
140
- for _ in range(max_new_tokens):
141
- next_id = sample_next_token(
142
- next_logits,
143
- temperature=temperature,
144
- top_k=top_k,
145
- top_p=top_p,
146
- greedy=greedy,
147
- )
148
-
149
- token_id = next_id.item()
150
-
151
- # Stop on EOS
152
- if token_id == tokenizer.eos_token_id:
153
- break
154
-
155
- generated_tokens.append(token_id)
156
-
157
- # Stream output: decode and print only the new token
158
- if stream:
159
- token_str = tokenizer.decode([token_id])
160
- print(token_str, end="", flush=True)
161
-
162
- # --- Decode step: feed only the new token, reuse KV-cache ---
163
- next_input = next_id.view(1, 1)
164
- logits, kv_caches = base_model(next_input, kv_caches=kv_caches)
165
- next_logits = logits[:, -1, :]
166
-
167
- if stream:
168
- print() # newline after streaming completes
169
-
170
- base_model.train()
171
- return tokenizer.decode(generated_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/memory.py DELETED
@@ -1,240 +0,0 @@
1
- """
2
- SAGE Memory & RAG Module
3
- =========================
4
- Provides:
5
- - A FAISS-backed vector store for retrieval-augmented generation (RAG).
6
- - A rolling conversation-history manager that truncates intelligently
7
- to stay within the model's context window.
8
- """
9
-
10
- import os
11
- import numpy as np
12
- import torch
13
- import torch.nn.functional as F
14
- from typing import List, Optional, Tuple
15
-
16
- from .data import SageTokenizer
17
- from .utils import setup_logger
18
-
19
- logger = setup_logger("sage.memory")
20
-
21
-
22
- # ===================================================================
23
- # Simple embedding helper (uses mean-pooled token embeddings)
24
- # ===================================================================
25
-
26
- def _embed_text(text: str, tokenizer: SageTokenizer, model: torch.nn.Module, device: torch.device) -> np.ndarray:
27
- """
28
- Produce a fixed-length embedding for *text* by mean-pooling the
29
- model's token embeddings. This is lightweight and avoids a full
30
- forward pass — suitable for a small retrieval index.
31
- """
32
- tokens = tokenizer.encode(text)
33
- if not tokens:
34
- # Return a zero vector when text is empty
35
- d_model = model.wte.weight.shape[1]
36
- return np.zeros(d_model, dtype=np.float32)
37
-
38
- ids = torch.tensor([tokens], dtype=torch.long, device=device)
39
- with torch.no_grad():
40
- embeddings = model.wte(ids) # [1, seq_len, d_model]
41
- mean_emb = embeddings.mean(dim=1) # [1, d_model]
42
- # L2-normalize for cosine similarity in FAISS
43
- mean_emb = F.normalize(mean_emb, p=2, dim=-1)
44
- return mean_emb.squeeze(0).cpu().numpy()
45
-
46
-
47
- # ===================================================================
48
- # FAISS-backed Vector Store
49
- # ===================================================================
50
-
51
- class VectorStore:
52
- """
53
- A lightweight document store backed by FAISS (Inner Product index,
54
- which equals cosine similarity when vectors are L2-normalized).
55
- """
56
-
57
- def __init__(self, dim: int):
58
- try:
59
- import faiss
60
- except ImportError:
61
- raise ImportError(
62
- "FAISS is required for RAG. Install it with: pip install faiss-cpu"
63
- )
64
- self.dim = dim
65
- self.index = faiss.IndexFlatIP(dim) # inner-product (cosine after L2-norm)
66
- self.documents: List[str] = []
67
- logger.info(f"VectorStore initialized (dim={dim})")
68
-
69
- def add(self, texts: List[str], embeddings: np.ndarray) -> None:
70
- """Add documents and their embeddings to the store."""
71
- assert embeddings.shape[0] == len(texts)
72
- assert embeddings.shape[1] == self.dim
73
- self.index.add(embeddings.astype(np.float32))
74
- self.documents.extend(texts)
75
- logger.info(f"Added {len(texts)} documents. Total: {len(self.documents)}")
76
-
77
- def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[Tuple[str, float]]:
78
- """Return the top-k most similar documents with their scores."""
79
- if self.index.ntotal == 0:
80
- return []
81
- query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
82
- scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
83
- results = []
84
- for score, idx in zip(scores[0], indices[0]):
85
- if idx < 0:
86
- continue
87
- results.append((self.documents[idx], float(score)))
88
- return results
89
-
90
- @property
91
- def size(self) -> int:
92
- return self.index.ntotal
93
-
94
-
95
- # ===================================================================
96
- # RAG Manager
97
- # ===================================================================
98
-
99
- class RAGManager:
100
- """
101
- High-level retrieval-augmented generation manager.
102
-
103
- Call ``add_documents`` to ingest text, then ``retrieve_context`` at
104
- inference time to prepend relevant chunks to the user prompt.
105
- """
106
-
107
- def __init__(
108
- self,
109
- model: torch.nn.Module,
110
- tokenizer: SageTokenizer,
111
- device: torch.device,
112
- chunk_size: int = 200,
113
- chunk_overlap: int = 50,
114
- ):
115
- self.model = model
116
- self.tokenizer = tokenizer
117
- self.device = device
118
- self.chunk_size = chunk_size
119
- self.chunk_overlap = chunk_overlap
120
-
121
- d_model = model.wte.weight.shape[1]
122
- self.store = VectorStore(dim=d_model)
123
- self.enabled = False
124
-
125
- def _chunk_text(self, text: str) -> List[str]:
126
- """Split text into overlapping word-level chunks."""
127
- words = text.split()
128
- chunks: List[str] = []
129
- start = 0
130
- while start < len(words):
131
- end = start + self.chunk_size
132
- chunk = " ".join(words[start:end])
133
- chunks.append(chunk)
134
- start += self.chunk_size - self.chunk_overlap
135
- return chunks
136
-
137
- def add_documents(self, texts: List[str]) -> None:
138
- """Chunk and embed documents, then add to the vector store."""
139
- all_chunks: List[str] = []
140
- for text in texts:
141
- all_chunks.extend(self._chunk_text(text))
142
-
143
- if not all_chunks:
144
- logger.warning("No document chunks to add.")
145
- return
146
-
147
- embeddings = np.stack([
148
- _embed_text(chunk, self.tokenizer, self.model, self.device)
149
- for chunk in all_chunks
150
- ])
151
- self.store.add(all_chunks, embeddings)
152
-
153
- def retrieve_context(self, query: str, top_k: int = 3) -> str:
154
- """
155
- Retrieve the top-k most relevant chunks for *query* and
156
- concatenate them into a context string.
157
- """
158
- if not self.enabled or self.store.size == 0:
159
- return ""
160
-
161
- q_emb = _embed_text(query, self.tokenizer, self.model, self.device)
162
- results = self.store.search(q_emb, top_k=top_k)
163
-
164
- if not results:
165
- return ""
166
-
167
- context_parts = [f"[Context {i+1}] {doc}" for i, (doc, _score) in enumerate(results)]
168
- return "\n\n".join(context_parts) + "\n\n"
169
-
170
- def toggle(self, on: bool) -> None:
171
- self.enabled = on
172
- state = "enabled" if on else "disabled"
173
- logger.info(f"RAG {state}. Store contains {self.store.size} chunks.")
174
-
175
-
176
- # ===================================================================
177
- # Conversation History Manager
178
- # ===================================================================
179
-
180
- DEFAULT_SYSTEM_PROMPT = (
181
- "You are SAGE, a high-quality reasoning assistant. "
182
- "Your goal is to provide accurate, structured, and deep logical explanations.\n\n"
183
- "CRITICAL GUIDELINES:\n"
184
- "1. THINKING PHASE: You must ALWAYS start your response with a <thinking> section. "
185
- "In this section, break down the user's request, identify key constraints, and plan your logical steps.\n"
186
- "2. RESPONSE PHASE: After completing your internal reasoning, provide your final answer within <response> tags.\n"
187
- "3. QUALITY: Prioritize step-by-step mathematical or logical derivation over short answers.\n"
188
- "4. NO REPETITION: Avoid filler words or circular logic.\n\n"
189
- "RESPONSE TEMPLATE:\n"
190
- "<thinking>\n[Step-by-step logic here]\n</thinking>\n"
191
- "<response>\n[Final clear answer here]\n</response>"
192
- )
193
-
194
- class ConversationHistory:
195
- """
196
- Rolling conversation history that stays within a token budget.
197
-
198
- Older turns are dropped when the history would exceed the context window.
199
- """
200
-
201
- def __init__(self, tokenizer: SageTokenizer, max_tokens: int = 900):
202
- self.tokenizer = tokenizer
203
- self.max_tokens = max_tokens
204
- self.turns: List[dict] = [] # [{"role": "user"/"assistant", "text": ...}, ...]
205
-
206
- def add(self, role: str, text: str) -> None:
207
- """Record a new conversational turn."""
208
- self.turns.append({"role": role, "text": text})
209
- self._trim()
210
-
211
- def _trim(self) -> None:
212
- """Drop oldest turns until the total token count is within budget."""
213
- while self._total_tokens() > self.max_tokens and len(self.turns) > 1:
214
- self.turns.pop(0)
215
-
216
- def _total_tokens(self) -> int:
217
- return sum(len(self.tokenizer.encode(t["text"])) for t in self.turns)
218
-
219
- def build_prompt(self, new_user_message: str, rag_context: str = "") -> str:
220
- """
221
- Assemble the full prompt from history + RAG context + new message.
222
- """
223
- parts: List[str] = []
224
-
225
- parts.append(DEFAULT_SYSTEM_PROMPT)
226
-
227
- if rag_context:
228
- parts.append(rag_context)
229
-
230
- for turn in self.turns:
231
- prefix = "User:" if turn["role"] == "user" else "SAGE:"
232
- parts.append(f"{prefix} {turn['text']}")
233
-
234
- parts.append(f"User: {new_user_message}")
235
- parts.append("SAGE:")
236
-
237
- return "\n".join(parts)
238
-
239
- def clear(self) -> None:
240
- self.turns.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/model.py DELETED
@@ -1,267 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from typing import Optional, Tuple
6
- from .config import SageConfig
7
-
8
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
9
- """Precomputes rotary positional embedding frequencies."""
10
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
11
- t = torch.arange(end, device=freqs.device, dtype=torch.float32)
12
- freqs = torch.outer(t, freqs).float()
13
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
14
- return freqs_cis
15
-
16
- def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
17
- """Applies rotary positional embeddings to queries and keys."""
18
- # Ensure freqs_cis is complex (DataParallel can sometimes replicate it as real)
19
- if not torch.is_complex(freqs_cis) and freqs_cis.shape[-1] == 2:
20
- freqs_cis = torch.view_as_complex(freqs_cis)
21
-
22
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
23
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
24
-
25
- # Reshape freqs_cis to broadcast with xq_ and xk_
26
- # xq_, xk_ shape: [batch, seq_len, n_heads, dim_head//2]
27
- # freqs_cis shape: [seq_len, dim_head//2]
28
- freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
29
-
30
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
31
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
32
-
33
- return xq_out.type_as(xq), xk_out.type_as(xk)
34
-
35
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
36
- """Repeat Key/Value heads n_rep times to match number of Query heads."""
37
- if n_rep == 1:
38
- return x
39
- B, T, n_kv_heads, head_dim = x.size()
40
- return (
41
- x[:, :, :, None, :]
42
- .expand(B, T, n_kv_heads, n_rep, head_dim)
43
- .reshape(B, T, n_kv_heads * n_rep, head_dim)
44
- )
45
-
46
- class CausalSelfAttention(nn.Module):
47
- def __init__(self, config: SageConfig):
48
- super().__init__()
49
- self.n_heads = config.n_heads
50
- self.n_kv_heads = config.n_kv_heads
51
- self.n_rep = self.n_heads // self.n_kv_heads
52
- self.d_model = config.d_model
53
- assert self.d_model % self.n_heads == 0
54
- self.head_dim = self.d_model // self.n_heads
55
-
56
- self.wq = nn.Linear(self.d_model, self.n_heads * self.head_dim, bias=False)
57
- self.wk = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=False)
58
- self.wv = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=False)
59
- self.wo = nn.Linear(self.d_model, self.d_model, bias=False)
60
-
61
- self.resid_dropout = nn.Dropout(config.dropout)
62
-
63
- # Flash attention handles causality via is_causal flag if seq_len > 1
64
-
65
- def forward(
66
- self,
67
- x: torch.Tensor,
68
- freqs_cis: torch.Tensor,
69
- kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
70
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
71
- B, T, C = x.size() # batch, seq_len, d_model
72
- q, k, v = self.wq(x), self.wk(x), self.wv(x)
73
-
74
- q = q.view(B, T, self.n_heads, self.head_dim)
75
- k = k.view(B, T, self.n_kv_heads, self.head_dim)
76
- v = v.view(B, T, self.n_kv_heads, self.head_dim)
77
-
78
- q, k = apply_rotary_emb(q, k, freqs_cis)
79
-
80
- if kv_cache is not None:
81
- # We are generating token by token
82
- k_cache, v_cache = kv_cache
83
- k = torch.cat([k_cache, k], dim=1)
84
- v = torch.cat([v_cache, v], dim=1)
85
- new_kv_cache = (k, v)
86
- else:
87
- new_kv_cache = None
88
-
89
- # Repeat KV heads to match Q heads (GQA)
90
- k = repeat_kv(k, self.n_rep)
91
- v = repeat_kv(v, self.n_rep)
92
-
93
- # Move heads to correct dimension: (B, n_heads, T, head_dim)
94
- q = q.transpose(1, 2)
95
- k = k.transpose(1, 2)
96
- v = v.transpose(1, 2)
97
-
98
- # Flash attention natively supported via scaled_dot_product_attention
99
- is_causal = (kv_cache is None and T > 1)
100
- try:
101
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.resid_dropout.p if self.training else 0.0, is_causal=is_causal)
102
- except Exception:
103
- # Manual attention fallback for older architectures (like P100 sm_60)
104
- attn_weights = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
105
- if is_causal:
106
- # Use a causal mask
107
- causal_mask = torch.tril(torch.ones(T, T, device=q.device)).view(1, 1, T, T)
108
- attn_weights = attn_weights.masked_fill(causal_mask == 0, float('-inf'))
109
-
110
- attn_weights = F.softmax(attn_weights, dim=-1)
111
- if self.training:
112
- attn_weights = self.resid_dropout(attn_weights)
113
-
114
- y = attn_weights @ v
115
-
116
- y = y.transpose(1, 2).contiguous().view(B, T, C)
117
- y = self.resid_dropout(self.wo(y))
118
-
119
- return y, new_kv_cache
120
-
121
- class ExpertFFN(nn.Module):
122
- def __init__(self, config: SageConfig):
123
- super().__init__()
124
- self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False)
125
- self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False)
126
- self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False)
127
- self.dropout = nn.Dropout(config.dropout)
128
-
129
- def forward(self, x: torch.Tensor) -> torch.Tensor:
130
- # SwiGLU activation structure
131
- hidden = F.silu(self.w1(x)) * self.w3(x)
132
- return self.dropout(self.w2(hidden))
133
-
134
- class MoE(nn.Module):
135
- def __init__(self, config: SageConfig):
136
- super().__init__()
137
- self.n_experts = config.n_experts
138
- self.top_k = config.num_experts_per_tok
139
- self.d_model = config.d_model
140
-
141
- self.router = nn.Linear(self.d_model, self.n_experts, bias=False)
142
- self.experts = nn.ModuleList([ExpertFFN(config) for _ in range(self.n_experts)])
143
-
144
- def forward(self, x: torch.Tensor) -> torch.Tensor:
145
- B, T, C = x.size()
146
- x_flat = x.view(-1, C) # [B*T, C]
147
-
148
- router_logits = self.router(x_flat) # [B*T, n_experts]
149
- routing_weights = F.softmax(router_logits, dim=-1)
150
-
151
- # Select Top K experts
152
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # [B*T, top_k]
153
- routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # re-normalize
154
-
155
- final_out = torch.zeros_like(x_flat)
156
-
157
- # Iterate over experts and compute their outputs
158
- for i, expert in enumerate(self.experts):
159
- # Find which tokens chose this expert
160
- expert_mask = (selected_experts == i)
161
- token_idx, kth_expert = torch.where(expert_mask)
162
-
163
- if token_idx.shape[0] > 0:
164
- expert_inputs = x_flat[token_idx]
165
- expert_outputs = expert(expert_inputs)
166
-
167
- # Apply router weight
168
- weights = routing_weights[token_idx, kth_expert].unsqueeze(-1)
169
- final_out[token_idx] += expert_outputs * weights
170
-
171
- return final_out.view(B, T, C)
172
-
173
- class TransformerBlock(nn.Module):
174
- def __init__(self, config: SageConfig):
175
- super().__init__()
176
- self.norm1 = nn.LayerNorm(config.d_model)
177
- self.attn = CausalSelfAttention(config)
178
- self.norm2 = nn.LayerNorm(config.d_model)
179
- self.moe = MoE(config)
180
-
181
- def forward(
182
- self,
183
- x: torch.Tensor,
184
- freqs_cis: torch.Tensor,
185
- kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
186
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
187
- # Pre-LayerNorm architecture
188
- h, new_kv_cache = self.attn(self.norm1(x), freqs_cis, kv_cache)
189
- x = x + h
190
- x = x + self.moe(self.norm2(x))
191
- return x, new_kv_cache
192
-
193
- class SageModel(nn.Module):
194
- def __init__(self, config: SageConfig):
195
- super().__init__()
196
- self.config = config
197
-
198
- self.wte = nn.Embedding(config.vocab_size, config.d_model)
199
- self.drop = nn.Dropout(config.dropout)
200
-
201
- self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
202
-
203
- self.ln_f = nn.LayerNorm(config.d_model)
204
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
205
-
206
- # Weight tying
207
- self.wte.weight = self.lm_head.weight
208
-
209
- # Precompute RoPE frequencies
210
- self.register_buffer("freqs_cis", precompute_freqs_cis(config.d_model // config.n_heads, config.max_seq_len * 2), persistent=False)
211
-
212
- self.apply(self._init_weights)
213
-
214
- def _init_weights(self, module):
215
- if isinstance(module, nn.Linear):
216
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
217
- if module.bias is not None:
218
- torch.nn.init.zeros_(module.bias)
219
- elif isinstance(module, nn.Embedding):
220
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
221
- elif isinstance(module, nn.LayerNorm):
222
- torch.nn.init.zeros_(module.bias)
223
- torch.nn.init.ones_(module.weight)
224
-
225
- def forward(
226
- self,
227
- idx: torch.Tensor,
228
- kv_caches: Optional[list] = None
229
- ) -> Tuple[torch.Tensor, Optional[list]]:
230
- B, T = idx.size()
231
-
232
- if kv_caches is not None:
233
- # generating context, token is at specific position
234
- start_pos = kv_caches[0][0].shape[1]
235
- else:
236
- start_pos = 0
237
-
238
- freqs_cis = self.freqs_cis[start_pos : start_pos + T]
239
-
240
- x = self.drop(self.wte(idx))
241
-
242
- new_kv_caches = []
243
- for i, layer in enumerate(self.layers):
244
- kv_cache = kv_caches[i] if kv_caches else None
245
-
246
- # Use gradient checkpointing during training
247
- if self.training and kv_cache is None:
248
- def create_custom_forward(module):
249
- def custom_forward(x_in, freqs_cis_in):
250
- return module(x_in, freqs_cis_in, None)
251
- return custom_forward
252
-
253
- x, new_kv_cache = torch.utils.checkpoint.checkpoint(
254
- create_custom_forward(layer),
255
- x, freqs_cis,
256
- use_reentrant=False
257
- )
258
- else:
259
- x, new_kv_cache = layer(x, freqs_cis, kv_cache)
260
-
261
- if new_kv_cache is not None:
262
- new_kv_caches.append(new_kv_cache)
263
-
264
- x = self.ln_f(x)
265
- logits = self.lm_head(x) # [B, T, vocab_size]
266
-
267
- return logits, new_kv_caches if len(new_kv_caches) > 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/optimize.py DELETED
@@ -1,164 +0,0 @@
1
- """
2
- SAGE Optimization Layer
3
- =======================
4
- Post-training quantization (INT8), optional pruning, and knowledge-distillation
5
- loss utilities.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.utils.prune as prune
11
- from typing import Optional
12
-
13
- from .model import SageModel
14
- from .config import SageConfig
15
- from .utils import setup_logger
16
-
17
- logger = setup_logger("sage.optimize")
18
-
19
-
20
- # ===================================================================
21
- # INT8 Dynamic Quantization
22
- # ===================================================================
23
-
24
- def quantize_int8(model: SageModel) -> nn.Module:
25
- """
26
- Apply dynamic INT8 quantization to all Linear layers in the model.
27
-
28
- This reduces model size by ~2-4x and can speed up CPU inference.
29
- The model is moved to CPU before quantization because PyTorch's
30
- dynamic quantization only supports CPU tensors.
31
-
32
- Returns
33
- -------
34
- nn.Module — the quantized model (on CPU).
35
- """
36
- base_model = getattr(model, "module", model)
37
- base_model = base_model.cpu().eval()
38
-
39
- quantized = torch.quantization.quantize_dynamic(
40
- base_model,
41
- {nn.Linear}, # quantize all linear layers
42
- dtype=torch.qint8,
43
- )
44
-
45
- # Report size reduction
46
- orig_size = sum(p.numel() * p.element_size() for p in base_model.parameters())
47
- # Quantized parameters may not report element_size correctly,
48
- # so we estimate based on INT8 = 1 byte per weight.
49
- quant_size = sum(p.numel() for p in quantized.parameters()) # * 1 byte
50
- logger.info(
51
- f"Quantization complete. "
52
- f"Original: {orig_size / 1e6:.1f} MB → Quantized: ~{quant_size / 1e6:.1f} MB (INT8)"
53
- )
54
- return quantized
55
-
56
-
57
- # ===================================================================
58
- # Weight Pruning
59
- # ===================================================================
60
-
61
- def prune_model(model: SageModel, amount: float = 0.3) -> SageModel:
62
- """
63
- Apply unstructured L1 pruning to all Linear layers, removing the
64
- *amount* fraction of weights with the smallest magnitude.
65
-
66
- Parameters
67
- ----------
68
- model : SageModel
69
- amount : float
70
- Fraction of weights to prune (0.0 – 1.0).
71
-
72
- Returns
73
- -------
74
- SageModel — the pruned model (pruning masks are permanent after this call).
75
- """
76
- pruned_count = 0
77
- total_count = 0
78
-
79
- base_model = getattr(model, "module", model)
80
- for name, module in base_model.named_modules():
81
- if isinstance(module, nn.Linear):
82
- prune.l1_unstructured(module, name="weight", amount=amount)
83
- prune.remove(module, "weight") # make the pruning permanent
84
- pruned_count += (module.weight == 0).sum().item()
85
- total_count += module.weight.numel()
86
-
87
- sparsity = pruned_count / max(total_count, 1) * 100
88
- logger.info(
89
- f"Pruning complete. {pruned_count:,} / {total_count:,} weights zeroed "
90
- f"({sparsity:.1f}% sparsity)"
91
- )
92
- return model
93
-
94
-
95
- # ===================================================================
96
- # Knowledge Distillation Loss
97
- # ===================================================================
98
-
99
- def distillation_loss(
100
- student_logits: torch.Tensor,
101
- teacher_logits: torch.Tensor,
102
- labels: torch.Tensor,
103
- temperature: float = 2.0,
104
- alpha: float = 0.5,
105
- ignore_index: int = -100,
106
- ) -> torch.Tensor:
107
- """
108
- Combined knowledge-distillation loss.
109
-
110
- ``L = alpha * KL(softmax(teacher/T), softmax(student/T)) * T^2
111
- + (1 - alpha) * CE(student, labels)``
112
-
113
- Parameters
114
- ----------
115
- student_logits : Tensor [B, T, V]
116
- teacher_logits : Tensor [B, T, V]
117
- labels : Tensor [B, T]
118
- temperature : float
119
- alpha : float — weight for the distillation term (0 → pure CE, 1 → pure KD).
120
- ignore_index : int — label value to ignore in cross-entropy.
121
-
122
- Returns
123
- -------
124
- Tensor (scalar)
125
- """
126
- # Soft targets
127
- soft_student = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
128
- soft_teacher = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
129
-
130
- kd_loss = torch.nn.functional.kl_div(
131
- soft_student.view(-1, soft_student.size(-1)),
132
- soft_teacher.view(-1, soft_teacher.size(-1)),
133
- reduction="batchmean",
134
- ) * (temperature ** 2)
135
-
136
- # Hard-label cross-entropy
137
- ce_loss = torch.nn.functional.cross_entropy(
138
- student_logits.view(-1, student_logits.size(-1)),
139
- labels.view(-1),
140
- ignore_index=ignore_index,
141
- )
142
-
143
- return alpha * kd_loss + (1 - alpha) * ce_loss
144
-
145
-
146
- # ===================================================================
147
- # torch.compile wrapper (PyTorch 2.0+)
148
- # ===================================================================
149
-
150
- def try_compile(model: nn.Module) -> nn.Module:
151
- """
152
- Attempt to compile the model with ``torch.compile`` for faster
153
- execution. Falls back gracefully if compilation is not available.
154
- """
155
- if hasattr(torch, "compile"):
156
- try:
157
- compiled = torch.compile(model)
158
- logger.info("Model compiled with torch.compile for accelerated execution.")
159
- return compiled
160
- except Exception as e:
161
- logger.warning(f"torch.compile failed ({e}). Using eager mode.")
162
- else:
163
- logger.info("torch.compile not available (requires PyTorch 2.0+). Using eager mode.")
164
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/train.py DELETED
@@ -1,266 +0,0 @@
1
- """
2
- SAGE Training System
3
- ====================
4
- Complete training loop with AdamW, cosine-decay LR schedule, mixed-precision
5
- (AMP), gradient accumulation, gradient clipping, and checkpoint management.
6
- """
7
-
8
- import math
9
- import time
10
- import torch
11
- import torch.nn as nn
12
- from torch.amp import GradScaler, autocast
13
- from tqdm import tqdm
14
- import wandb
15
- from typing import Optional
16
-
17
- from .config import SageConfig
18
- from .model import SageModel
19
- from .data import SageTokenizer, create_dataloader
20
- from .utils import setup_logger, save_checkpoint, load_checkpoint
21
-
22
- logger = setup_logger("sage.train")
23
-
24
-
25
- # ---------------------------------------------------------------------------
26
- # Learning-rate scheduler helpers
27
- # ---------------------------------------------------------------------------
28
-
29
- def get_lr(step: int, config: SageConfig, total_steps: int) -> float:
30
- """Cosine decay with linear warmup. Returns the learning rate for *step*."""
31
- if step < config.warmup_steps:
32
- # Linear warmup
33
- return config.learning_rate * (step + 1) / config.warmup_steps
34
-
35
- # Cosine decay phase
36
- decay_steps = total_steps - config.warmup_steps
37
- progress = (step - config.warmup_steps) / max(1, decay_steps)
38
- coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
39
- return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
40
-
41
-
42
- def set_lr(optimizer: torch.optim.Optimizer, lr: float) -> None:
43
- """Manually sets the learning rate for every parameter group."""
44
- for pg in optimizer.param_groups:
45
- pg["lr"] = lr
46
-
47
-
48
- # ---------------------------------------------------------------------------
49
- # Optimizer factory
50
- # ---------------------------------------------------------------------------
51
-
52
- def create_optimizer(model: SageModel, config: SageConfig) -> torch.optim.AdamW:
53
- """
54
- Create an AdamW optimizer with weight-decay applied only to weight
55
- matrices (not biases or LayerNorm parameters).
56
- """
57
- decay_params = []
58
- no_decay_params = []
59
-
60
- for name, param in model.named_parameters():
61
- if not param.requires_grad:
62
- continue
63
- # Biases and LayerNorm weights should not be decayed
64
- if param.ndim == 1 or "bias" in name:
65
- no_decay_params.append(param)
66
- else:
67
- decay_params.append(param)
68
-
69
- param_groups = [
70
- {"params": decay_params, "weight_decay": config.weight_decay},
71
- {"params": no_decay_params, "weight_decay": 0.0},
72
- ]
73
-
74
- # Enable Fused AdamW for 10% speedup if CUDA is active
75
- use_fused = torch.cuda.is_available() and 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames
76
- optimizer = torch.optim.AdamW(
77
- param_groups,
78
- lr=config.learning_rate,
79
- betas=(0.9, 0.95),
80
- eps=1e-8,
81
- fused=use_fused,
82
- )
83
- return optimizer
84
-
85
-
86
- # ---------------------------------------------------------------------------
87
- # Main training loop
88
- # ---------------------------------------------------------------------------
89
-
90
- def train(
91
- model: SageModel,
92
- config: SageConfig,
93
- total_steps: int = 500,
94
- dataset_name: str = "roneneldan/TinyStories",
95
- resume: bool = True,
96
- tokenizer: Optional[SageTokenizer] = None,
97
- ) -> SageModel:
98
- """
99
- Run pre-training for *total_steps* gradient-update steps.
100
-
101
- Parameters
102
- ----------
103
- model : SageModel
104
- The model to train (will be moved to config.device).
105
- config : SageConfig
106
- Hyperparameters.
107
- total_steps : int
108
- Number of optimizer steps to run.
109
- dataset_name : str
110
- HuggingFace dataset identifier.
111
- resume : bool
112
- If True, attempt to load the latest checkpoint before training.
113
- tokenizer : SageTokenizer, optional
114
- Tokenizer instance; one will be created if not supplied.
115
-
116
- Returns
117
- -------
118
- SageModel
119
- The trained model (on config.device).
120
- """
121
- # --- TURBO MODE: TF32 & COMPILE ---
122
- if torch.cuda.is_available():
123
- torch.set_float32_matmul_precision('high')
124
-
125
- device = config.device
126
- model = model.to(device)
127
-
128
- # Wrap model with torch.compile for graph-level optimization
129
- if hasattr(torch, "compile"):
130
- try:
131
- logger.info("Turbo Mode: Compiling model graph...")
132
- base = getattr(model, "module", model)
133
- compiled_base = torch.compile(base, mode="reduce-overhead")
134
- if hasattr(model, "module"):
135
- model.module = compiled_base
136
- else:
137
- model = compiled_base
138
- except (ValueError, RuntimeError, ImportError) as e:
139
- # Graceful fallback: numpy compatibility issues or other compilation errors
140
- logger.warning(f"torch.compile failed ({type(e).__name__}), proceeding without optimization: {str(e)[:100]}")
141
-
142
- tok = tokenizer or SageTokenizer()
143
- optimizer = create_optimizer(model, config)
144
-
145
- # ------- resume from checkpoint if available -------
146
- start_step = 0
147
- if resume:
148
- model, optimizer, start_step = load_checkpoint(
149
- model, optimizer, config.checkpoint_dir, device=str(device)
150
- )
151
- if start_step >= total_steps:
152
- logger.info("Checkpoint already at or past requested steps. Nothing to do.")
153
- return model
154
-
155
- # ------- mixed precision setup -------
156
- use_amp = device.type == "cuda"
157
- # prefer bf16 if the GPU supports it
158
- amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16
159
- scaler = GradScaler("cuda", enabled=(use_amp and amp_dtype == torch.float16))
160
-
161
- # ------- data loader -------
162
- loader = create_dataloader(config, dataset_name=dataset_name, tokenizer=tok)
163
- data_iter = iter(loader)
164
-
165
- # ------- W&B Logging -------
166
- wandb.init(
167
- project=config.project_name,
168
- name=f"pretrain-{time.strftime('%Y%m%d-%H%M')}",
169
- config=config.__dict__,
170
- )
171
-
172
- # ------- gradient checkpointing (saves VRAM) -------
173
- base_model = getattr(model, "module", model)
174
- if hasattr(base_model, "layers"):
175
- for layer in base_model.layers:
176
- layer: nn.Module
177
- # PyTorch gradient checkpointing
178
- try:
179
- from torch.utils.checkpoint import checkpoint # noqa: F401
180
- # We wrap the forward below instead, using it at call-site.
181
- except ImportError:
182
- pass
183
-
184
- # ------- training loop -------
185
- model.train()
186
- accum_loss = 0.0
187
- log_interval = 10
188
- t0 = time.time()
189
-
190
- pbar = tqdm(range(start_step, total_steps), desc="Training", unit="step")
191
- micro_step = 0
192
-
193
- for step in pbar:
194
- # Update learning rate
195
- lr = get_lr(step, config, total_steps)
196
- set_lr(optimizer, lr)
197
-
198
- # Accumulate gradients over multiple micro-batches
199
- optimizer.zero_grad(set_to_none=True)
200
- step_loss = 0.0
201
-
202
- for micro in range(config.gradient_accumulation_steps):
203
- try:
204
- batch = next(data_iter)
205
- except StopIteration:
206
- # Restart the data stream when exhausted
207
- data_iter = iter(loader)
208
- batch = next(data_iter)
209
-
210
- batch = batch.to(device)
211
- inputs = batch[:, :-1] # all tokens except last
212
- targets = batch[:, 1:] # all tokens except first
213
-
214
- with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
215
- logits, _ = model(inputs)
216
- loss = nn.functional.cross_entropy(
217
- logits.reshape(-1, logits.size(-1)),
218
- targets.reshape(-1),
219
- ignore_index=tok.pad_token_id,
220
- )
221
- # Scale loss by accumulation steps so the effective loss
222
- # is independent of the number of micro-batches.
223
- loss = loss / config.gradient_accumulation_steps
224
-
225
- scaler.scale(loss).backward()
226
- step_loss += loss.item()
227
-
228
- # Gradient clipping (unscale first for correct norm computation)
229
- scaler.unscale_(optimizer)
230
- torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
231
-
232
- scaler.step(optimizer)
233
- scaler.update()
234
-
235
- accum_loss += step_loss
236
- micro_step += 1
237
-
238
- # ------- logging -------
239
- if (step + 1) % log_interval == 0 or step == total_steps - 1:
240
- avg_loss = accum_loss / log_interval
241
- elapsed = time.time() - t0
242
- perplexity = math.exp(min(avg_loss, 20)) # clamp to avoid overflow
243
- pbar.set_postfix(
244
- loss=f"{avg_loss:.4f}",
245
- ppl=f"{perplexity:.2f}",
246
- lr=f"{lr:.2e}",
247
- elapsed=f"{elapsed:.1f}s",
248
- )
249
- logger.info(
250
- f"step={step+1} | loss={avg_loss:.4f} | ppl={perplexity:.2f} | lr={lr:.2e}"
251
- )
252
- wandb.log({
253
- "train/loss": avg_loss,
254
- "train/perplexity": perplexity,
255
- "train/lr": lr,
256
- }, step=step + 1)
257
- accum_loss = 0.0
258
-
259
- # ------- checkpoint every 100 steps -------
260
- if (step + 1) % 100 == 0 or step == total_steps - 1:
261
- save_checkpoint(model, optimizer, step + 1, config.checkpoint_dir)
262
- logger.info(f"Checkpoint saved at step {step + 1}")
263
-
264
- logger.info("Training complete.")
265
- wandb.finish()
266
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage/utils.py DELETED
@@ -1,143 +0,0 @@
1
- import os
2
- import logging
3
- import torch
4
- from typing import Optional, Tuple
5
-
6
- def _get_logger(name: str) -> logging.Logger:
7
- """Simple logger getter to avoid circular imports."""
8
- logger = logging.getLogger(name)
9
- if not logger.handlers:
10
- logger.setLevel(logging.INFO)
11
- console_handler = logging.StreamHandler()
12
- console_handler.setLevel(logging.INFO)
13
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
14
- console_handler.setFormatter(formatter)
15
- logger.addHandler(console_handler)
16
- logger.propagate = False
17
- return logger
18
-
19
- def get_compatible_device() -> torch.device:
20
- """
21
- Returns the best available device with CUDA compatibility checking.
22
-
23
- Automatically detects GPU compute capability and falls back to CPU
24
- if the current PyTorch installation doesn't support the GPU.
25
- """
26
- logger = _get_logger("sage.device")
27
-
28
- # Check CUDA availability and compatibility
29
- if torch.cuda.is_available():
30
- gpu_name = torch.cuda.get_device_name(0)
31
- capability = torch.cuda.get_device_capability()
32
- major, minor = capability
33
- sm_version = f"sm_{major}{minor}"
34
-
35
- logger.info(f"Detected GPU: {gpu_name} (CUDA Capability: {sm_version})")
36
-
37
- # PyTorch 2.0+ minimum is sm_70, PyTorch 1.13 supports sm_60
38
- # Check if we can actually run model operations (embedding, linear, etc.)
39
- try:
40
- # Test 1: Basic tensor operation
41
- test_tensor = torch.zeros(2, 4).cuda()
42
- _ = test_tensor + test_tensor
43
-
44
- # Test 2: Embedding (this is where P100/sm_60 often fails)
45
- import torch.nn as nn
46
- test_emb = nn.Embedding(10, 8).cuda()
47
- test_indices = torch.tensor([0, 1, 2], dtype=torch.long).cuda()
48
- _ = test_emb(test_indices)
49
-
50
- # Test 3: Linear layer
51
- test_linear = nn.Linear(8, 4).cuda()
52
- _ = test_linear(test_emb(test_indices))
53
-
54
- logger.info(f"✅ GPU is compatible with current PyTorch")
55
- return torch.device("cuda")
56
- except RuntimeError as e:
57
- if "no kernel image is available" in str(e).lower():
58
- logger.warning(f"⚠️ GPU {sm_version} not supported by current PyTorch")
59
- logger.warning(f" Current PyTorch supports: {torch.cuda.get_arch_list() or 'sm_70+'}")
60
- logger.warning(f" Install compatible PyTorch:")
61
- if major < 7:
62
- logger.warning(f" !pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121")
63
- else:
64
- logger.warning(f" !pip install torch --index-url https://download.pytorch.org/whl/cu118")
65
- logger.warning(f" Falling back to CPU...")
66
- else:
67
- raise
68
-
69
- # Check MPS (Apple Silicon)
70
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
71
- logger.info("Using Apple Silicon (MPS)")
72
- return torch.device("mps")
73
-
74
- logger.info("Using CPU")
75
- return torch.device("cpu")
76
-
77
- def setup_logger(name: str) -> logging.Logger:
78
- """Sets up a standardized logger for the SAGE system."""
79
- logger = logging.getLogger(name)
80
- if not logger.handlers:
81
- logger.setLevel(logging.INFO)
82
- console_handler = logging.StreamHandler()
83
- console_handler.setLevel(logging.INFO)
84
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
85
- console_handler.setFormatter(formatter)
86
- logger.addHandler(console_handler)
87
- # Prevent propagation to the root logger to avoid double printing
88
- logger.propagate = False
89
- return logger
90
-
91
- def save_checkpoint(
92
- model: torch.nn.Module,
93
- optimizer: Optional[torch.optim.Optimizer],
94
- step: int,
95
- checkpoint_dir: str,
96
- filename: str = "sage_latest.pt"
97
- ) -> str:
98
- """Saves the model and optimizer state to a checkpoint file."""
99
- os.makedirs(checkpoint_dir, exist_ok=True)
100
- path = os.path.join(checkpoint_dir, filename)
101
-
102
- base_model = getattr(model, "module", model)
103
- checkpoint = {
104
- 'step': step,
105
- 'model_state_dict': base_model.state_dict(),
106
- }
107
-
108
- if optimizer is not None:
109
- checkpoint['optimizer_state_dict'] = optimizer.state_dict()
110
-
111
- torch.save(checkpoint, path)
112
- return path
113
-
114
- def load_checkpoint(
115
- model: torch.nn.Module,
116
- optimizer: Optional[torch.optim.Optimizer],
117
- checkpoint_dir: str,
118
- filename: str = "sage_latest.pt",
119
- device: str = "cpu"
120
- ) -> Tuple[torch.nn.Module, Optional[torch.optim.Optimizer], int]:
121
- """Loads a checkpoint and restores the model and optimizer states."""
122
- path = os.path.join(checkpoint_dir, filename)
123
-
124
- if not os.path.exists(path):
125
- logger = setup_logger("utils")
126
- logger.warning(f"No checkpoint found at {path}. Starting from scratch.")
127
- return model, optimizer, 0
128
-
129
- # Load to CPU first to avoid VRAM spikes, then the module will be moved later if needed
130
- checkpoint = torch.load(path, map_location=device)
131
-
132
- base_model = getattr(model, "module", model)
133
- base_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
134
-
135
- if optimizer is not None and 'optimizer_state_dict' in checkpoint:
136
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
137
-
138
- step = checkpoint.get('step', 0)
139
-
140
- logger = setup_logger("utils")
141
- logger.info(f"Loaded checkpoint from {path} at step {step}")
142
-
143
- return model, optimizer, step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sage_single.py DELETED
@@ -1,824 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- SAGE — Self-Adaptive General Engine (Single-File Edition)
4
- =========================================================
5
- A complete mini-LLM in one file. Run with:
6
-
7
- python sage_single.py
8
-
9
- All architecture, data, training, inference, fine-tuning, quantization,
10
- RAG, and CLI components are included below.
11
- """
12
-
13
- import os
14
- import re
15
- import sys
16
- import math
17
- import copy
18
- import time
19
- import random
20
- import logging
21
- from dataclasses import dataclass
22
- from typing import Iterator, List, Optional, Tuple
23
-
24
- import numpy as np
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- import torch.nn.utils.prune as prune
29
- from torch.amp import GradScaler, autocast
30
- from torch.utils.data import IterableDataset, DataLoader
31
- from tqdm import tqdm
32
- import tiktoken
33
- import wandb
34
-
35
- __version__ = "1.0.0"
36
-
37
-
38
- # ===================================================================
39
- # Section 1 — Configuration
40
- # ===================================================================
41
-
42
- @dataclass
43
- class SageConfig:
44
- d_model: int = 512
45
- n_heads: int = 8
46
- n_kv_heads: int = 4
47
- n_layers: int = 6
48
- d_ff: int = 2048
49
- n_experts: int = 4
50
- num_experts_per_tok: int = 2
51
- vocab_size: int = 100277
52
- max_seq_len: int = 1024
53
- dropout: float = 0.1
54
- batch_size: int = 4
55
- gradient_accumulation_steps: int = 16
56
- learning_rate: float = 3e-4
57
- min_learning_rate: float = 1e-5
58
- warmup_steps: int = 100
59
- weight_decay: float = 0.01
60
- max_grad_norm: float = 1.0
61
- checkpoint_dir: str = "checkpoints"
62
- project_name: str = "sage-v2"
63
-
64
- @property
65
- def device(self):
66
- if torch.cuda.is_available():
67
- return torch.device("cuda")
68
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
69
- return torch.device("mps")
70
- return torch.device("cpu")
71
-
72
-
73
- # ===================================================================
74
- # Section 2 — Logging & Checkpoint Utilities
75
- # ===================================================================
76
-
77
- def setup_logger(name: str) -> logging.Logger:
78
- logger = logging.getLogger(name)
79
- if not logger.handlers:
80
- logger.setLevel(logging.INFO)
81
- h = logging.StreamHandler()
82
- h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s", datefmt="%H:%M:%S"))
83
- logger.addHandler(h)
84
- logger.propagate = False
85
- return logger
86
-
87
- logger = setup_logger("sage")
88
-
89
- def save_checkpoint(model, optimizer, step, checkpoint_dir, filename="sage_latest.pt"):
90
- os.makedirs(checkpoint_dir, exist_ok=True)
91
- path = os.path.join(checkpoint_dir, filename)
92
- base = getattr(model, "module", model)
93
- ckpt = {"step": step, "model_state_dict": base.state_dict()}
94
- if optimizer is not None:
95
- ckpt["optimizer_state_dict"] = optimizer.state_dict()
96
- torch.save(ckpt, path)
97
- return path
98
-
99
- def load_checkpoint(model, optimizer, checkpoint_dir, filename="sage_latest.pt", device="cpu"):
100
- path = os.path.join(checkpoint_dir, filename)
101
- if not os.path.exists(path):
102
- logger.warning(f"No checkpoint at {path}, starting fresh.")
103
- return model, optimizer, 0
104
- ckpt = torch.load(path, map_location=device)
105
- base = getattr(model, "module", model)
106
- base.load_state_dict(ckpt["model_state_dict"], strict=False)
107
- if optimizer and "optimizer_state_dict" in ckpt:
108
- optimizer.load_state_dict(ckpt["optimizer_state_dict"])
109
- step = ckpt.get("step", 0)
110
- logger.info(f"Loaded checkpoint from {path} (step {step})")
111
- return model, optimizer, step
112
-
113
-
114
- # ===================================================================
115
- # Section 3 — Tokenizer
116
- # ===================================================================
117
-
118
- class SageTokenizer:
119
- def __init__(self, encoding_name="cl100k_base"):
120
- self.enc = tiktoken.get_encoding(encoding_name)
121
- self.eos_token_id = self.enc.n_vocab - 1
122
- self.pad_token_id = self.enc.n_vocab - 2
123
- self.vocab_size = self.enc.n_vocab
124
-
125
- def encode(self, text, add_eos=False):
126
- tokens = self.enc.encode(text, allowed_special="all")
127
- if add_eos:
128
- tokens.append(self.eos_token_id)
129
- return tokens
130
-
131
- def decode(self, tokens):
132
- filtered = [t for t in tokens if t not in (self.eos_token_id, self.pad_token_id)]
133
- return self.enc.decode(filtered)
134
-
135
-
136
- # ===================================================================
137
- # Section 4 — Model Architecture (RoPE, Attention, MoE, Transformer)
138
- # ===================================================================
139
-
140
- def precompute_freqs_cis(dim, end, theta=10000.0):
141
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
142
- t = torch.arange(end, dtype=torch.float32)
143
- freqs = torch.outer(t, freqs)
144
- return torch.polar(torch.ones_like(freqs), freqs)
145
-
146
- def apply_rotary_emb(xq, xk, freqs_cis):
147
- # Ensure freqs_cis is complex (DataParallel can sometimes replicate it as real)
148
- if not torch.is_complex(freqs_cis) and freqs_cis.shape[-1] == 2:
149
- freqs_cis = torch.view_as_complex(freqs_cis)
150
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
151
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
152
- fc = freqs_cis.unsqueeze(0).unsqueeze(2)
153
- xq_out = torch.view_as_real(xq_ * fc).flatten(3)
154
- xk_out = torch.view_as_real(xk_ * fc).flatten(3)
155
- return xq_out.type_as(xq), xk_out.type_as(xk)
156
-
157
- def repeat_kv(x, n_rep):
158
- if n_rep == 1: return x
159
- B, T, n_kv_heads, head_dim = x.size()
160
- return x[:, :, :, None, :].expand(B, T, n_kv_heads, n_rep, head_dim).reshape(B, T, n_kv_heads * n_rep, head_dim)
161
-
162
- class CausalSelfAttention(nn.Module):
163
- def __init__(self, config):
164
- super().__init__()
165
- self.n_heads = config.n_heads
166
- self.n_kv_heads = config.n_kv_heads
167
- self.n_rep = self.n_heads // self.n_kv_heads
168
- self.d_model = config.d_model
169
- self.head_dim = config.d_model // config.n_heads
170
- self.wq = nn.Linear(config.d_model, config.n_heads * self.head_dim, bias=False)
171
- self.wk = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False)
172
- self.wv = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False)
173
- self.wo = nn.Linear(config.d_model, config.d_model, bias=False)
174
- self.resid_dropout = nn.Dropout(config.dropout)
175
-
176
- def forward(self, x, freqs_cis, kv_cache=None):
177
- B, T, C = x.size()
178
- q, k, v = self.wq(x), self.wk(x), self.wv(x)
179
- q = q.view(B, T, self.n_heads, self.head_dim)
180
- k = k.view(B, T, self.n_kv_heads, self.head_dim)
181
- v = v.view(B, T, self.n_kv_heads, self.head_dim)
182
- q, k = apply_rotary_emb(q, k, freqs_cis)
183
- if kv_cache is not None:
184
- k = torch.cat([kv_cache[0], k], dim=1)
185
- v = torch.cat([kv_cache[1], v], dim=1)
186
- new_kv = (k, v)
187
- else:
188
- new_kv = None
189
- k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
190
- q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
191
- is_causal = kv_cache is None and T > 1
192
- try:
193
- y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0 if not self.training else 0.1, is_causal=is_causal)
194
- except Exception:
195
- attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
196
- if is_causal:
197
- mask = torch.tril(torch.ones(T, T, device=q.device)).view(1, 1, T, T)
198
- attn = attn.masked_fill(mask == 0, float('-inf'))
199
- attn = F.softmax(attn, dim=-1)
200
- if self.training: attn = F.dropout(attn, p=0.1)
201
- y = attn @ v
202
- y = y.transpose(1, 2).contiguous().view(B, T, C)
203
- return self.resid_dropout(self.wo(y)), new_kv
204
-
205
- class ExpertFFN(nn.Module):
206
- def __init__(self, config):
207
- super().__init__()
208
- self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False)
209
- self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False)
210
- self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False)
211
- self.dropout = nn.Dropout(config.dropout)
212
-
213
- def forward(self, x):
214
- return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
215
-
216
- class MoE(nn.Module):
217
- def __init__(self, config):
218
- super().__init__()
219
- self.n_experts = config.n_experts
220
- self.top_k = config.num_experts_per_tok
221
- self.router = nn.Linear(config.d_model, config.n_experts, bias=False)
222
- self.experts = nn.ModuleList([ExpertFFN(config) for _ in range(config.n_experts)])
223
-
224
- def forward(self, x):
225
- B, T, C = x.size()
226
- flat = x.view(-1, C)
227
- weights = F.softmax(self.router(flat), dim=-1)
228
- weights, indices = torch.topk(weights, self.top_k, dim=-1)
229
- weights = weights / weights.sum(dim=-1, keepdim=True)
230
- out = torch.zeros_like(flat)
231
- for i, expert in enumerate(self.experts):
232
- mask = (indices == i)
233
- tok_idx, kth = torch.where(mask)
234
- if tok_idx.shape[0] > 0:
235
- out[tok_idx] += expert(flat[tok_idx]) * weights[tok_idx, kth].unsqueeze(-1)
236
- return out.view(B, T, C)
237
-
238
- class TransformerBlock(nn.Module):
239
- def __init__(self, config):
240
- super().__init__()
241
- self.norm1 = nn.LayerNorm(config.d_model)
242
- self.attn = CausalSelfAttention(config)
243
- self.norm2 = nn.LayerNorm(config.d_model)
244
- self.moe = MoE(config)
245
-
246
- def forward(self, x, freqs_cis, kv_cache=None):
247
- h, new_kv = self.attn(self.norm1(x), freqs_cis, kv_cache)
248
- x = x + h
249
- x = x + self.moe(self.norm2(x))
250
- return x, new_kv
251
-
252
- class SageModel(nn.Module):
253
- def __init__(self, config):
254
- super().__init__()
255
- self.config = config
256
- self.wte = nn.Embedding(config.vocab_size, config.d_model)
257
- self.drop = nn.Dropout(config.dropout)
258
- self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
259
- self.ln_f = nn.LayerNorm(config.d_model)
260
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # tied
261
- self.wte.weight = self.lm_head.weight
262
- self.register_buffer("freqs_cis", precompute_freqs_cis(config.d_model // config.n_heads, config.max_seq_len * 2), persistent=False)
263
- self.apply(self._init_weights)
264
-
265
- def _init_weights(self, m):
266
- if isinstance(m, nn.Linear):
267
- nn.init.normal_(m.weight, std=0.02)
268
- if m.bias is not None: nn.init.zeros_(m.bias)
269
- elif isinstance(m, nn.Embedding):
270
- nn.init.normal_(m.weight, std=0.02)
271
- elif isinstance(m, nn.LayerNorm):
272
- nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
273
-
274
- def forward(self, idx, kv_caches=None):
275
- B, T = idx.size()
276
- start = kv_caches[0][0].shape[1] if kv_caches else 0
277
- fc = self.freqs_cis[start:start + T]
278
- x = self.drop(self.wte(idx))
279
- new_kvs = []
280
- for i, layer in enumerate(self.layers):
281
- kv = kv_caches[i] if kv_caches else None
282
- if self.training and kv is None:
283
- def create_custom_forward(module):
284
- def custom_forward(x_in, freqs_cis_in):
285
- return module(x_in, freqs_cis_in, None)
286
- return custom_forward
287
- x, nkv = torch.utils.checkpoint.checkpoint(create_custom_forward(layer), x, fc, use_reentrant=False)
288
- else:
289
- x, nkv = layer(x, fc, kv)
290
- if nkv is not None: new_kvs.append(nkv)
291
- return self.lm_head(self.ln_f(x)), new_kvs if new_kvs else None
292
-
293
-
294
- # ===================================================================
295
- # Section 5 — Data Pipeline
296
- # ===================================================================
297
-
298
- _HTML_RE = re.compile(r"<[^>]+>")
299
-
300
- def clean_text(text):
301
- text = _HTML_RE.sub("", text)
302
- text = re.sub(r"[ \t]+", " ", text)
303
- text = re.sub(r"\n{3,}", "\n\n", text)
304
- return text.strip()
305
-
306
- class StreamingTextDataset(IterableDataset):
307
- def __init__(self, dataset_name="HuggingFaceFW/fineweb-edu", split="train", seq_len=512, tokenizer=None, buffer_size=1000, text_field="text"):
308
- super().__init__()
309
- self.dataset_name, self.split, self.seq_len = dataset_name, split, seq_len
310
- self.tokenizer = tokenizer or SageTokenizer()
311
- self.buffer_size, self.text_field = buffer_size, text_field
312
- if "fineweb-edu" in dataset_name.lower(): self.text_field = "text"
313
- elif "tinystories" in dataset_name.lower(): self.text_field = "text"
314
-
315
- def _tokens(self):
316
- from datasets import load_dataset
317
- ds = load_dataset(self.dataset_name, split=self.split, streaming=True)
318
- for s in ds:
319
- raw = s.get(self.text_field, "")
320
- if not raw or len(raw) < 50: continue
321
- text = clean_text(raw)
322
- yield from self.tokenizer.encode(text, add_eos=True)
323
-
324
- def __iter__(self):
325
- chunk, buf = [], []
326
- for tok in self._tokens():
327
- chunk.append(tok)
328
- if len(chunk) == self.seq_len + 1:
329
- buf.append(torch.tensor(chunk, dtype=torch.long))
330
- chunk = []
331
- if len(buf) >= self.buffer_size:
332
- random.shuffle(buf)
333
- while len(buf) > self.buffer_size // 2: yield buf.pop()
334
- random.shuffle(buf)
335
- yield from buf
336
-
337
- def create_dataloader(config, dataset_name="HuggingFaceFW/fineweb-edu", tokenizer=None):
338
- tok = tokenizer or SageTokenizer()
339
- ds = StreamingTextDataset(dataset_name=dataset_name, seq_len=config.max_seq_len, tokenizer=tok)
340
- return DataLoader(ds, batch_size=config.batch_size, num_workers=2, pin_memory=True, drop_last=True)
341
-
342
-
343
- # ===================================================================
344
- # Section 6 — Training
345
- # ===================================================================
346
-
347
- def get_lr(step, config, total_steps):
348
- if step < config.warmup_steps:
349
- return config.learning_rate * (step + 1) / config.warmup_steps
350
- progress = (step - config.warmup_steps) / max(1, total_steps - config.warmup_steps)
351
- coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
352
- return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
353
-
354
- def create_optimizer(model, config):
355
- decay, no_decay = [], []
356
- for n, p in model.named_parameters():
357
- if not p.requires_grad: continue
358
- (no_decay if p.ndim == 1 or "bias" in n else decay).append(p)
359
- # Enable Fused AdamW for 10% speedup if CUDA is active
360
- use_fused = torch.cuda.is_available() and 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames
361
- return torch.optim.AdamW([
362
- {"params": decay, "weight_decay": config.weight_decay},
363
- {"params": no_decay, "weight_decay": 0.0},
364
- ], lr=config.learning_rate, betas=(0.9, 0.95), fused=use_fused)
365
-
366
- def train_model(model, config, total_steps=500, dataset_name="roneneldan/TinyStories", resume=True, tokenizer=None):
367
- device = config.device
368
- # --- TURBO MODE: TF32 & COMPILE ---
369
- if torch.cuda.is_available():
370
- torch.set_float32_matmul_precision('high')
371
-
372
- model = model.to(device)
373
- tok = tokenizer or SageTokenizer()
374
-
375
- # Wrap model with torch.compile for graph-level optimization
376
- # mode="reduce-overhead" is ideal for smaller-to-medium models like SAGE
377
- if hasattr(torch, "compile"):
378
- try:
379
- logger.info("Turbo Mode: Compiling model graph...")
380
- # Compile the base model (unwrapped from DataParallel if present)
381
- base = getattr(model, "module", model)
382
- compiled_base = torch.compile(base, mode="reduce-overhead")
383
- if hasattr(model, "module"):
384
- model.module = compiled_base
385
- else:
386
- model = compiled_base
387
- except (ValueError, RuntimeError, ImportError) as e:
388
- # Graceful fallback: numpy compatibility issues or other compilation errors
389
- logger.warning(f"torch.compile failed ({type(e).__name__}), proceeding without optimization: {str(e)[:100]}")
390
- # Continue with uncompiled model
391
-
392
- opt = create_optimizer(model, config)
393
- start_step = 0
394
- if resume:
395
- model, opt, start_step = load_checkpoint(model, opt, config.checkpoint_dir, device=str(device))
396
- if start_step >= total_steps: return model
397
- use_amp = device.type == "cuda"
398
- amp_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16
399
- scaler = GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16)
400
- loader = create_dataloader(config, dataset_name, tok)
401
- data_iter = iter(loader)
402
- wandb.init(project=config.project_name, name=f"pretrain-{time.strftime('%Y%m%d-%H%M')}", config=config.__dict__)
403
- model.train()
404
- accum_loss, t0 = 0.0, time.time()
405
- pbar = tqdm(range(start_step, total_steps), desc="Training")
406
- for step in pbar:
407
- lr = get_lr(step, config, total_steps)
408
- for pg in opt.param_groups: pg["lr"] = lr
409
- opt.zero_grad(set_to_none=True)
410
- step_loss = 0.0
411
- for _ in range(config.gradient_accumulation_steps):
412
- try: batch = next(data_iter)
413
- except StopIteration: data_iter = iter(loader); batch = next(data_iter)
414
- batch = batch.to(device)
415
- with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
416
- logits, _ = model(batch[:, :-1])
417
- loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1), ignore_index=tok.pad_token_id)
418
- loss = loss / config.gradient_accumulation_steps
419
- scaler.scale(loss).backward()
420
- step_loss += loss.item()
421
- scaler.unscale_(opt)
422
- nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
423
- scaler.step(opt); scaler.update()
424
- accum_loss += step_loss
425
- if (step + 1) % 10 == 0:
426
- avg = accum_loss / 10
427
- pbar.set_postfix(loss=f"{avg:.4f}", ppl=f"{math.exp(min(avg,20)):.1f}", lr=f"{lr:.2e}")
428
- wandb.log({"train/loss": avg, "train/perplexity": math.exp(min(avg, 20)), "train/lr": lr}, step=step + 1)
429
- accum_loss = 0.0
430
- if (step + 1) % 100 == 0:
431
- save_checkpoint(model, opt, step + 1, config.checkpoint_dir)
432
- save_checkpoint(model, opt, total_steps, config.checkpoint_dir)
433
- logger.info("Training complete.")
434
- wandb.finish()
435
- return model
436
-
437
-
438
- # ===================================================================
439
- # Section 7 — Inference
440
- # ===================================================================
441
-
442
- def sample_next(logits, temperature=0.8, top_k=50, top_p=0.9, greedy=False):
443
- if greedy: return logits.argmax(-1, keepdim=True)
444
- logits = logits / max(temperature, 1e-8)
445
- if 0 < top_k < logits.size(-1):
446
- v, _ = torch.topk(logits, top_k)
447
- logits[logits < v[:, -1:]] = float("-inf")
448
- if top_p < 1.0:
449
- sorted_l, sorted_i = torch.sort(logits, descending=True)
450
- cum = torch.cumsum(F.softmax(sorted_l, -1), -1)
451
- mask = cum - F.softmax(sorted_l, -1) >= top_p
452
- sorted_l[mask] = float("-inf")
453
- logits = logits.scatter(1, sorted_i, sorted_l)
454
- return torch.multinomial(F.softmax(logits, -1), 1)
455
-
456
- @torch.no_grad()
457
- def generate(model, tokenizer, prompt, max_new=256, temperature=0.8, top_k=50, top_p=0.9, stream=True, device=None):
458
- device = device or next(model.parameters()).device
459
- base = getattr(model, "module", model)
460
- base.eval()
461
- ids = tokenizer.encode(prompt) or [tokenizer.eos_token_id]
462
- inp = torch.tensor([ids], dtype=torch.long, device=device)
463
- logits, kvs = base(inp)
464
- gen = list(ids)
465
- nl = logits[:, -1, :]
466
- for _ in range(max_new):
467
- nid = sample_next(nl, temperature, top_k, top_p)
468
- tid = nid.item()
469
- if tid == tokenizer.eos_token_id: break
470
- gen.append(tid)
471
- if stream: print(tokenizer.decode([tid]), end="", flush=True)
472
- logits, kvs = base(nid.view(1, 1), kv_caches=kvs)
473
- nl = logits[:, -1, :]
474
- if stream: print()
475
- base.train()
476
- return tokenizer.decode(gen)
477
-
478
-
479
- # ===================================================================
480
- # Section 8 — LoRA Fine-tuning
481
- # ===================================================================
482
-
483
- class LoRALinear(nn.Module):
484
- def __init__(self, original, rank=8, alpha=16.0):
485
- super().__init__()
486
- self.original = original
487
- self.scaling = alpha / rank
488
- device, dtype = original.weight.device, original.weight.dtype
489
- self.lora_A = nn.Parameter(torch.randn(original.in_features, rank, device=device, dtype=dtype) * 0.01)
490
- self.lora_B = nn.Parameter(torch.zeros(rank, original.out_features, device=device, dtype=dtype))
491
- original.weight.requires_grad = False
492
- if original.bias is not None: original.bias.requires_grad = False
493
-
494
- def forward(self, x):
495
- return self.original(x) + (x @ self.lora_A @ self.lora_B) * self.scaling
496
-
497
- def merge(self):
498
- m = copy.deepcopy(self.original)
499
- m.weight.data += (self.lora_B.T @ self.lora_A.T).T * self.scaling
500
- m.weight.requires_grad = True
501
- return m
502
-
503
- def inject_lora(model, rank=8, alpha=16.0):
504
- base = getattr(model, "module", model)
505
- for layer in base.layers:
506
- a = layer.attn
507
- for name in ("wq", "wk", "wv", "wo"):
508
- setattr(a, name, LoRALinear(getattr(a, name), rank, alpha))
509
- tp = sum(p.numel() for p in base.parameters() if p.requires_grad)
510
- logger.info(f"LoRA injected (rank={rank}). Trainable params: {tp:,}")
511
- return model
512
-
513
- def merge_lora(model):
514
- base = getattr(model, "module", model)
515
- for layer in base.layers:
516
- a = layer.attn
517
- for name in ("wq", "wk", "wv", "wo"):
518
- m = getattr(a, name)
519
- if isinstance(m, LoRALinear): setattr(a, name, m.merge())
520
- logger.info("LoRA merged.")
521
- return model
522
-
523
- INSTRUCTION_TEMPLATE = "### Instruction:\n{instruction}\n\n### Response:\n{response}"
524
-
525
- DEMO_SAMPLES = [
526
- {"instruction": "What is the capital of France?", "response": "The capital of France is Paris."},
527
- {"instruction": "Explain gravity simply.", "response": "Gravity pulls objects toward each other. More mass means stronger pull."},
528
- {"instruction": "Write a short poem about the ocean.", "response": "Waves crash on sandy shore,\nThe ocean sings forevermore.\nDeep blue meets the sky,\nSeagulls dance and clouds float by."},
529
- {"instruction": "What is 15 times 12?", "response": "15 times 12 equals 180."},
530
- {"instruction": "Summarize photosynthesis.", "response": "Plants convert sunlight, water, and CO2 into glucose and oxygen."},
531
- {"instruction": "Tell me a fun fact about space.", "response": "A day on Venus is longer than its year — 243 Earth days to rotate vs 225 to orbit the Sun."},
532
- {"instruction": "How do airplanes fly?", "response": "Wings generate lift because air moves faster over the curved top, creating lower pressure above."},
533
- {"instruction": "What is machine learning?", "response": "ML is AI where computers learn patterns from data instead of being explicitly programmed."},
534
- ]
535
-
536
- def create_instruction_batch(samples, tokenizer, max_len=512):
537
- all_ids, all_masks = [], []
538
- for s in samples:
539
- inst_text = f"### Instruction:\n{s['instruction'].strip()}\n\n### Response:\n"
540
- full_text = inst_text + s["response"].strip()
541
- inst_toks = tokenizer.encode(inst_text)
542
- full_toks = tokenizer.encode(full_text, add_eos=True)[:max_len]
543
- ni = min(len(inst_toks), len(full_toks))
544
- mask = [0] * ni + [1] * (len(full_toks) - ni)
545
- pad = max_len - len(full_toks)
546
- full_toks += [tokenizer.pad_token_id] * pad
547
- mask += [0] * pad
548
- all_ids.append(full_toks); all_masks.append(mask)
549
- return {"input_ids": torch.tensor(all_ids), "labels": torch.tensor(all_ids), "loss_mask": torch.tensor(all_masks, dtype=torch.float32)}
550
-
551
- def finetune(model, config, samples=None, steps=200, use_lora=True, tokenizer=None):
552
- device = config.device; model = model.to(device)
553
- tok = tokenizer or SageTokenizer()
554
- samples = samples or DEMO_SAMPLES
555
- if use_lora: model = inject_lora(model)
556
- opt = create_optimizer(model, config)
557
- use_amp = device.type == "cuda"
558
- amp_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16
559
- scaler = GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16)
560
- wandb.init(project=config.project_name, name=f"finetune-{time.strftime('%Y%m%d-%H%M')}", config=config.__dict__)
561
- model.train(); accum = 0.0
562
- for step in tqdm(range(steps), desc="Fine-tuning"):
563
- lr = get_lr(step, config, steps)
564
- for pg in opt.param_groups: pg["lr"] = lr
565
- batch = create_instruction_batch(random.choices(samples, k=min(config.batch_size, len(samples))), tok, config.max_seq_len)
566
- ids, labels, mask = batch["input_ids"].to(device), batch["labels"].to(device), batch["loss_mask"].to(device)
567
- opt.zero_grad(set_to_none=True)
568
- with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
569
- logits, _ = model(ids)
570
- sl, slb, sm = logits[:, :-1, :].contiguous(), labels[:, 1:].contiguous(), mask[:, 1:].contiguous()
571
- ptl = F.cross_entropy(sl.view(-1, sl.size(-1)), slb.view(-1), reduction="none").view(slb.size())
572
- loss = (ptl * sm).sum() / sm.sum().clamp(min=1)
573
- scaler.scale(loss).backward()
574
- scaler.unscale_(opt); nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
575
- scaler.step(opt); scaler.update()
576
- accum += loss.item()
577
- if (step + 1) % 10 == 0: accum = 0.0
578
- if use_lora: model = merge_lora(model)
579
- save_checkpoint(model, None, steps, config.checkpoint_dir, "sage_finetuned.pt")
580
- logger.info("Fine-tuning complete.")
581
- wandb.finish()
582
- return model
583
-
584
-
585
- # ===================================================================
586
- # Section 9 — Optimization (Quantize / Prune)
587
- # ===================================================================
588
-
589
- def quantize_int8(model):
590
- base = getattr(model, "module", model)
591
- model = base.cpu().eval()
592
- q = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
593
- logger.info("INT8 quantization complete.")
594
- return q
595
-
596
- def prune_model(model, amount=0.3):
597
- base = getattr(model, "module", model)
598
- for _, m in base.named_modules():
599
- if isinstance(m, nn.Linear):
600
- prune.l1_unstructured(m, "weight", amount=amount)
601
- prune.remove(m, "weight")
602
- logger.info(f"Pruning complete ({amount*100:.0f}% sparsity target).")
603
- return model
604
-
605
-
606
- # ===================================================================
607
- # Section 10 — RAG & Memory
608
- # ===================================================================
609
-
610
- def _embed(text, tokenizer, model, device):
611
- toks = tokenizer.encode(text)
612
- base = getattr(model, "module", model)
613
- if not toks: return np.zeros(base.wte.weight.shape[1], dtype=np.float32)
614
- with torch.no_grad():
615
- emb = base.wte(torch.tensor([toks], device=device)).mean(1)
616
- emb = F.normalize(emb, p=2, dim=-1)
617
- return emb.squeeze(0).cpu().numpy()
618
-
619
- class VectorStore:
620
- def __init__(self, dim):
621
- self.dim = dim; self.docs = []; self.index = None
622
- try:
623
- import faiss
624
- self.index = faiss.IndexFlatIP(dim)
625
- except ImportError:
626
- logger.warning("FAISS not installed. RAG will use brute-force search.")
627
-
628
- def add(self, texts, embeddings):
629
- if self.index is not None:
630
- self.index.add(embeddings.astype(np.float32))
631
- else:
632
- # Brute-force fallback
633
- if not hasattr(self, '_embeddings'):
634
- self._embeddings = []
635
- self._embeddings.extend(embeddings.astype(np.float32))
636
- self.docs.extend(texts)
637
-
638
- def search(self, qemb, k=3):
639
- if not self.docs: return []
640
- k = min(k, len(self.docs))
641
- if self.index is not None:
642
- scores, idx = self.index.search(qemb.reshape(1, -1).astype(np.float32), k)
643
- return [(self.docs[i], float(s)) for s, i in zip(scores[0], idx[0]) if i >= 0]
644
- else:
645
- # Brute-force cosine similarity
646
- import numpy as np
647
- qemb = qemb.reshape(1, -1).astype(np.float32)
648
- embs = np.array(self._embeddings)
649
- sims = np.dot(embs, qemb.T).flatten()
650
- top_k = np.argsort(sims)[-k:][::-1]
651
- return [(self.docs[i], float(sims[i])) for i in top_k]
652
-
653
- @property
654
- def size(self): return len(self.docs)
655
-
656
- class RAGManager:
657
- def __init__(self, model, tokenizer, device, chunk_size=200):
658
- self.model, self.tokenizer, self.device = model, tokenizer, device
659
- base = getattr(model, "module", model)
660
- self.store = VectorStore(base.wte.weight.shape[1])
661
- self.enabled = False
662
-
663
- def add_documents(self, texts):
664
- chunks = []
665
- for t in texts:
666
- words = t.split()
667
- for i in range(0, len(words), 150):
668
- chunks.append(" ".join(words[i:i+200]))
669
- if chunks:
670
- embs = np.stack([_embed(c, self.tokenizer, self.model, self.device) for c in chunks])
671
- self.store.add(chunks, embs)
672
-
673
- def retrieve(self, query, k=3):
674
- if not self.enabled or not self.store.size: return ""
675
- qe = _embed(query, self.tokenizer, self.model, self.device)
676
- results = self.store.search(qe, k)
677
- return "\n\n".join(f"[Context {i+1}] {d}" for i, (d, _) in enumerate(results)) + "\n\n" if results else ""
678
-
679
- def toggle(self, on): self.enabled = on
680
-
681
- DEFAULT_SYSTEM_PROMPT = (
682
- "You are a high-quality reasoning assistant model.\n"
683
- "You must ONLY learn from high-quality instruction and reasoning datasets.\n"
684
- "You must IGNORE any previously trained low-quality or repetitive patterns.\n\n"
685
- "Training preference rules:\n"
686
- "1. Prioritize step-by-step reasoning over short or repetitive answers.\n"
687
- "2. Always produce structured logical explanations when solving problems.\n"
688
- "3. Avoid repetition, filler words, or looped phrases.\n"
689
- "4. Prefer datasets with mathematical reasoning and high-quality instruction.\n"
690
- "5. Do not imitate noisy conversational or corrupted text patterns.\n"
691
- "6. Always prefer clarity, correctness, and structured reasoning.\n\n"
692
- "Output behavior goal:\n"
693
- "- Think in steps.\n"
694
- "- Explain logic clearly.\n"
695
- "- Produce final answer only after reasoning."
696
- )
697
-
698
- class ConversationHistory:
699
- def __init__(self, tokenizer, max_tokens=900):
700
- self.tokenizer, self.max_tokens, self.turns = tokenizer, max_tokens, []
701
-
702
- def add(self, role, text):
703
- self.turns.append({"role": role, "text": text})
704
- while sum(len(self.tokenizer.encode(t["text"])) for t in self.turns) > self.max_tokens and len(self.turns) > 1:
705
- self.turns.pop(0)
706
-
707
- def build_prompt(self, msg, rag_ctx=""):
708
- parts = [DEFAULT_SYSTEM_PROMPT]
709
- if rag_ctx: parts.append(rag_ctx)
710
- for t in self.turns:
711
- parts.append(f"{'User' if t['role']=='user' else 'SAGE'}: {t['text']}")
712
- parts += [f"User: {msg}", "SAGE:"]
713
- return "\n\n".join(parts)
714
-
715
- def clear(self): self.turns.clear()
716
-
717
-
718
- # ===================================================================
719
- # Section 11 — CLI
720
- # ===================================================================
721
-
722
- BANNER = r"""
723
- ╔══════════════════════════════════════════════════════════════╗
724
- ║ ███████ █████ ██████ ███████ ║
725
- ║ ██ ██ ██ ██ ██ ║
726
- ║ ███████ ███████ ██ ███ █████ ║
727
- ║ ██ ██ ██ ██ ██ ██ ║
728
- ║ ███████ ██ ██ ██████ ███████ ║
729
- ║ Self-Adaptive General Engine v{ver} ║
730
- ╚══════════════════════════════════════════════════════════════╝"""
731
-
732
- HELP = """
733
- /train [steps] Train (default 100)
734
- /finetune [steps] Instruction-tune with LoRA (default 200)
735
- /save Save checkpoint
736
- /load Load checkpoint
737
- /quantize INT8 quantization
738
- /rag on|off|add Toggle or add docs for RAG
739
- /clear Clear history
740
- /help This message
741
- /exit Quit
742
- """
743
-
744
- def main():
745
- config = SageConfig()
746
- tok = SageTokenizer()
747
- config.vocab_size = tok.vocab_size
748
- print(" Initializing SAGE …")
749
- model = SageModel(config).to(config.device)
750
- if torch.cuda.is_available() and torch.cuda.device_count() > 1:
751
- print(f" Multi-GPU detected: {torch.cuda.device_count()} GPUs. Using DataParallel.")
752
- model = nn.DataParallel(model)
753
- model, _, step = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device))
754
- base = getattr(model, "module", model)
755
- total = sum(p.numel() for p in base.parameters())
756
- print(BANNER.format(ver=__version__))
757
- print(f" Params: {total:,} ({total/1e6:.1f}M) | Context: {config.max_seq_len} | Device: {config.device}")
758
- print(f" Layers: {config.n_layers} | Heads: {config.n_heads} | Experts: {config.n_experts}")
759
- if step: print(f" Resumed from step {step}")
760
- print(" Type /help for commands.\n")
761
-
762
- rag = RAGManager(model, tok, config.device)
763
- hist = ConversationHistory(tok, config.max_seq_len - 128)
764
-
765
- if len(sys.argv) > 1:
766
- cmd = sys.argv[1].lower()
767
- args = sys.argv[2:]
768
- if cmd == "--train":
769
- s = int(args[0]) if args else 100
770
- train_model(model, config, s, tokenizer=tok)
771
- return
772
- elif cmd == "--finetune":
773
- s = int(args[0]) if args else 200
774
- finetune(model, config, steps=s, tokenizer=tok)
775
- return
776
- elif cmd == "--quantize":
777
- quantize_int8(model)
778
- return
779
- else:
780
- print(f" Unknown argument: {cmd}\n Usage: --train [steps] | --finetune [steps] | --quantize")
781
- return
782
-
783
- while True:
784
- try: inp = input("You: ").strip()
785
- except (EOFError, KeyboardInterrupt): print("\n Goodbye!"); break
786
- if not inp: continue
787
-
788
- if inp.startswith("/"):
789
- parts = inp.split(); cmd = parts[0].lower(); args = parts[1:]
790
- if cmd == "/exit": print(" Goodbye!"); break
791
- elif cmd == "/help": print(HELP)
792
- elif cmd == "/train":
793
- s = int(args[0]) if args else 100
794
- model = train_model(model, config, s, tokenizer=tok)
795
- print("\n Sample:"); generate(model, tok, "Once upon a time", max_new=80, device=config.device); print()
796
- elif cmd == "/finetune":
797
- s = int(args[0]) if args else 200
798
- model = finetune(model, config, steps=s, tokenizer=tok)
799
- print("\n Sample:"); generate(model, tok, "### Instruction:\nWhat is gravity?\n\n### Response:\n", max_new=100, device=config.device); print()
800
- elif cmd == "/save": print(f" Saved to {save_checkpoint(model, None, 0, config.checkpoint_dir)}")
801
- elif cmd == "/load":
802
- model, _, s = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device))
803
- model = model.to(config.device); rag.model = model; print(f" Loaded (step {s})")
804
- elif cmd == "/quantize": model = quantize_int8(model); rag.model = model
805
- elif cmd == "/rag":
806
- if not args: print(f" RAG {'on' if rag.enabled else 'off'} ({rag.store.size} chunks)")
807
- elif args[0] == "on": rag.toggle(True); print(" RAG on.")
808
- elif args[0] == "off": rag.toggle(False); print(" RAG off.")
809
- elif args[0] == "add" and len(args) > 1: rag.add_documents([" ".join(args[1:])]); print(f" Added. {rag.store.size} chunks.")
810
- else: print(" /rag on|off|add <text>")
811
- elif cmd == "/clear": hist.clear(); print(" Cleared.")
812
- else: print(f" Unknown: {cmd}")
813
- continue
814
-
815
- ctx = rag.retrieve(inp)
816
- prompt = hist.build_prompt(inp, ctx)
817
- hist.add("user", inp)
818
- print("SAGE: ", end="", flush=True)
819
- resp = generate(model, tok, prompt, max_new=256, stream=True, device=config.device)
820
- reply = resp.split("SAGE:")[-1].strip() if "SAGE:" in resp else resp[len(prompt):].strip()
821
- hist.add("assistant", reply)
822
-
823
- if __name__ == "__main__":
824
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_data_pipeline.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ python -m tokenizer.train_tokenizer "$@"
scripts/run_eval.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ python - <<'PY'
5
+ from eval.benchmarks import run_registered_benchmarks
6
+ from model.model import SageTransformer
7
+ from model.config import ModelConfig
8
+
9
+ model = SageTransformer(ModelConfig())
10
+ for result in run_registered_benchmarks(model):
11
+ print(result)
12
+ PY
scripts/run_serve.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ uvicorn serve.server:app --host "${HOST:-0.0.0.0}" --port "${PORT:-8000}" "$@"
scripts/run_serve_cpu.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ uvicorn serve.server_cpu:app --host "${HOST:-0.0.0.0}" --port "${PORT:-8001}" "$@"