feat: rewrite SAGE 1B architecture and replace legacy repo contents
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -1
- .gitignore +8 -2
- README.md +402 -110
- SAGE_KAGGLE_GUIDE.md +0 -146
- SAGE_V3_ROADMAP.md +0 -52
- configs/data/mix.yaml +25 -0
- configs/model/1b.yaml +13 -0
- configs/model/3b.yaml +13 -0
- configs/model/7b.yaml +13 -0
- configs/train/schedule.yaml +14 -0
- data/__init__.py +1 -0
- data/dataset.py +92 -0
- data/dedup.py +59 -0
- data/filter.py +135 -0
- data/ingest.py +79 -0
- data/shard.py +95 -0
- docs/COMMANDS.md +84 -0
- docs/flow_llm.mmd +140 -0
- docs/llm_Arch.mmd +202 -0
- eval/__init__.py +1 -0
- eval/benchmarks.py +42 -0
- eval/long_context.py +24 -0
- eval/perplexity.py +39 -0
- eval/regression.py +15 -0
- hf_push.py +40 -0
- model/__init__.py +1 -0
- model/attention.py +76 -0
- model/block.py +37 -0
- model/config.py +48 -0
- model/mlp.py +23 -0
- model/model.py +73 -0
- model/rmsnorm.py +23 -0
- model/rope.py +57 -0
- requirements.txt +12 -26
- sage/__init__.py +0 -15
- sage/cli.py +0 -299
- sage/config.py +0 -48
- sage/data.py +0 -255
- sage/finetune.py +0 -268
- sage/inference.py +0 -171
- sage/memory.py +0 -240
- sage/model.py +0 -267
- sage/optimize.py +0 -164
- sage/train.py +0 -266
- sage/utils.py +0 -143
- sage_single.py +0 -824
- scripts/run_data_pipeline.sh +4 -0
- scripts/run_eval.sh +12 -0
- scripts/run_serve.sh +4 -0
- scripts/run_serve_cpu.sh +4 -0
.gitattributes
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
# Git files
|
| 6 |
.git/*
|
| 7 |
.gitignore
|
| 8 |
-
|
| 9 |
# Python virtual environments
|
| 10 |
.venv/*
|
| 11 |
venv/*
|
|
|
|
| 5 |
# Git files
|
| 6 |
.git/*
|
| 7 |
.gitignore
|
| 8 |
+
hf_push.py
|
| 9 |
# Python virtual environments
|
| 10 |
.venv/*
|
| 11 |
venv/*
|
.gitignore
CHANGED
|
@@ -3,6 +3,8 @@ __pycache__/
|
|
| 3 |
*.py[cod]
|
| 4 |
*$py.class
|
| 5 |
|
|
|
|
|
|
|
| 6 |
# C extensions
|
| 7 |
*.so
|
| 8 |
|
|
@@ -25,8 +27,6 @@ share/python-wheels/
|
|
| 25 |
.installed.cfg
|
| 26 |
*.egg
|
| 27 |
MANIFEST
|
| 28 |
-
hf_push.py
|
| 29 |
-
.hugging_face_ignore
|
| 30 |
|
| 31 |
# PyInstaller
|
| 32 |
# Usually these files are written by a python script, before a-one-file pack
|
|
@@ -112,6 +112,12 @@ celerybeat.pid
|
|
| 112 |
|
| 113 |
# Sage Project Specific
|
| 114 |
checkpoints/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
.venv/
|
| 116 |
.env
|
| 117 |
.DS_Store
|
|
|
|
| 3 |
*.py[cod]
|
| 4 |
*$py.class
|
| 5 |
|
| 6 |
+
wandb/
|
| 7 |
+
|
| 8 |
# C extensions
|
| 9 |
*.so
|
| 10 |
|
|
|
|
| 27 |
.installed.cfg
|
| 28 |
*.egg
|
| 29 |
MANIFEST
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# PyInstaller
|
| 32 |
# Usually these files are written by a python script, before a-one-file pack
|
|
|
|
| 112 |
|
| 113 |
# Sage Project Specific
|
| 114 |
checkpoints/
|
| 115 |
+
runs/
|
| 116 |
+
tokenizer/*.model
|
| 117 |
+
tokenizer/*.vocab
|
| 118 |
+
tokenizer/training_corpus.txt
|
| 119 |
+
data/raw/
|
| 120 |
+
data/processed/
|
| 121 |
.venv/
|
| 122 |
.env
|
| 123 |
.DS_Store
|
README.md
CHANGED
|
@@ -1,163 +1,455 @@
|
|
| 1 |
-
# SAGE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
- **Mixture of Experts (MoE)**: Efficient scaling with a learned router selecting top-k experts per token.
|
| 18 |
-
- **Rotary Positional Embeddings (RoPE)**: Enhanced long-sequence generalization.
|
| 19 |
-
- **KV-Cache Inference**: O(1) time-per-token generation for high-speed response.
|
| 20 |
-
- **Retrieval-Augmented Generation (RAG)**: Integration with FAISS for document-based context lookup.
|
| 21 |
-
- **Efficient Fine-Tuning**: Support for LoRA and instruction tuning with loss masking.
|
| 22 |
-
- **Post-Training Quantization**: INT8 support to reduce memory footprint.
|
| 23 |
-
- **Interactive CLI**: A full REPL (Read-Eval-Print Loop) for chatting and system management.
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
-
```
|
| 30 |
-
|
| 31 |
-
├── model.py # Core architecture (Transformer, MoE, RoPE, Attention)
|
| 32 |
-
├── data.py # Tokenization (tiktoken) & Streaming Datasets (HuggingFace)
|
| 33 |
-
├── train.py # Pre-training loop with AdamW, AMP, and Cosine Decay
|
| 34 |
-
├── inference.py # Text generation (Greedy, Temp, Top-k, Top-p sampling)
|
| 35 |
-
├── finetune.py # LoRA implementation & Instruction Tuning
|
| 36 |
-
├── optimize.py # INT8 Quantization & Pruning utilities
|
| 37 |
-
├── memory.py # RAG Vector Store & Conversation History
|
| 38 |
-
├── cli.py # Interactive Terminal Interface
|
| 39 |
-
├── utils.py # Logging, Checkpointing, and Helper functions
|
| 40 |
-
├── config.py # Central Hyperparameter Configuration
|
| 41 |
-
└── requirements.txt # System dependencies
|
| 42 |
-
sage_single.py # Consolidated single-file version for easy portability
|
| 43 |
```
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
| 51 |
|
| 52 |
```bash
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
git clone https://huggingface.co/sage002/sage
|
| 59 |
-
cd sage
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
```
|
| 64 |
|
| 65 |
-
|
| 66 |
-
- **PyTorch**: Core deep learning framework.
|
| 67 |
-
- **tiktoken**: Fast BPE tokenization (OpenAI's cl100k_base).
|
| 68 |
-
- **datasets**: For streaming training data from HuggingFace.
|
| 69 |
-
- **faiss-cpu**: For vector-based retrieval (RAG).
|
| 70 |
-
- **tqdm**: Progress bars for training.
|
| 71 |
-
- **bitsandbytes**: (Optional) For advanced quantization.
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
##
|
| 76 |
|
| 77 |
-
|
| 78 |
-
You can run the modular version or the single-file version:
|
| 79 |
|
| 80 |
```bash
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
```
|
| 87 |
|
| 88 |
-
###
|
| 89 |
-
Once launched, simply type your message to chat with SAGE. The system uses a rolling conversation history to maintain context.
|
| 90 |
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
|
| 95 |
-
|
| 96 |
|
| 97 |
-
### Non-Interactive "One-Liner" Commands
|
| 98 |
-
If you want to bypass the chat interface and just run a training job, pass the command as a CLI argument:
|
| 99 |
```bash
|
| 100 |
-
|
| 101 |
-
python sage_single.py --finetune 200 # Instruction-tune for 200 steps
|
| 102 |
-
python sage_single.py --quantize # Apply INT8 quantization
|
| 103 |
```
|
| 104 |
|
| 105 |
-
|
| 106 |
-
If you are inside the chat interface, use the slash commands:
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
|
| 112 |
-
###
|
| 113 |
-
Perform instruction fine-tuning using LoRA adapters.
|
| 114 |
-
- `/finetune 200` — Trains on instruction/response pairs and merges weights.
|
| 115 |
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
| `/rag on` | Enable Retrieval-Augmented Generation. |
|
| 126 |
-
| `/rag add <text>` | Add new knowledge to SAGE's retrieval database. |
|
| 127 |
-
| `/clear` | Clear the current conversation history. |
|
| 128 |
-
| `/help` | Show the list of available commands. |
|
| 129 |
-
| `/exit` | Exit the program cleanly. |
|
| 130 |
|
| 131 |
-
|
| 132 |
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
##
|
| 136 |
-
SAGE swaps standard FFN layers for MoE blocks. Each block contains 4 experts, where exactly 2 are activated per token via a learned linear router. This allows for higher total capacity without increasing the computational cost per token.
|
| 137 |
|
| 138 |
-
|
| 139 |
-
Positions are encoded via complex-valued rotations of query and key vectors. This allows SAGE to better handle sequences longer than what it was trained on compared to absolute position embeddings.
|
| 140 |
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
|
| 155 |
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
|
| 160 |
-
|
| 161 |
-
SAGE is an experimental engine. While architecturally complete, the quality of generated responses depends heavily on the amount of training data and compute steps provided.
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAGE 1B
|
| 2 |
+
|
| 3 |
+
SAGE is a root-level rewrite of this repository into a production-style dense language model project. The current baseline is a 1B-class decoder-only transformer with RMSNorm, RoPE, grouped-query attention, SwiGLU, SentencePiece, resumable training, Parquet-backed datasets, and FastAPI serving.
|
| 4 |
+
|
| 5 |
+
This README is written as a practical operator guide. It tells you:
|
| 6 |
+
|
| 7 |
+
- what the project contains
|
| 8 |
+
- what is already implemented
|
| 9 |
+
- what commands to run
|
| 10 |
+
- what files are inputs and outputs
|
| 11 |
+
- what parts are scaffolding versus fully wired
|
| 12 |
+
|
| 13 |
+
## What SAGE Is
|
| 14 |
+
|
| 15 |
+
SAGE is organized into these layers:
|
| 16 |
+
|
| 17 |
+
1. `tokenizer/`
|
| 18 |
+
Trains and validates a SentencePiece tokenizer.
|
| 19 |
+
2. `data/`
|
| 20 |
+
Handles raw corpus ingest, filtering, deduplication, sharding, and packed datasets.
|
| 21 |
+
3. `model/`
|
| 22 |
+
Implements the dense decoder-only transformer.
|
| 23 |
+
4. `train/`
|
| 24 |
+
Handles optimizer setup, scheduler, hardware detection, checkpoints, and the training loop.
|
| 25 |
+
5. `eval/`
|
| 26 |
+
Provides perplexity evaluation and benchmark harness registration.
|
| 27 |
+
6. `serve/`
|
| 28 |
+
Exposes FastAPI servers and quantization helpers.
|
| 29 |
+
|
| 30 |
+
## Current Baseline
|
| 31 |
+
|
| 32 |
+
| Component | Value |
|
| 33 |
+
| --- | --- |
|
| 34 |
+
| Layers | 24 |
|
| 35 |
+
| d_model | 2048 |
|
| 36 |
+
| Attention heads | 16 |
|
| 37 |
+
| KV heads | 8 |
|
| 38 |
+
| Head dim | 128 |
|
| 39 |
+
| FFN dim | 5632 |
|
| 40 |
+
| Context length | 4096 |
|
| 41 |
+
| Vocab size | 50000 |
|
| 42 |
+
| Norm | RMSNorm |
|
| 43 |
+
| Positional encoding | RoPE |
|
| 44 |
+
| Attention | GQA + SDPA |
|
| 45 |
+
| Activation | SwiGLU |
|
| 46 |
+
| Weight tying | Enabled |
|
| 47 |
+
|
| 48 |
+
## Repository Layout
|
| 49 |
|
| 50 |
+
```text
|
| 51 |
+
configs/
|
| 52 |
+
model/ model YAMLs for 1B, 3B, 7B
|
| 53 |
+
data/ corpus mix and shard config
|
| 54 |
+
train/ LR, checkpoint, and logging schedule
|
| 55 |
+
data/
|
| 56 |
+
ingest.py raw source registry and streaming helpers
|
| 57 |
+
filter.py license/lang/PII/safety/quality filtering
|
| 58 |
+
dedup.py exact and near-duplicate removal
|
| 59 |
+
shard.py tokenization + parquet shard writing + manifest
|
| 60 |
+
dataset.py packed iterable dataset with resume skip()
|
| 61 |
+
tokenizer/
|
| 62 |
+
train_tokenizer.py
|
| 63 |
+
validate_tokenizer.py
|
| 64 |
+
model/
|
| 65 |
+
config.py
|
| 66 |
+
rmsnorm.py
|
| 67 |
+
rope.py
|
| 68 |
+
attention.py
|
| 69 |
+
mlp.py
|
| 70 |
+
block.py
|
| 71 |
+
model.py
|
| 72 |
+
train/
|
| 73 |
+
loss.py
|
| 74 |
+
optimizer.py
|
| 75 |
+
checkpoint.py
|
| 76 |
+
distributed.py
|
| 77 |
+
hardware.py
|
| 78 |
+
trainer.py
|
| 79 |
+
eval/
|
| 80 |
+
perplexity.py
|
| 81 |
+
benchmarks.py
|
| 82 |
+
long_context.py
|
| 83 |
+
regression.py
|
| 84 |
+
serve/
|
| 85 |
+
kv_cache.py
|
| 86 |
+
quantize.py
|
| 87 |
+
server.py
|
| 88 |
+
server_cpu.py
|
| 89 |
+
scripts/
|
| 90 |
+
run_data_pipeline.sh
|
| 91 |
+
run_training.sh
|
| 92 |
+
run_eval.sh
|
| 93 |
+
run_serve.sh
|
| 94 |
+
run_serve_cpu.sh
|
| 95 |
+
run_validate_tokenizer.sh
|
| 96 |
+
tests/
|
| 97 |
+
```
|
| 98 |
|
| 99 |
+
## What Is Fully Working vs. What Is Scaffolded
|
| 100 |
|
| 101 |
+
### Working now
|
| 102 |
|
| 103 |
+
- tokenizer training
|
| 104 |
+
- tokenizer validation
|
| 105 |
+
- data filtering and dedup helpers
|
| 106 |
+
- packed dataset logic
|
| 107 |
+
- dense transformer forward pass
|
| 108 |
+
- checkpoint save and resume
|
| 109 |
+
- hardware detection
|
| 110 |
+
- trainer entrypoint
|
| 111 |
+
- FastAPI health and basic generate endpoint
|
| 112 |
+
- unit and smoke tests
|
| 113 |
|
| 114 |
+
### Scaffolded but not yet a full production runner
|
| 115 |
|
| 116 |
+
- benchmark execution against downloaded external datasets
|
| 117 |
+
- a single end-to-end corpus build command that downloads and preprocesses public corpora automatically
|
| 118 |
+
- production-grade multi-node launch tooling
|
| 119 |
+
- real llama.cpp server wiring beyond availability checks
|
| 120 |
|
| 121 |
+
That means the core codebase is real, but you still need to provide your own corpus files and Parquet shards before running a training job.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
## Install
|
| 124 |
|
| 125 |
+
Create and activate a virtual environment, then install dependencies:
|
| 126 |
|
| 127 |
+
```bash
|
| 128 |
+
pip install -r requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
```
|
| 130 |
|
| 131 |
+
Recommended optional extras:
|
| 132 |
+
|
| 133 |
+
- `sentencepiece` is required for tokenizer training and validation
|
| 134 |
+
- `bitsandbytes` is useful for 8-bit experiments
|
| 135 |
+
- `llama.cpp` or `llama-cpp-python` is needed for the CPU serving path
|
| 136 |
+
|
| 137 |
+
## Quick Start
|
| 138 |
|
| 139 |
+
If you want the shortest path to verifying the repo:
|
| 140 |
|
| 141 |
+
1. Install dependencies.
|
| 142 |
+
2. Run tests.
|
| 143 |
+
3. Start the FastAPI server.
|
| 144 |
|
| 145 |
```bash
|
| 146 |
+
pytest -q
|
| 147 |
+
uvicorn serve.server:app --host 127.0.0.1 --port 8000
|
| 148 |
+
```
|
| 149 |
|
| 150 |
+
Then check:
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
```bash
|
| 153 |
+
curl http://127.0.0.1:8000/health
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
## Command Reference
|
| 157 |
+
|
| 158 |
+
The detailed command guide is in [docs/COMMANDS.md](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/docs/COMMANDS.md:1). The most important commands are below.
|
| 159 |
+
|
| 160 |
+
### 1. Train tokenizer
|
| 161 |
+
|
| 162 |
+
Cross-platform Python command:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
python -m tokenizer.train_tokenizer \
|
| 166 |
+
--input data/raw/general_web.txt data/raw/code.txt \
|
| 167 |
+
--model-prefix tokenizer/tokenizer \
|
| 168 |
+
--vocab-size 50000
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
Linux/macOS/WSL wrapper:
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
bash scripts/run_data_pipeline.sh \
|
| 175 |
+
--input data/raw/general_web.txt data/raw/code.txt \
|
| 176 |
+
--model-prefix tokenizer/tokenizer \
|
| 177 |
+
--vocab-size 50000
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Outputs:
|
| 181 |
+
|
| 182 |
+
- `tokenizer/tokenizer.model`
|
| 183 |
+
- `tokenizer/tokenizer.vocab`
|
| 184 |
+
- `tokenizer/training_corpus.txt`
|
| 185 |
+
|
| 186 |
+
### 2. Validate tokenizer
|
| 187 |
+
|
| 188 |
+
```bash
|
| 189 |
+
python - <<'PY'
|
| 190 |
+
from tokenizer.validate_tokenizer import validate_model_file
|
| 191 |
+
validate_model_file("tokenizer/tokenizer.model")
|
| 192 |
+
print("tokenizer ok")
|
| 193 |
+
PY
|
| 194 |
```
|
| 195 |
|
| 196 |
+
Or:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
```bash
|
| 199 |
+
bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model
|
| 200 |
+
```
|
| 201 |
|
| 202 |
+
### 3. Train the model
|
| 203 |
|
| 204 |
+
Training expects existing Parquet shards. Example:
|
|
|
|
| 205 |
|
| 206 |
```bash
|
| 207 |
+
python -m train.trainer \
|
| 208 |
+
--model-config configs/model/1b.yaml \
|
| 209 |
+
--schedule-config configs/train/schedule.yaml \
|
| 210 |
+
--train-shards data/processed/shard-00000.parquet data/processed/shard-00001.parquet \
|
| 211 |
+
--validation-shards data/processed/shard-00002.parquet \
|
| 212 |
+
--output-dir runs/sage-1b
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
Useful options:
|
| 216 |
|
| 217 |
+
- `--steps 100` for a short smoke run
|
| 218 |
+
- `--disable-wandb` to disable offline W&B logging
|
| 219 |
+
|
| 220 |
+
Example smoke run:
|
| 221 |
+
|
| 222 |
+
```bash
|
| 223 |
+
python -m train.trainer \
|
| 224 |
+
--train-shards data/processed/shard-00000.parquet \
|
| 225 |
+
--validation-shards data/processed/shard-00001.parquet \
|
| 226 |
+
--output-dir runs/smoke \
|
| 227 |
+
--steps 20 \
|
| 228 |
+
--disable-wandb
|
| 229 |
```
|
| 230 |
|
| 231 |
+
### 4. Run evaluation harness
|
|
|
|
| 232 |
|
| 233 |
+
```bash
|
| 234 |
+
bash scripts/run_eval.sh
|
| 235 |
+
```
|
| 236 |
|
| 237 |
+
This currently prints the registered benchmark surfaces. It is a harness check, not a full benchmark download-and-run pipeline.
|
| 238 |
|
| 239 |
+
### 5. Start the GPU server
|
| 240 |
|
|
|
|
|
|
|
| 241 |
```bash
|
| 242 |
+
uvicorn serve.server:app --host 0.0.0.0 --port 8000
|
|
|
|
|
|
|
| 243 |
```
|
| 244 |
|
| 245 |
+
Or:
|
|
|
|
| 246 |
|
| 247 |
+
```bash
|
| 248 |
+
bash scripts/run_serve.sh
|
| 249 |
+
```
|
| 250 |
|
| 251 |
+
### 6. Start the CPU server
|
|
|
|
|
|
|
| 252 |
|
| 253 |
+
```bash
|
| 254 |
+
uvicorn serve.server_cpu:app --host 0.0.0.0 --port 8001
|
| 255 |
+
```
|
| 256 |
|
| 257 |
+
Or:
|
| 258 |
|
| 259 |
+
```bash
|
| 260 |
+
bash scripts/run_serve_cpu.sh
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
### 7. Call the generate endpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
The current server takes token IDs directly, not raw text strings.
|
| 266 |
|
| 267 |
+
```bash
|
| 268 |
+
curl -X POST http://127.0.0.1:8000/generate \
|
| 269 |
+
-H "Content-Type: application/json" \
|
| 270 |
+
-d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
Response shape:
|
| 274 |
+
|
| 275 |
+
```json
|
| 276 |
+
{
|
| 277 |
+
"tokens": [1, 42, 99, 123, 456]
|
| 278 |
+
}
|
| 279 |
+
```
|
| 280 |
|
| 281 |
+
## How Training Works
|
|
|
|
| 282 |
|
| 283 |
+
The training flow is:
|
|
|
|
| 284 |
|
| 285 |
+
1. load model config from `configs/model/*.yaml`
|
| 286 |
+
2. load schedule config from `configs/train/schedule.yaml`
|
| 287 |
+
3. detect hardware in `train/hardware.py`
|
| 288 |
+
4. build optimizer and cosine scheduler
|
| 289 |
+
5. load latest checkpoint if one exists
|
| 290 |
+
6. call `PackedDataset.skip()` so resume does not replay already-trained batches
|
| 291 |
+
7. run forward/backward with autocast on CUDA or MPS
|
| 292 |
+
8. clip gradients
|
| 293 |
+
9. log metrics to `metrics.jsonl` and optionally offline W&B
|
| 294 |
+
10. run validation perplexity at eval intervals
|
| 295 |
+
11. save checkpoint every configured interval
|
| 296 |
|
| 297 |
+
Important output files during training:
|
| 298 |
|
| 299 |
+
- `runs/<run-name>/metrics.jsonl`
|
| 300 |
+
- `runs/<run-name>/ckpt_step_0001000.pt`
|
| 301 |
+
- later checkpoints in the same folder
|
| 302 |
+
|
| 303 |
+
## How Data Is Expected to Look
|
| 304 |
+
|
| 305 |
+
### Raw text files for tokenizer training
|
| 306 |
+
|
| 307 |
+
Simple UTF-8 text files are enough:
|
| 308 |
+
|
| 309 |
+
```text
|
| 310 |
+
This is a training document.
|
| 311 |
+
This is another one.
|
| 312 |
+
```
|
| 313 |
|
| 314 |
+
### Raw JSONL records for ingest/filter work
|
| 315 |
+
|
| 316 |
+
The ingest layer assumes records like:
|
| 317 |
+
|
| 318 |
+
```json
|
| 319 |
+
{"text": "example text"}
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
### Processed Parquet shards for training
|
| 323 |
+
|
| 324 |
+
The trainer expects Parquet rows with at least:
|
| 325 |
+
|
| 326 |
+
- `tokens`
|
| 327 |
+
- `split`
|
| 328 |
+
|
| 329 |
+
The sharding helper writes:
|
| 330 |
+
|
| 331 |
+
- `id`
|
| 332 |
+
- `text`
|
| 333 |
+
- `tokens`
|
| 334 |
+
- `domain_tag`
|
| 335 |
+
- `quality_tier`
|
| 336 |
+
- `lang`
|
| 337 |
+
- `token_count`
|
| 338 |
+
- `split`
|
| 339 |
+
|
| 340 |
+
## Main Config Files
|
| 341 |
+
|
| 342 |
+
### [configs/model/1b.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/model/1b.yaml:1)
|
| 343 |
+
|
| 344 |
+
Controls the model shape:
|
| 345 |
+
|
| 346 |
+
- layers
|
| 347 |
+
- hidden size
|
| 348 |
+
- heads
|
| 349 |
+
- KV heads
|
| 350 |
+
- FFN size
|
| 351 |
+
- vocab size
|
| 352 |
+
- context length
|
| 353 |
+
|
| 354 |
+
### [configs/data/mix.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/data/mix.yaml:1)
|
| 355 |
+
|
| 356 |
+
Controls corpus weights and split ratios.
|
| 357 |
+
|
| 358 |
+
### [configs/train/schedule.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/train/schedule.yaml:1)
|
| 359 |
+
|
| 360 |
+
Controls:
|
| 361 |
+
|
| 362 |
+
- total token target
|
| 363 |
+
- LR schedule
|
| 364 |
+
- warmup
|
| 365 |
+
- checkpoint interval
|
| 366 |
+
- log interval
|
| 367 |
+
- eval interval
|
| 368 |
+
|
| 369 |
+
## Common Workflows
|
| 370 |
+
|
| 371 |
+
### Workflow A: verify the repo
|
| 372 |
+
|
| 373 |
+
```bash
|
| 374 |
+
pip install -r requirements.txt
|
| 375 |
+
pytest -q
|
| 376 |
+
```
|
| 377 |
+
|
| 378 |
+
### Workflow B: train tokenizer only
|
| 379 |
+
|
| 380 |
+
```bash
|
| 381 |
+
python -m tokenizer.train_tokenizer --input data/raw/general_web.txt --model-prefix tokenizer/tokenizer
|
| 382 |
+
python - <<'PY'
|
| 383 |
+
from tokenizer.validate_tokenizer import validate_model_file
|
| 384 |
+
validate_model_file("tokenizer/tokenizer.model")
|
| 385 |
+
print("ok")
|
| 386 |
+
PY
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
### Workflow C: smoke-train on local shards
|
| 390 |
+
|
| 391 |
+
```bash
|
| 392 |
+
python -m train.trainer \
|
| 393 |
+
--train-shards data/processed/shard-00000.parquet \
|
| 394 |
+
--validation-shards data/processed/shard-00001.parquet \
|
| 395 |
+
--output-dir runs/smoke \
|
| 396 |
+
--steps 20 \
|
| 397 |
+
--disable-wandb
|
| 398 |
+
```
|
| 399 |
+
|
| 400 |
+
### Workflow D: serve locally
|
| 401 |
+
|
| 402 |
+
```bash
|
| 403 |
+
uvicorn serve.server:app --host 127.0.0.1 --port 8000
|
| 404 |
+
curl http://127.0.0.1:8000/health
|
| 405 |
+
```
|
| 406 |
+
|
| 407 |
+
## Troubleshooting
|
| 408 |
+
|
| 409 |
+
### `No training shards provided`
|
| 410 |
+
|
| 411 |
+
You launched the trainer without `--train-shards`. The trainer is working as designed, but it needs Parquet shard paths.
|
| 412 |
+
|
| 413 |
+
### `ModuleNotFoundError: sentencepiece`
|
| 414 |
+
|
| 415 |
+
Install dependencies:
|
| 416 |
+
|
| 417 |
+
```bash
|
| 418 |
+
pip install -r requirements.txt
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
### FastAPI starts but generate is not useful
|
| 422 |
+
|
| 423 |
+
That is expected right now if you have not trained or loaded a checkpoint. The server instantiates the model architecture, but it does not yet load a trained checkpoint automatically.
|
| 424 |
+
|
| 425 |
+
### CPU server says llama.cpp is unavailable
|
| 426 |
+
|
| 427 |
+
Install `llama.cpp` or `llama-cpp-python`. The current CPU server is a readiness surface, not a bundled llama.cpp runtime.
|
| 428 |
+
|
| 429 |
+
## Tests
|
| 430 |
+
|
| 431 |
+
Run the full suite:
|
| 432 |
+
|
| 433 |
+
```bash
|
| 434 |
+
pytest -q
|
| 435 |
+
```
|
| 436 |
|
| 437 |
+
Coverage areas:
|
| 438 |
|
| 439 |
+
- tokenizer roundtrip validation
|
| 440 |
+
- model shapes
|
| 441 |
+
- attention math
|
| 442 |
+
- data filtering and packing
|
| 443 |
+
- checkpoint roundtrip
|
| 444 |
+
- hardware summaries
|
| 445 |
+
- FastAPI health endpoints
|
| 446 |
|
| 447 |
+
## Next Practical Step
|
| 448 |
|
| 449 |
+
If you want the fastest real progress from here, the next step is:
|
|
|
|
| 450 |
|
| 451 |
+
1. prepare a small local corpus
|
| 452 |
+
2. train the tokenizer
|
| 453 |
+
3. write Parquet shards with `data/shard.py`
|
| 454 |
+
4. run a `--steps 20` smoke training job
|
| 455 |
+
5. only then start extending benchmark or serving behavior
|
SAGE_KAGGLE_GUIDE.md
DELETED
|
@@ -1,146 +0,0 @@
|
|
| 1 |
-
# 🪐 SAGE: Kaggle & Colab Quickstart Guide
|
| 2 |
-
|
| 3 |
-
Welcome to the **Self-Adaptive General Engine (SAGE)**. This guide will help you get SAGE v2 running on a cloud environment (like Kaggle's 2x T4 or Google Colab) in under 5 minutes.
|
| 4 |
-
|
| 5 |
-
---
|
| 6 |
-
|
| 7 |
-
## 🛠️ Step 1: Environment Setup
|
| 8 |
-
|
| 9 |
-
Run this cell first to install dependencies and fix any common binary incompatibilities (like the Numpy/Torch mismatch).
|
| 10 |
-
|
| 11 |
-
```python
|
| 12 |
-
# Install PyTorch 2.1 with CUDA 12.1 (supports Tesla P100 sm_60)
|
| 13 |
-
!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
| 14 |
-
|
| 15 |
-
# Install other dependencies
|
| 16 |
-
!pip install "numpy<2.0.0" --force-reinstall
|
| 17 |
-
!pip install bitsandbytes tqdm tiktoken faiss-cpu datasets wandb --upgrade
|
| 18 |
-
|
| 19 |
-
print("✅ Environment ready. Please RESTART YOUR KERNEL now if this is your first run.")
|
| 20 |
-
```
|
| 21 |
-
|
| 22 |
-
---
|
| 23 |
-
|
| 24 |
-
## 🔑 Step 2: Weights & Biases Logging (Optional but Recommended)
|
| 25 |
-
|
| 26 |
-
To track your training progress with professional charts:
|
| 27 |
-
|
| 28 |
-
1. Get your API Key from [wandb.ai/authorize](https://wandb.ai/authorize).
|
| 29 |
-
2. Add it to your Kaggle **Secrets** with the label `WANDB_API_KEY`.
|
| 30 |
-
3. Run this:
|
| 31 |
-
|
| 32 |
-
```python
|
| 33 |
-
import wandb
|
| 34 |
-
from kaggle_secrets import UserSecretsClient
|
| 35 |
-
try:
|
| 36 |
-
user_secrets = UserSecretsClient()
|
| 37 |
-
wandb.login(key=user_secrets.get_secret("WANDB_API_KEY"))
|
| 38 |
-
except:
|
| 39 |
-
import os
|
| 40 |
-
os.environ["WANDB_MODE"] = "offline"
|
| 41 |
-
print("⚠️ W&B Secret not found. Running in offline mode.")
|
| 42 |
-
```
|
| 43 |
-
|
| 44 |
-
---
|
| 45 |
-
|
| 46 |
-
## 💬 Step 3: Launch the SAGE Chat Interface
|
| 47 |
-
|
| 48 |
-
This is a premium, multi-GPU enabled chat widget. Paste this into a cell to start interacting with SAGE.
|
| 49 |
-
|
| 50 |
-
**Note:** SAGE automatically detects GPU compatibility and falls back to CPU if needed.
|
| 51 |
-
|
| 52 |
-
```python
|
| 53 |
-
import sys, os, torch, random
|
| 54 |
-
import torch.nn as nn
|
| 55 |
-
import ipywidgets as widgets
|
| 56 |
-
from IPython.display import display, HTML
|
| 57 |
-
|
| 58 |
-
# Verify SAGE is accessible (debugging import issues)
|
| 59 |
-
if not os.path.exists('sage/__init__.py'):
|
| 60 |
-
print("❌ ERROR: sage/ folder not found in current directory!")
|
| 61 |
-
print(" Make sure you've cloned the repo: !git clone https://github.com/er-del/sage.git")
|
| 62 |
-
raise ImportError("sage module not found")
|
| 63 |
-
|
| 64 |
-
# Add current directory to path if needed
|
| 65 |
-
if '' not in sys.path and '.' not in sys.path:
|
| 66 |
-
sys.path.insert(0, '')
|
| 67 |
-
|
| 68 |
-
# Import SAGE
|
| 69 |
-
from sage import SageModel, SageConfig, SageTokenizer, generate, ConversationHistory, train as train_model, finetune
|
| 70 |
-
from sage import __version__ as sage_version
|
| 71 |
-
|
| 72 |
-
# Verify import worked
|
| 73 |
-
print(f"✅ SAGE v{sage_version} loaded successfully")
|
| 74 |
-
|
| 75 |
-
# -- Initialization --
|
| 76 |
-
config = SageConfig()
|
| 77 |
-
# Note: config.device automatically checks GPU compatibility and falls back to CPU if needed
|
| 78 |
-
device = config.device
|
| 79 |
-
print(f"🖥️ Using device: {device}")
|
| 80 |
-
|
| 81 |
-
tokenizer = SageTokenizer()
|
| 82 |
-
history = ConversationHistory(tokenizer, max_tokens=1024)
|
| 83 |
-
model = SageModel(config)
|
| 84 |
-
|
| 85 |
-
# -- Multi-GPU Logic (only if CUDA is actually being used) --
|
| 86 |
-
if device.type == "cuda":
|
| 87 |
-
gpu_count = torch.cuda.device_count()
|
| 88 |
-
if gpu_count > 1:
|
| 89 |
-
print(f"🚀 Multi-GPU active: {gpu_count} GPUs.")
|
| 90 |
-
model = nn.DataParallel(model)
|
| 91 |
-
model = model.to(device)
|
| 92 |
-
|
| 93 |
-
# -- Load Weights --
|
| 94 |
-
ckpt_path = "checkpoints/sage_latest.pt"
|
| 95 |
-
if os.path.exists(ckpt_path):
|
| 96 |
-
base_model = getattr(model, "module", model)
|
| 97 |
-
ckpt = torch.load(ckpt_path, map_location=device)
|
| 98 |
-
base_model.load_state_dict(ckpt['model_state_dict'])
|
| 99 |
-
print("✅ Weights loaded from checkpoint.")
|
| 100 |
-
else:
|
| 101 |
-
print("⚠️ RANDOM WEIGHTS (Type /train <steps> to begin learning).")
|
| 102 |
-
|
| 103 |
-
# -- Render UI --
|
| 104 |
-
chat_display = widgets.Output(layout={'border': '1px solid #444', 'height': '450px', 'overflow_y': 'scroll', 'padding': '10px'})
|
| 105 |
-
text_input = widgets.Text(placeholder="Chat or type /train 1000...", layout={'width': '80%'})
|
| 106 |
-
send_button = widgets.Button(description="Send", button_style='primary', layout={'width': '18%'})
|
| 107 |
-
display(HTML("<style>.user-msg { background: #2b2d42; color: #fff; padding: 10px; border-radius: 10px; margin: 5px; border-left: 5px solid #ef233c; } .sage-msg { background: #1a1b2e; color: #fff; padding: 10px; border-radius: 10px; margin: 5px; border-left: 5px solid #4cc9f0; }</style>"))
|
| 108 |
-
|
| 109 |
-
def on_send(_=None):
|
| 110 |
-
user_text = text_input.value.strip()
|
| 111 |
-
if not user_text: return
|
| 112 |
-
text_input.value = ""
|
| 113 |
-
with chat_display:
|
| 114 |
-
if user_text.startswith("/train"):
|
| 115 |
-
steps = int(user_text.split()[1]) if len(user_text.split()) > 1 else 100
|
| 116 |
-
print(f"🚀 TRAINING STARTING ({steps} steps)...")
|
| 117 |
-
train_model(model, config, total_steps=steps)
|
| 118 |
-
print("✅ DONE.")
|
| 119 |
-
return
|
| 120 |
-
display(HTML(f'<div class="user-msg"><b>You:</b> {user_text}</div>'))
|
| 121 |
-
response = generate(model, tokenizer, history.build_prompt(user_text), stream=False)
|
| 122 |
-
res = response.split("SAGE:")[-1].split("</response>")[0].replace("<response>", "").strip()
|
| 123 |
-
history.add("user", user_text); history.add("assistant", res)
|
| 124 |
-
display(HTML(f'<div class="sage-msg"><b>SAGE:</b> {res}</div>'))
|
| 125 |
-
|
| 126 |
-
text_input.on_submit(on_send); send_button.on_click(lambda b: on_send())
|
| 127 |
-
display(chat_display, widgets.HBox([text_input, send_button]))
|
| 128 |
-
```
|
| 129 |
-
|
| 130 |
-
---
|
| 131 |
-
|
| 132 |
-
## 🎮 Command Cheat Sheet
|
| 133 |
-
|
| 134 |
-
| Command | Action |
|
| 135 |
-
| :--- | :--- |
|
| 136 |
-
| `/train <steps>` | Starts pre-training (Base knowledge). Recommended: 5000+ |
|
| 137 |
-
| `/clear` | Resets the conversation history. |
|
| 138 |
-
| `/finetune <steps>` | (Coming Soon) Starts instruction fine-tuning. |
|
| 139 |
-
|
| 140 |
-
---
|
| 141 |
-
|
| 142 |
-
## 💡 Pro Tips for T4 GPUs
|
| 143 |
-
|
| 144 |
-
1. **Batch Size**: The default `batch_size=4` with `gradient_accumulation=16` is perfect for a 2x T4 setup (32GB VRAM total).
|
| 145 |
-
2. **Persistence**: Kaggle outputs are deleted when the session ends. Make sure to **download** the `checkpoints/` folder or sync it to **Hugging Face** regularly.
|
| 146 |
-
3. **Patience**: Loss will fluctuate. Look for a steady downward trend on your W&B dashboard!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAGE_V3_ROADMAP.md
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
# SAGE v3: The "Long-Vision" Roadmap 🗺️
|
| 2 |
-
|
| 3 |
-
This document outlines the high-impact architectural upgrades that will transform SAGE into a multi-thousand token reasoning assistant with multimedia capabilities.
|
| 4 |
-
|
| 5 |
-
---
|
| 6 |
-
|
| 7 |
-
## 🏗️ 1. Long-Context Scaling (YaRN / RoPE-Interpolation)
|
| 8 |
-
|
| 9 |
-
**Goal**: Increase SAGE's maximum comprehension from 1,024 to **4,096+ tokens**.
|
| 10 |
-
|
| 11 |
-
### Technical Strategy:
|
| 12 |
-
Currently, our `freqs_cis` are precomputed for a fixed window. In v3, we will implement **NTK-Aware Interpolation**.
|
| 13 |
-
- **Implementation**: We will add a `scaling_factor` to `SageConfig`.
|
| 14 |
-
- **Logic**: During inference, if the sequence length exceeds the original training window, we will "stretch" the rotary frequencies dynamically rather than letting them overflow.
|
| 15 |
-
- **Benefit**: SAGE can read entire source code files or long essays without "losing its mind" at the 1,025th token.
|
| 16 |
-
|
| 17 |
-
---
|
| 18 |
-
|
| 19 |
-
## 📂 2. Interactive RAG (Kaggle UI Integration)
|
| 20 |
-
|
| 21 |
-
**Goal**: Allow users to "Upload and Chat" with any file instantly in the notebook.
|
| 22 |
-
|
| 23 |
-
### Technical Strategy:
|
| 24 |
-
- **Widget Update**: Add a `widgets.FileUpload` component to the Kaggle chat interface.
|
| 25 |
-
- **Auto-Ingestion**: When a file is uploaded, a background hook will:
|
| 26 |
-
1. Parse the text (PDF, `.py`, `.md`).
|
| 27 |
-
2. Chunk it into 200-token segments.
|
| 28 |
-
3. Generate embeddings and add them to the **FAISS Vector Store**.
|
| 29 |
-
- **Real-time Recall**: SAGE will automatically pull context from these uploaded files using the `retrieve_context` logic we've already built.
|
| 30 |
-
|
| 31 |
-
---
|
| 32 |
-
|
| 33 |
-
## 👁️ 3. Multimodal Foundation (Vision Projection)
|
| 34 |
-
|
| 35 |
-
**Goal**: Let SAGE "see" images.
|
| 36 |
-
|
| 37 |
-
### Technical Strategy:
|
| 38 |
-
Since SAGE is a small, efficient model (133M), it is the perfect candidate for a **Vision-Language Model (VLM)**.
|
| 39 |
-
- **Architecture**: We will add a frozen **CLIP-ViT** image encoder.
|
| 40 |
-
- **The Bridge**: We will implement a `VisionProjector` (a simple 2-layer MLP) that converts CLIP image embeddings (e.g., 768-dim) into SAGE token embeddings (512-dim).
|
| 41 |
-
- **Outcome**: You will be able to provide an image URL and a prompt like "What is in this picture?", and SAGE will respond based on the visual tokens.
|
| 42 |
-
|
| 43 |
-
---
|
| 44 |
-
|
| 45 |
-
## ⚡ 4. Training Stability: LayerNorm Tuning
|
| 46 |
-
|
| 47 |
-
To support these advanced features, we will move to **RMSNorm** (Root Mean Square Layer Normalization) for even faster convergence and better numerical stability on the double-T4 setup.
|
| 48 |
-
|
| 49 |
-
---
|
| 50 |
-
|
| 51 |
-
### Which one first?
|
| 52 |
-
We can begin implementing **RoPE Scaling** immediately to give SAGE a massive context boost without needing new weights. Just let me know when you're ready!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/data/mix.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_sources:
|
| 2 |
+
general_web:
|
| 3 |
+
weight_percent: 55
|
| 4 |
+
quality_tiers: [high, medium]
|
| 5 |
+
code:
|
| 6 |
+
weight_percent: 15
|
| 7 |
+
quality_tiers: [high, medium]
|
| 8 |
+
math_science:
|
| 9 |
+
weight_percent: 12
|
| 10 |
+
quality_tiers: [high, medium]
|
| 11 |
+
books_longform:
|
| 12 |
+
weight_percent: 10
|
| 13 |
+
quality_tiers: [high, medium]
|
| 14 |
+
multilingual:
|
| 15 |
+
weight_percent: 5
|
| 16 |
+
quality_tiers: [high, medium]
|
| 17 |
+
synthetic:
|
| 18 |
+
weight_percent: 3
|
| 19 |
+
quality_tiers: [high]
|
| 20 |
+
splits:
|
| 21 |
+
train: 0.989
|
| 22 |
+
validation: 0.01
|
| 23 |
+
test: 0.001
|
| 24 |
+
shard_size_bytes: 2147483648
|
| 25 |
+
format: parquet
|
configs/model/1b.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sage-1b
|
| 2 |
+
num_layers: 24
|
| 3 |
+
d_model: 2048
|
| 4 |
+
num_attn_heads: 16
|
| 5 |
+
num_kv_heads: 8
|
| 6 |
+
head_dim: 128
|
| 7 |
+
ffn_hidden_dim: 5632
|
| 8 |
+
vocab_size: 50000
|
| 9 |
+
context_length: 4096
|
| 10 |
+
rope_base_frequency: 500000
|
| 11 |
+
rope_scaling_factor: 1.0
|
| 12 |
+
dropout: 0.0
|
| 13 |
+
tie_word_embeddings: true
|
configs/model/3b.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sage-3b
|
| 2 |
+
num_layers: 28
|
| 3 |
+
d_model: 3072
|
| 4 |
+
num_attn_heads: 24
|
| 5 |
+
num_kv_heads: 8
|
| 6 |
+
head_dim: 128
|
| 7 |
+
ffn_hidden_dim: 8192
|
| 8 |
+
vocab_size: 50000
|
| 9 |
+
context_length: 8192
|
| 10 |
+
rope_base_frequency: 500000
|
| 11 |
+
rope_scaling_factor: 1.0
|
| 12 |
+
dropout: 0.0
|
| 13 |
+
tie_word_embeddings: true
|
configs/model/7b.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sage-7b
|
| 2 |
+
num_layers: 32
|
| 3 |
+
d_model: 4096
|
| 4 |
+
num_attn_heads: 32
|
| 5 |
+
num_kv_heads: 8
|
| 6 |
+
head_dim: 128
|
| 7 |
+
ffn_hidden_dim: 11008
|
| 8 |
+
vocab_size: 50000
|
| 9 |
+
context_length: 8192
|
| 10 |
+
rope_base_frequency: 500000
|
| 11 |
+
rope_scaling_factor: 1.0
|
| 12 |
+
dropout: 0.0
|
| 13 |
+
tie_word_embeddings: true
|
configs/train/schedule.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_name: sage-1b-pretrain
|
| 2 |
+
total_tokens: 50000000000
|
| 3 |
+
effective_batch_tokens: 2000000
|
| 4 |
+
peak_learning_rate: 3.0e-4
|
| 5 |
+
min_learning_rate: 3.0e-5
|
| 6 |
+
warmup_steps: 2000
|
| 7 |
+
weight_decay: 0.1
|
| 8 |
+
betas: [0.9, 0.95]
|
| 9 |
+
adam_eps: 1.0e-8
|
| 10 |
+
gradient_clip_norm: 1.0
|
| 11 |
+
checkpoint_interval: 1000
|
| 12 |
+
log_interval: 10
|
| 13 |
+
eval_interval: 1000
|
| 14 |
+
seed: 42
|
data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Data pipeline modules for SAGE."""
|
data/dataset.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Packed training dataset with deterministic resume support."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Iterable, Iterator
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.utils.data import IterableDataset
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import pyarrow.parquet as pq
|
| 15 |
+
except ImportError: # pragma: no cover - optional at import time
|
| 16 |
+
pq = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class DatasetConfig:
|
| 21 |
+
"""Configuration for packing token streams into training batches."""
|
| 22 |
+
|
| 23 |
+
shard_paths: tuple[str, ...]
|
| 24 |
+
context_length: int
|
| 25 |
+
split: str = "train"
|
| 26 |
+
seed: int = 42
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PackedDataset(IterableDataset):
|
| 30 |
+
"""Iterate packed token sequences with document-boundary masks."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, config: DatasetConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.config = config
|
| 35 |
+
self._skip = 0
|
| 36 |
+
|
| 37 |
+
def skip(self, n_batches: int) -> None:
|
| 38 |
+
"""Fast-forward the iterator by discarding the first n batches."""
|
| 39 |
+
self._skip = max(0, int(n_batches))
|
| 40 |
+
|
| 41 |
+
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
| 42 |
+
skipped = 0
|
| 43 |
+
for batch in self._generate():
|
| 44 |
+
if skipped < self._skip:
|
| 45 |
+
skipped += 1
|
| 46 |
+
continue
|
| 47 |
+
yield batch
|
| 48 |
+
|
| 49 |
+
def _generate(self) -> Iterator[dict[str, torch.Tensor]]:
|
| 50 |
+
token_buffer: list[int] = []
|
| 51 |
+
boundary_buffer: list[int] = []
|
| 52 |
+
for row in self._iter_rows():
|
| 53 |
+
tokens = list(row["tokens"])
|
| 54 |
+
if len(tokens) < 2:
|
| 55 |
+
continue
|
| 56 |
+
token_buffer.extend(tokens)
|
| 57 |
+
boundary_buffer.extend([0] * (len(tokens) - 1) + [1])
|
| 58 |
+
while len(token_buffer) >= self.config.context_length + 1:
|
| 59 |
+
window_tokens = token_buffer[: self.config.context_length + 1]
|
| 60 |
+
window_boundaries = boundary_buffer[: self.config.context_length + 1]
|
| 61 |
+
yield pack_sequence(window_tokens, window_boundaries)
|
| 62 |
+
token_buffer = token_buffer[self.config.context_length :]
|
| 63 |
+
boundary_buffer = boundary_buffer[self.config.context_length :]
|
| 64 |
+
|
| 65 |
+
def _iter_rows(self) -> Iterator[dict[str, object]]:
|
| 66 |
+
if pq is None:
|
| 67 |
+
raise ImportError("pyarrow is required to read parquet shards.")
|
| 68 |
+
shard_paths = [Path(path) for path in self.config.shard_paths]
|
| 69 |
+
rng = random.Random(self.config.seed)
|
| 70 |
+
shard_paths = shard_paths[:]
|
| 71 |
+
rng.shuffle(shard_paths)
|
| 72 |
+
for path in shard_paths:
|
| 73 |
+
table = pq.read_table(path, columns=["tokens", "split"])
|
| 74 |
+
rows = table.to_pylist()
|
| 75 |
+
for row in rows:
|
| 76 |
+
if row["split"] != self.config.split:
|
| 77 |
+
continue
|
| 78 |
+
yield row
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def pack_sequence(tokens: list[int], boundaries: list[int]) -> dict[str, torch.Tensor]:
|
| 82 |
+
"""Turn one packed token window into model-ready tensors."""
|
| 83 |
+
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
|
| 84 |
+
labels = torch.tensor(tokens[1:], dtype=torch.long)
|
| 85 |
+
loss_mask = torch.ones_like(input_ids, dtype=torch.float32)
|
| 86 |
+
attention_document_mask = torch.tensor(boundaries[:-1], dtype=torch.long)
|
| 87 |
+
return {
|
| 88 |
+
"input_ids": input_ids,
|
| 89 |
+
"labels": labels,
|
| 90 |
+
"loss_mask": loss_mask,
|
| 91 |
+
"document_boundaries": attention_document_mask,
|
| 92 |
+
}
|
data/dedup.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Exact and near-duplicate detection helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import re
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from typing import Iterable
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TOKEN_RE = re.compile(r"\w+")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def exact_content_hash(text: str) -> str:
|
| 15 |
+
"""Return an exact content hash."""
|
| 16 |
+
return hashlib.sha1(text.encode("utf-8")).hexdigest()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def shingles(text: str, n: int = 5) -> set[str]:
|
| 20 |
+
"""Build token shingles for near-duplicate detection."""
|
| 21 |
+
tokens = TOKEN_RE.findall(text.lower())
|
| 22 |
+
if len(tokens) < n:
|
| 23 |
+
return {" ".join(tokens)} if tokens else set()
|
| 24 |
+
return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def jaccard_similarity(left: str, right: str, n: int = 5) -> float:
|
| 28 |
+
"""Compute shingle-level Jaccard similarity."""
|
| 29 |
+
left_set = shingles(left, n)
|
| 30 |
+
right_set = shingles(right, n)
|
| 31 |
+
if not left_set and not right_set:
|
| 32 |
+
return 1.0
|
| 33 |
+
if not left_set or not right_set:
|
| 34 |
+
return 0.0
|
| 35 |
+
return len(left_set & right_set) / len(left_set | right_set)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def deduplicate_records(records: Iterable[dict[str, object]], near_dup_threshold: float = 0.92) -> list[dict[str, object]]:
|
| 39 |
+
"""Drop exact and near-duplicate records."""
|
| 40 |
+
exact_seen: set[str] = set()
|
| 41 |
+
buckets: dict[str, list[dict[str, object]]] = defaultdict(list)
|
| 42 |
+
kept: list[dict[str, object]] = []
|
| 43 |
+
for record in records:
|
| 44 |
+
text = str(record["text"])
|
| 45 |
+
digest = exact_content_hash(text)
|
| 46 |
+
if digest in exact_seen:
|
| 47 |
+
continue
|
| 48 |
+
signature = digest[:8]
|
| 49 |
+
near_duplicate = False
|
| 50 |
+
for candidate in buckets[signature]:
|
| 51 |
+
if jaccard_similarity(text, str(candidate["text"])) >= near_dup_threshold:
|
| 52 |
+
near_duplicate = True
|
| 53 |
+
break
|
| 54 |
+
if near_duplicate:
|
| 55 |
+
continue
|
| 56 |
+
exact_seen.add(digest)
|
| 57 |
+
buckets[signature].append(record)
|
| 58 |
+
kept.append(record)
|
| 59 |
+
return kept
|
data/filter.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Corpus filtering, safety, and quality heuristics."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Iterable
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
EMAIL_RE = re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE)
|
| 11 |
+
PHONE_RE = re.compile(r"(?:(?:\+?\d{1,3})?[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?){2}\d{4}")
|
| 12 |
+
SSN_RE = re.compile(r"\b\d{3}-\d{2}-\d{4}\b")
|
| 13 |
+
HTML_RE = re.compile(r"<[^>]+>")
|
| 14 |
+
MULTISPACE_RE = re.compile(r"[ \t]+")
|
| 15 |
+
NSFW_TERMS = {"porn", "explicit sex", "rape"}
|
| 16 |
+
HATE_TERMS = {"kill all", "ethnic cleansing"}
|
| 17 |
+
ALLOWED_LICENSES = {"permissive", "restricted"}
|
| 18 |
+
ALLOWED_LANGS = {"en", "es", "fr", "de", "hi", "zh", "ar", "pt"}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class FilterConfig:
|
| 23 |
+
"""Policy controls for the filtering pipeline."""
|
| 24 |
+
|
| 25 |
+
minimum_chars: int = 200
|
| 26 |
+
maximum_chars: int = 200_000
|
| 27 |
+
minimum_alpha_ratio: float = 0.45
|
| 28 |
+
minimum_quality_score: float = 0.20
|
| 29 |
+
language_confidence_threshold: float = 0.65
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def normalize_text(text: str) -> str:
|
| 33 |
+
"""Strip tags and normalize whitespace."""
|
| 34 |
+
text = HTML_RE.sub(" ", text)
|
| 35 |
+
text = MULTISPACE_RE.sub(" ", text)
|
| 36 |
+
return text.strip()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def detect_language(text: str) -> tuple[str, float]:
|
| 40 |
+
"""Use a light heuristic to assign a language code."""
|
| 41 |
+
ascii_ratio = sum(ch.isascii() for ch in text) / max(len(text), 1)
|
| 42 |
+
devanagari = sum("\u0900" <= ch <= "\u097f" for ch in text)
|
| 43 |
+
cjk = sum("\u4e00" <= ch <= "\u9fff" for ch in text)
|
| 44 |
+
arabic = sum("\u0600" <= ch <= "\u06ff" for ch in text)
|
| 45 |
+
if cjk > 8:
|
| 46 |
+
return "zh", 0.95
|
| 47 |
+
if arabic > 8:
|
| 48 |
+
return "ar", 0.95
|
| 49 |
+
if devanagari > 8:
|
| 50 |
+
return "hi", 0.95
|
| 51 |
+
if ascii_ratio > 0.9:
|
| 52 |
+
return "en", 0.80
|
| 53 |
+
return "unknown", 0.40
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def quality_score(text: str) -> float:
|
| 57 |
+
"""Score text using length, punctuation, and alphabetic density."""
|
| 58 |
+
if not text:
|
| 59 |
+
return 0.0
|
| 60 |
+
alpha_ratio = sum(ch.isalpha() for ch in text) / len(text)
|
| 61 |
+
punct_ratio = sum(ch in ".,;:!?()[]{}" for ch in text) / len(text)
|
| 62 |
+
line_count = text.count("\n") + 1
|
| 63 |
+
score = min(len(text) / 4000.0, 1.0) * 0.4 + alpha_ratio * 0.4 + min(punct_ratio * 8.0, 1.0) * 0.2
|
| 64 |
+
if line_count < 2 and len(text) > 10_000:
|
| 65 |
+
score *= 0.85
|
| 66 |
+
return round(score, 4)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def quality_tier(score: float) -> str:
|
| 70 |
+
"""Map a numeric score to a quality tier."""
|
| 71 |
+
if score >= 0.70:
|
| 72 |
+
return "high"
|
| 73 |
+
if score >= 0.40:
|
| 74 |
+
return "medium"
|
| 75 |
+
return "low"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def strip_pii(text: str) -> str:
|
| 79 |
+
"""Mask basic email, phone, and SSN patterns."""
|
| 80 |
+
text = EMAIL_RE.sub("[EMAIL]", text)
|
| 81 |
+
text = PHONE_RE.sub("[PHONE]", text)
|
| 82 |
+
text = SSN_RE.sub("[SSN]", text)
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def passes_safety_filter(text: str) -> bool:
|
| 87 |
+
"""Reject obviously unsafe content with simple keyword checks."""
|
| 88 |
+
lower = text.lower()
|
| 89 |
+
if any(term in lower for term in NSFW_TERMS):
|
| 90 |
+
return False
|
| 91 |
+
if any(term in lower for term in HATE_TERMS):
|
| 92 |
+
return False
|
| 93 |
+
return True
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def license_allowed(category: str) -> bool:
|
| 97 |
+
"""Return whether the source license category is allowed."""
|
| 98 |
+
return category in ALLOWED_LICENSES
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def filter_record(record: dict[str, object], config: FilterConfig = FilterConfig()) -> dict[str, object] | None:
|
| 102 |
+
"""Apply the full filter pipeline to one record."""
|
| 103 |
+
if not license_allowed(str(record.get("license_category", ""))):
|
| 104 |
+
return None
|
| 105 |
+
text = normalize_text(str(record.get("text", "")))
|
| 106 |
+
if not (config.minimum_chars <= len(text) <= config.maximum_chars):
|
| 107 |
+
return None
|
| 108 |
+
lang, confidence = detect_language(text)
|
| 109 |
+
if lang not in ALLOWED_LANGS or confidence < config.language_confidence_threshold:
|
| 110 |
+
return None
|
| 111 |
+
text = strip_pii(text)
|
| 112 |
+
if not passes_safety_filter(text):
|
| 113 |
+
return None
|
| 114 |
+
score = quality_score(text)
|
| 115 |
+
if score < config.minimum_quality_score:
|
| 116 |
+
return None
|
| 117 |
+
return {
|
| 118 |
+
**record,
|
| 119 |
+
"text": text,
|
| 120 |
+
"lang": lang,
|
| 121 |
+
"lang_confidence": confidence,
|
| 122 |
+
"quality_score": score,
|
| 123 |
+
"quality_tier": quality_tier(score),
|
| 124 |
+
"token_count_estimate": max(1, len(text) // 4),
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def filter_corpus(records: Iterable[dict[str, object]], config: FilterConfig = FilterConfig()) -> list[dict[str, object]]:
|
| 129 |
+
"""Filter a corpus in memory."""
|
| 130 |
+
kept: list[dict[str, object]] = []
|
| 131 |
+
for record in records:
|
| 132 |
+
filtered = filter_record(record, config)
|
| 133 |
+
if filtered is not None:
|
| 134 |
+
kept.append(filtered)
|
| 135 |
+
return kept
|
data/ingest.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Raw corpus ingestion utilities."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Iterable, Iterator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class SourceSpec:
|
| 14 |
+
"""Describes one raw corpus source."""
|
| 15 |
+
|
| 16 |
+
name: str
|
| 17 |
+
domain_tag: str
|
| 18 |
+
quality_tier: str
|
| 19 |
+
license_category: str
|
| 20 |
+
estimated_tokens: int
|
| 21 |
+
path: str
|
| 22 |
+
text_key: str = "text"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
SOURCE_REGISTRY: tuple[SourceSpec, ...] = (
|
| 26 |
+
SourceSpec("general_web", "general", "medium", "permissive", 20_000_000_000, "data/raw/general_web.jsonl"),
|
| 27 |
+
SourceSpec("code", "code", "high", "permissive", 8_000_000_000, "data/raw/code.jsonl"),
|
| 28 |
+
SourceSpec("math_science", "math", "high", "permissive", 4_000_000_000, "data/raw/math_science.jsonl"),
|
| 29 |
+
SourceSpec("books_longform", "general", "high", "restricted", 5_000_000_000, "data/raw/books.jsonl"),
|
| 30 |
+
SourceSpec("multilingual", "multilingual", "medium", "permissive", 3_000_000_000, "data/raw/multilingual.jsonl"),
|
| 31 |
+
SourceSpec("synthetic", "reasoning", "high", "permissive", 1_000_000_000, "data/raw/synthetic.jsonl"),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def iter_jsonl(path: Path, text_key: str = "text") -> Iterator[dict[str, object]]:
|
| 36 |
+
"""Yield JSONL records from disk."""
|
| 37 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 38 |
+
for line in handle:
|
| 39 |
+
line = line.strip()
|
| 40 |
+
if not line:
|
| 41 |
+
continue
|
| 42 |
+
payload = json.loads(line)
|
| 43 |
+
text = payload.get(text_key)
|
| 44 |
+
if not isinstance(text, str) or not text.strip():
|
| 45 |
+
continue
|
| 46 |
+
yield payload
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def stream_source(spec: SourceSpec) -> Iterator[dict[str, object]]:
|
| 50 |
+
"""Yield normalized records for one configured source."""
|
| 51 |
+
path = Path(spec.path)
|
| 52 |
+
if not path.exists():
|
| 53 |
+
return iter(())
|
| 54 |
+
return (
|
| 55 |
+
{
|
| 56 |
+
"id": stable_record_id(spec.name, record[spec.text_key]),
|
| 57 |
+
"text": record[spec.text_key],
|
| 58 |
+
"domain_tag": spec.domain_tag,
|
| 59 |
+
"quality_tier": spec.quality_tier,
|
| 60 |
+
"license_category": spec.license_category,
|
| 61 |
+
"source_name": spec.name,
|
| 62 |
+
}
|
| 63 |
+
for record in iter_jsonl(path, spec.text_key)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def stream_all_sources(sources: Iterable[SourceSpec] = SOURCE_REGISTRY) -> Iterator[dict[str, object]]:
|
| 68 |
+
"""Yield records from every source in the registry."""
|
| 69 |
+
for spec in sources:
|
| 70 |
+
yield from stream_source(spec)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def stable_record_id(source_name: str, text: str) -> str:
|
| 74 |
+
"""Hash a source+text pair into a stable content id."""
|
| 75 |
+
digest = hashlib.sha256()
|
| 76 |
+
digest.update(source_name.encode("utf-8"))
|
| 77 |
+
digest.update(b"\0")
|
| 78 |
+
digest.update(text.encode("utf-8"))
|
| 79 |
+
return digest.hexdigest()
|
data/shard.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenization, manifesting, and Parquet sharding."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Iterable
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import pyarrow as pa
|
| 13 |
+
import pyarrow.parquet as pq
|
| 14 |
+
except ImportError: # pragma: no cover - optional at import time
|
| 15 |
+
pa = None
|
| 16 |
+
pq = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
SCHEMA_COLUMNS = ("id", "text", "tokens", "domain_tag", "quality_tier", "lang", "token_count", "split")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True)
|
| 23 |
+
class ShardConfig:
|
| 24 |
+
"""Parameters for Parquet shard writing."""
|
| 25 |
+
|
| 26 |
+
output_dir: str
|
| 27 |
+
shard_size: int = 2048
|
| 28 |
+
validation_ratio: float = 0.01
|
| 29 |
+
test_ratio: float = 0.001
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def assign_split(record_id: str, validation_ratio: float, test_ratio: float) -> str:
|
| 33 |
+
"""Assign a deterministic split from the content id."""
|
| 34 |
+
value = int(record_id[:8], 16) / 0xFFFFFFFF
|
| 35 |
+
if value < test_ratio:
|
| 36 |
+
return "test"
|
| 37 |
+
if value < test_ratio + validation_ratio:
|
| 38 |
+
return "validation"
|
| 39 |
+
return "train"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_manifest(shard_paths: Iterable[Path]) -> dict[str, object]:
|
| 43 |
+
"""Create a manifest describing shard files."""
|
| 44 |
+
shard_paths = list(shard_paths)
|
| 45 |
+
digest = hashlib.sha256()
|
| 46 |
+
for path in shard_paths:
|
| 47 |
+
digest.update(path.name.encode("utf-8"))
|
| 48 |
+
digest.update(str(path.stat().st_size).encode("utf-8"))
|
| 49 |
+
return {
|
| 50 |
+
"format": "parquet",
|
| 51 |
+
"schema": list(SCHEMA_COLUMNS),
|
| 52 |
+
"shards": [path.name for path in shard_paths],
|
| 53 |
+
"dataset_hash": digest.hexdigest(),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def write_shards(records: Iterable[dict[str, object]], tokenizer, config: ShardConfig) -> dict[str, object]:
|
| 58 |
+
"""Write tokenized records to Parquet shards and emit a manifest."""
|
| 59 |
+
if pa is None or pq is None:
|
| 60 |
+
raise ImportError("pyarrow is required to write parquet shards.")
|
| 61 |
+
output_dir = Path(config.output_dir)
|
| 62 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
buffer: list[dict[str, object]] = []
|
| 64 |
+
shard_paths: list[Path] = []
|
| 65 |
+
shard_index = 0
|
| 66 |
+
for record in records:
|
| 67 |
+
tokens = tokenizer.encode(str(record["text"]), out_type=int)
|
| 68 |
+
row = {
|
| 69 |
+
"id": str(record["id"]),
|
| 70 |
+
"text": str(record["text"]),
|
| 71 |
+
"tokens": tokens,
|
| 72 |
+
"domain_tag": str(record["domain_tag"]),
|
| 73 |
+
"quality_tier": str(record["quality_tier"]),
|
| 74 |
+
"lang": str(record["lang"]),
|
| 75 |
+
"token_count": len(tokens),
|
| 76 |
+
"split": assign_split(str(record["id"]), config.validation_ratio, config.test_ratio),
|
| 77 |
+
}
|
| 78 |
+
buffer.append(row)
|
| 79 |
+
if len(buffer) >= config.shard_size:
|
| 80 |
+
shard_paths.append(_flush_shard(output_dir, shard_index, buffer))
|
| 81 |
+
shard_index += 1
|
| 82 |
+
buffer = []
|
| 83 |
+
if buffer:
|
| 84 |
+
shard_paths.append(_flush_shard(output_dir, shard_index, buffer))
|
| 85 |
+
manifest = build_manifest(shard_paths)
|
| 86 |
+
(output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
|
| 87 |
+
return manifest
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _flush_shard(output_dir: Path, shard_index: int, rows: list[dict[str, object]]) -> Path:
|
| 91 |
+
"""Persist one Parquet shard."""
|
| 92 |
+
table = pa.table({column: [row[column] for row in rows] for column in SCHEMA_COLUMNS})
|
| 93 |
+
path = output_dir / f"shard-{shard_index:05d}.parquet"
|
| 94 |
+
pq.write_table(table, path)
|
| 95 |
+
return path
|
docs/COMMANDS.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAGE Commands
|
| 2 |
+
|
| 3 |
+
This file is the short command-only reference for the repo.
|
| 4 |
+
|
| 5 |
+
## Install
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
pip install -r requirements.txt
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Run tests
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
pytest -q
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Train tokenizer
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
python -m tokenizer.train_tokenizer \
|
| 21 |
+
--input data/raw/general_web.txt data/raw/code.txt \
|
| 22 |
+
--model-prefix tokenizer/tokenizer \
|
| 23 |
+
--vocab-size 50000
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Validate tokenizer
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
## Start a short training smoke run
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
python -m train.trainer \
|
| 36 |
+
--train-shards data/processed/shard-00000.parquet \
|
| 37 |
+
--validation-shards data/processed/shard-00001.parquet \
|
| 38 |
+
--output-dir runs/smoke \
|
| 39 |
+
--steps 20 \
|
| 40 |
+
--disable-wandb
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Start full training
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
python -m train.trainer \
|
| 47 |
+
--model-config configs/model/1b.yaml \
|
| 48 |
+
--schedule-config configs/train/schedule.yaml \
|
| 49 |
+
--train-shards data/processed/shard-00000.parquet data/processed/shard-00001.parquet \
|
| 50 |
+
--validation-shards data/processed/shard-00002.parquet \
|
| 51 |
+
--output-dir runs/sage-1b
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Run eval harness
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
bash scripts/run_eval.sh
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Start GPU server
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
bash scripts/run_serve.sh
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Start CPU server
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
bash scripts/run_serve_cpu.sh
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## Check server health
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
curl http://127.0.0.1:8000/health
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Generate tokens from the API
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
curl -X POST http://127.0.0.1:8000/generate \
|
| 82 |
+
-H "Content-Type: application/json" \
|
| 83 |
+
-d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}"
|
| 84 |
+
```
|
docs/flow_llm.mmd
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flowchart TB
|
| 2 |
+
|
| 3 |
+
%% =========================================================
|
| 4 |
+
%% SAGE - Simplified Operational Flow
|
| 5 |
+
%% This file is pure Mermaid so .mmd renderers can open it.
|
| 6 |
+
%% =========================================================
|
| 7 |
+
|
| 8 |
+
user["You / Operator"]
|
| 9 |
+
|
| 10 |
+
subgraph inputs["Inputs"]
|
| 11 |
+
raw["Raw text / JSONL corpus"]
|
| 12 |
+
cfg_model["configs/model/1b.yaml"]
|
| 13 |
+
cfg_train["configs/train/schedule.yaml"]
|
| 14 |
+
cfg_data["configs/data/mix.yaml"]
|
| 15 |
+
end
|
| 16 |
+
|
| 17 |
+
subgraph tokenizer["Tokenizer Stage"]
|
| 18 |
+
tok_train["tokenizer/train_tokenizer.py"]
|
| 19 |
+
tok_validate["tokenizer/validate_tokenizer.py"]
|
| 20 |
+
tok_model["tokenizer.model + tokenizer.vocab"]
|
| 21 |
+
end
|
| 22 |
+
|
| 23 |
+
subgraph prep["Data Preparation Stage"]
|
| 24 |
+
ingest["data/ingest.py"]
|
| 25 |
+
filter["data/filter.py"]
|
| 26 |
+
dedup["data/dedup.py"]
|
| 27 |
+
shard["data/shard.py"]
|
| 28 |
+
parquet["Parquet shards + manifest.json"]
|
| 29 |
+
packed["data/dataset.py<br/>PackedDataset"]
|
| 30 |
+
end
|
| 31 |
+
|
| 32 |
+
subgraph model["Model Stage"]
|
| 33 |
+
model_cfg["model/config.py"]
|
| 34 |
+
rms["RMSNorm"]
|
| 35 |
+
rope["RoPE"]
|
| 36 |
+
attn["GQA Attention + SDPA"]
|
| 37 |
+
mlp["SwiGLU MLP"]
|
| 38 |
+
blocks["Transformer Blocks x24"]
|
| 39 |
+
sage["SageTransformer"]
|
| 40 |
+
end
|
| 41 |
+
|
| 42 |
+
subgraph train["Training Stage"]
|
| 43 |
+
hw["train/hardware.py"]
|
| 44 |
+
opt["train/optimizer.py"]
|
| 45 |
+
loss["train/loss.py"]
|
| 46 |
+
ckpt["train/checkpoint.py"]
|
| 47 |
+
trainer["train/trainer.py"]
|
| 48 |
+
metrics["runs/<name>/metrics.jsonl"]
|
| 49 |
+
saves["runs/<name>/ckpt_step_xxxxxxx.pt"]
|
| 50 |
+
end
|
| 51 |
+
|
| 52 |
+
subgraph evals["Evaluation Stage"]
|
| 53 |
+
ppl["eval/perplexity.py"]
|
| 54 |
+
bench["eval/benchmarks.py"]
|
| 55 |
+
longctx["eval/long_context.py"]
|
| 56 |
+
regress["eval/regression.py"]
|
| 57 |
+
end
|
| 58 |
+
|
| 59 |
+
subgraph serving["Serving Stage"]
|
| 60 |
+
kv["serve/kv_cache.py"]
|
| 61 |
+
quant["serve/quantize.py"]
|
| 62 |
+
api["serve/server.py"]
|
| 63 |
+
cpu["serve/server_cpu.py"]
|
| 64 |
+
health["/health"]
|
| 65 |
+
generate["/generate"]
|
| 66 |
+
end
|
| 67 |
+
|
| 68 |
+
user --> raw
|
| 69 |
+
user --> cfg_model
|
| 70 |
+
user --> cfg_train
|
| 71 |
+
user --> cfg_data
|
| 72 |
+
|
| 73 |
+
raw --> tok_train
|
| 74 |
+
tok_train --> tok_model
|
| 75 |
+
tok_model --> tok_validate
|
| 76 |
+
|
| 77 |
+
raw --> ingest
|
| 78 |
+
cfg_data --> ingest
|
| 79 |
+
ingest --> filter
|
| 80 |
+
filter --> dedup
|
| 81 |
+
dedup --> shard
|
| 82 |
+
tok_model --> shard
|
| 83 |
+
shard --> parquet
|
| 84 |
+
parquet --> packed
|
| 85 |
+
|
| 86 |
+
cfg_model --> model_cfg
|
| 87 |
+
model_cfg --> rms
|
| 88 |
+
model_cfg --> rope
|
| 89 |
+
model_cfg --> attn
|
| 90 |
+
model_cfg --> mlp
|
| 91 |
+
rms --> blocks
|
| 92 |
+
rope --> attn
|
| 93 |
+
attn --> blocks
|
| 94 |
+
mlp --> blocks
|
| 95 |
+
blocks --> sage
|
| 96 |
+
|
| 97 |
+
packed --> trainer
|
| 98 |
+
cfg_train --> trainer
|
| 99 |
+
cfg_model --> trainer
|
| 100 |
+
hw --> trainer
|
| 101 |
+
opt --> trainer
|
| 102 |
+
loss --> trainer
|
| 103 |
+
ckpt --> trainer
|
| 104 |
+
sage --> trainer
|
| 105 |
+
|
| 106 |
+
trainer --> metrics
|
| 107 |
+
trainer --> saves
|
| 108 |
+
trainer --> ppl
|
| 109 |
+
|
| 110 |
+
sage --> ppl
|
| 111 |
+
sage --> bench
|
| 112 |
+
sage --> longctx
|
| 113 |
+
ppl --> regress
|
| 114 |
+
bench --> regress
|
| 115 |
+
longctx --> regress
|
| 116 |
+
|
| 117 |
+
sage --> kv
|
| 118 |
+
sage --> quant
|
| 119 |
+
sage --> api
|
| 120 |
+
quant --> cpu
|
| 121 |
+
kv --> api
|
| 122 |
+
api --> health
|
| 123 |
+
api --> generate
|
| 124 |
+
cpu --> health
|
| 125 |
+
|
| 126 |
+
classDef input fill:#0f172a,stroke:#93c5fd,color:#ffffff
|
| 127 |
+
classDef token fill:#1d4ed8,stroke:#bfdbfe,color:#ffffff
|
| 128 |
+
classDef prep fill:#0f766e,stroke:#99f6e4,color:#ffffff
|
| 129 |
+
classDef model fill:#581c87,stroke:#d8b4fe,color:#ffffff
|
| 130 |
+
classDef train fill:#92400e,stroke:#fde68a,color:#ffffff
|
| 131 |
+
classDef eval fill:#991b1b,stroke:#fecaca,color:#ffffff
|
| 132 |
+
classDef serve fill:#166534,stroke:#bbf7d0,color:#ffffff
|
| 133 |
+
|
| 134 |
+
class user,raw,cfg_model,cfg_train,cfg_data input
|
| 135 |
+
class tok_train,tok_validate,tok_model token
|
| 136 |
+
class ingest,filter,dedup,shard,parquet,packed prep
|
| 137 |
+
class model_cfg,rms,rope,attn,mlp,blocks,sage model
|
| 138 |
+
class hw,opt,loss,ckpt,trainer,metrics,saves train
|
| 139 |
+
class ppl,bench,longctx,regress eval
|
| 140 |
+
class kv,quant,api,cpu,health,generate serve
|
docs/llm_Arch.mmd
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SAGE 1B System Architecture
|
| 3 |
+
---
|
| 4 |
+
flowchart TB
|
| 5 |
+
|
| 6 |
+
%% =========================================================
|
| 7 |
+
%% SAGE 1B - End-to-End Architecture and Flow Overview
|
| 8 |
+
%% =========================================================
|
| 9 |
+
|
| 10 |
+
user["Developer / Operator"]
|
| 11 |
+
|
| 12 |
+
subgraph repo["SAGE Repository"]
|
| 13 |
+
direction TB
|
| 14 |
+
|
| 15 |
+
subgraph configs["configs/"]
|
| 16 |
+
cfg_model["model/1b.yaml<br/>24L, 2048 d_model, 16Q / 8KV, 4096 ctx"]
|
| 17 |
+
cfg_data["data/mix.yaml<br/>corpus weights + split ratios"]
|
| 18 |
+
cfg_train["train/schedule.yaml<br/>LR, warmup, checkpoints, logging"]
|
| 19 |
+
end
|
| 20 |
+
|
| 21 |
+
subgraph tokenizer["tokenizer/"]
|
| 22 |
+
tok_train["train_tokenizer.py<br/>SentencePiece BPE training"]
|
| 23 |
+
tok_validate["validate_tokenizer.py<br/>roundtrip + edge-case checks"]
|
| 24 |
+
tok_model["tokenizer.model / tokenizer.vocab"]
|
| 25 |
+
end
|
| 26 |
+
|
| 27 |
+
subgraph data_layer["data/"]
|
| 28 |
+
ingest["ingest.py<br/>source registry + raw record streaming"]
|
| 29 |
+
filter["filter.py<br/>license, lang, PII, safety, quality"]
|
| 30 |
+
dedup["dedup.py<br/>exact + near-duplicate removal"]
|
| 31 |
+
shard["shard.py<br/>tokenize -> parquet shards + manifest"]
|
| 32 |
+
dataset["dataset.py<br/>PackedDataset + skip(n_batches)"]
|
| 33 |
+
end
|
| 34 |
+
|
| 35 |
+
subgraph model_layer["model/"]
|
| 36 |
+
model_cfg["config.py<br/>ModelConfig"]
|
| 37 |
+
rmsnorm["rmsnorm.py<br/>pre-norm RMSNorm"]
|
| 38 |
+
rope["rope.py<br/>RoPE cache + apply_rope"]
|
| 39 |
+
attn["attention.py<br/>fused QKV + GQA + SDPA"]
|
| 40 |
+
mlp["mlp.py<br/>SwiGLU FFN"]
|
| 41 |
+
block["block.py<br/>Transformer block"]
|
| 42 |
+
full_model["model.py<br/>SageTransformer"]
|
| 43 |
+
end
|
| 44 |
+
|
| 45 |
+
subgraph train_layer["train/"]
|
| 46 |
+
hw["hardware.py<br/>device / dtype / batch routing"]
|
| 47 |
+
dist["distributed.py<br/>single / DDP / FSDP strategy"]
|
| 48 |
+
opt["optimizer.py<br/>AdamW + cosine schedule"]
|
| 49 |
+
loss["loss.py<br/>masked next-token cross entropy"]
|
| 50 |
+
ckpt["checkpoint.py<br/>save / prune / resume"]
|
| 51 |
+
trainer["trainer.py<br/>main training loop"]
|
| 52 |
+
end
|
| 53 |
+
|
| 54 |
+
subgraph eval_layer["eval/"]
|
| 55 |
+
ppl["perplexity.py<br/>validation loss + perplexity"]
|
| 56 |
+
benches["benchmarks.py<br/>benchmark harness registry"]
|
| 57 |
+
longctx["long_context.py<br/>needle-in-haystack probes"]
|
| 58 |
+
regress["regression.py<br/>checkpoint metric comparison"]
|
| 59 |
+
end
|
| 60 |
+
|
| 61 |
+
subgraph serve_layer["serve/"]
|
| 62 |
+
kv["kv_cache.py<br/>cache container"]
|
| 63 |
+
quant["quantize.py<br/>int8 export + GGUF command helper"]
|
| 64 |
+
gpu_api["server.py<br/>FastAPI GPU server"]
|
| 65 |
+
cpu_api["server_cpu.py<br/>FastAPI CPU readiness surface"]
|
| 66 |
+
end
|
| 67 |
+
|
| 68 |
+
subgraph scripts["scripts/"]
|
| 69 |
+
s_data["run_data_pipeline.sh"]
|
| 70 |
+
s_train["run_training.sh"]
|
| 71 |
+
s_eval["run_eval.sh"]
|
| 72 |
+
s_serve["run_serve.sh / run_serve_cpu.sh"]
|
| 73 |
+
end
|
| 74 |
+
|
| 75 |
+
subgraph outputs["Runtime Outputs"]
|
| 76 |
+
raw["Raw text / JSONL corpora"]
|
| 77 |
+
parquet["Parquet shards + manifest.json"]
|
| 78 |
+
runs["runs/<name>/metrics.jsonl"]
|
| 79 |
+
checkpoints["runs/<name>/ckpt_step_xxxxxxx.pt"]
|
| 80 |
+
api_out["/health + /generate responses"]
|
| 81 |
+
end
|
| 82 |
+
end
|
| 83 |
+
|
| 84 |
+
%% =========================================================
|
| 85 |
+
%% Top-level usage
|
| 86 |
+
%% =========================================================
|
| 87 |
+
|
| 88 |
+
user --> s_data
|
| 89 |
+
user --> s_train
|
| 90 |
+
user --> s_eval
|
| 91 |
+
user --> s_serve
|
| 92 |
+
|
| 93 |
+
s_data --> tok_train
|
| 94 |
+
s_train --> trainer
|
| 95 |
+
s_eval --> benches
|
| 96 |
+
s_serve --> gpu_api
|
| 97 |
+
s_serve --> cpu_api
|
| 98 |
+
|
| 99 |
+
%% =========================================================
|
| 100 |
+
%% Tokenizer flow
|
| 101 |
+
%% =========================================================
|
| 102 |
+
|
| 103 |
+
raw --> tok_train
|
| 104 |
+
tok_train --> tok_model
|
| 105 |
+
tok_model --> tok_validate
|
| 106 |
+
|
| 107 |
+
%% =========================================================
|
| 108 |
+
%% Data preparation flow
|
| 109 |
+
%% =========================================================
|
| 110 |
+
|
| 111 |
+
raw --> ingest
|
| 112 |
+
ingest --> filter
|
| 113 |
+
filter --> dedup
|
| 114 |
+
dedup --> shard
|
| 115 |
+
tok_model --> shard
|
| 116 |
+
shard --> parquet
|
| 117 |
+
parquet --> dataset
|
| 118 |
+
|
| 119 |
+
cfg_data --> ingest
|
| 120 |
+
cfg_data --> filter
|
| 121 |
+
cfg_data --> shard
|
| 122 |
+
|
| 123 |
+
%% =========================================================
|
| 124 |
+
%% Model construction flow
|
| 125 |
+
%% =========================================================
|
| 126 |
+
|
| 127 |
+
cfg_model --> model_cfg
|
| 128 |
+
model_cfg --> rmsnorm
|
| 129 |
+
model_cfg --> rope
|
| 130 |
+
model_cfg --> attn
|
| 131 |
+
model_cfg --> mlp
|
| 132 |
+
rmsnorm --> block
|
| 133 |
+
rope --> attn
|
| 134 |
+
attn --> block
|
| 135 |
+
mlp --> block
|
| 136 |
+
block --> full_model
|
| 137 |
+
|
| 138 |
+
%% =========================================================
|
| 139 |
+
%% Training flow
|
| 140 |
+
%% =========================================================
|
| 141 |
+
|
| 142 |
+
cfg_train --> opt
|
| 143 |
+
cfg_train --> trainer
|
| 144 |
+
cfg_train --> ckpt
|
| 145 |
+
model_cfg --> full_model
|
| 146 |
+
dataset --> trainer
|
| 147 |
+
full_model --> trainer
|
| 148 |
+
hw --> trainer
|
| 149 |
+
dist --> trainer
|
| 150 |
+
opt --> trainer
|
| 151 |
+
loss --> trainer
|
| 152 |
+
ckpt --> trainer
|
| 153 |
+
|
| 154 |
+
trainer --> runs
|
| 155 |
+
trainer --> checkpoints
|
| 156 |
+
trainer --> ppl
|
| 157 |
+
|
| 158 |
+
%% =========================================================
|
| 159 |
+
%% Evaluation flow
|
| 160 |
+
%% =========================================================
|
| 161 |
+
|
| 162 |
+
full_model --> ppl
|
| 163 |
+
full_model --> benches
|
| 164 |
+
full_model --> longctx
|
| 165 |
+
ppl --> regress
|
| 166 |
+
benches --> regress
|
| 167 |
+
longctx --> regress
|
| 168 |
+
|
| 169 |
+
%% =========================================================
|
| 170 |
+
%% Serving flow
|
| 171 |
+
%% =========================================================
|
| 172 |
+
|
| 173 |
+
full_model --> kv
|
| 174 |
+
full_model --> quant
|
| 175 |
+
full_model --> gpu_api
|
| 176 |
+
quant --> cpu_api
|
| 177 |
+
kv --> gpu_api
|
| 178 |
+
hw --> gpu_api
|
| 179 |
+
gpu_api --> api_out
|
| 180 |
+
cpu_api --> api_out
|
| 181 |
+
|
| 182 |
+
%% =========================================================
|
| 183 |
+
%% Visual grouping
|
| 184 |
+
%% =========================================================
|
| 185 |
+
|
| 186 |
+
classDef config fill:#1f2937,stroke:#93c5fd,color:#ffffff
|
| 187 |
+
classDef pipeline fill:#0f766e,stroke:#5eead4,color:#ffffff
|
| 188 |
+
classDef model fill:#4c1d95,stroke:#c4b5fd,color:#ffffff
|
| 189 |
+
classDef train fill:#92400e,stroke:#fcd34d,color:#ffffff
|
| 190 |
+
classDef eval fill:#7f1d1d,stroke:#fca5a5,color:#ffffff
|
| 191 |
+
classDef serve fill:#065f46,stroke:#86efac,color:#ffffff
|
| 192 |
+
classDef io fill:#111827,stroke:#9ca3af,color:#ffffff
|
| 193 |
+
classDef actor fill:#2563eb,stroke:#bfdbfe,color:#ffffff
|
| 194 |
+
|
| 195 |
+
class user actor
|
| 196 |
+
class cfg_model,cfg_data,cfg_train config
|
| 197 |
+
class tok_train,tok_validate,ingest,filter,dedup,shard,dataset pipeline
|
| 198 |
+
class model_cfg,rmsnorm,rope,attn,mlp,block,full_model model
|
| 199 |
+
class hw,dist,opt,loss,ckpt,trainer train
|
| 200 |
+
class ppl,benches,longctx,regress eval
|
| 201 |
+
class kv,quant,gpu_api,cpu_api serve
|
| 202 |
+
class raw,parquet,runs,checkpoints,api_out,tok_model,s_data,s_train,s_eval,s_serve io
|
eval/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation helpers for SAGE."""
|
eval/benchmarks.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark harness registration for SAGE."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class BenchmarkResult:
|
| 10 |
+
"""A normalized benchmark result."""
|
| 11 |
+
|
| 12 |
+
name: str
|
| 13 |
+
status: str
|
| 14 |
+
score: float | None
|
| 15 |
+
detail: str
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
BENCHMARKS = (
|
| 19 |
+
"hellaswag",
|
| 20 |
+
"winogrande",
|
| 21 |
+
"arc_easy",
|
| 22 |
+
"arc_challenge",
|
| 23 |
+
"gsm8k",
|
| 24 |
+
"math",
|
| 25 |
+
"humaneval",
|
| 26 |
+
"mbpp",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def run_registered_benchmarks(model, tokenizer=None) -> list[BenchmarkResult]:
|
| 31 |
+
"""Return a lightweight result set for the configured benchmarks."""
|
| 32 |
+
_ = model
|
| 33 |
+
_ = tokenizer
|
| 34 |
+
return [
|
| 35 |
+
BenchmarkResult(
|
| 36 |
+
name=name,
|
| 37 |
+
status="skipped",
|
| 38 |
+
score=None,
|
| 39 |
+
detail="Benchmark harness registered; dataset/task execution is external to unit tests.",
|
| 40 |
+
)
|
| 41 |
+
for name in BENCHMARKS
|
| 42 |
+
]
|
eval/long_context.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Long-context retrieval evaluation helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class RetrievalProbe:
|
| 10 |
+
"""A synthetic retrieval probe for long-context checks."""
|
| 11 |
+
|
| 12 |
+
prompt: str
|
| 13 |
+
needle: str
|
| 14 |
+
expected_index: int
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def build_needle_in_haystack_probe(context_length: int) -> RetrievalProbe:
|
| 18 |
+
"""Create a deterministic retrieval prompt for smoke tests."""
|
| 19 |
+
needle = "SAGE_LONG_CONTEXT_NEEDLE"
|
| 20 |
+
haystack = ["token"] * max(context_length - 16, 16)
|
| 21 |
+
insert_at = min(len(haystack) // 2, max(context_length // 4, 1))
|
| 22 |
+
haystack.insert(insert_at, needle)
|
| 23 |
+
prompt = " ".join(haystack)
|
| 24 |
+
return RetrievalProbe(prompt=prompt, needle=needle, expected_index=insert_at)
|
eval/perplexity.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validation perplexity evaluation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from train.loss import masked_cross_entropy
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.no_grad()
|
| 13 |
+
def evaluate_perplexity(
|
| 14 |
+
model: torch.nn.Module,
|
| 15 |
+
dataloader,
|
| 16 |
+
device: torch.device,
|
| 17 |
+
dtype: torch.dtype | None = None,
|
| 18 |
+
max_batches: int = 16,
|
| 19 |
+
) -> dict[str, float]:
|
| 20 |
+
"""Evaluate average loss and perplexity on a validation loader."""
|
| 21 |
+
model.eval()
|
| 22 |
+
losses: list[float] = []
|
| 23 |
+
for index, batch in enumerate(dataloader):
|
| 24 |
+
if index >= max_batches:
|
| 25 |
+
break
|
| 26 |
+
input_ids = batch["input_ids"].to(device)
|
| 27 |
+
labels = batch["labels"].to(device)
|
| 28 |
+
loss_mask = batch["loss_mask"].to(device)
|
| 29 |
+
if dtype is not None and device.type != "cpu":
|
| 30 |
+
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
| 31 |
+
logits, _ = model(input_ids)
|
| 32 |
+
loss = masked_cross_entropy(logits, labels, loss_mask)
|
| 33 |
+
else:
|
| 34 |
+
logits, _ = model(input_ids)
|
| 35 |
+
loss = masked_cross_entropy(logits, labels, loss_mask)
|
| 36 |
+
losses.append(float(loss))
|
| 37 |
+
model.train()
|
| 38 |
+
mean_loss = sum(losses) / max(len(losses), 1)
|
| 39 |
+
return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))}
|
eval/regression.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Checkpoint-to-checkpoint regression checks."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def compare_metrics(previous: dict[str, float], current: dict[str, float], threshold: float = 0.005) -> dict[str, object]:
|
| 7 |
+
"""Flag metric drops larger than the configured threshold."""
|
| 8 |
+
regressions: list[str] = []
|
| 9 |
+
for key, prev_value in previous.items():
|
| 10 |
+
curr_value = current.get(key)
|
| 11 |
+
if curr_value is None:
|
| 12 |
+
continue
|
| 13 |
+
if curr_value < prev_value * (1.0 - threshold):
|
| 14 |
+
regressions.append(key)
|
| 15 |
+
return {"regressions": regressions, "passed": not regressions}
|
hf_push.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Upload the current SAGE repository contents to the Hugging Face Hub."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import HfApi
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
REPO_ID = "sage002/sage"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main() -> None:
|
| 12 |
+
"""Replace the remote Hugging Face repo contents with the local folder state."""
|
| 13 |
+
api = HfApi()
|
| 14 |
+
print(f"Syncing current repository to {REPO_ID}...")
|
| 15 |
+
api.upload_folder(
|
| 16 |
+
folder_path=".",
|
| 17 |
+
repo_id=REPO_ID,
|
| 18 |
+
repo_type="model",
|
| 19 |
+
ignore_patterns=[
|
| 20 |
+
".git/*",
|
| 21 |
+
".venv/*",
|
| 22 |
+
"__pycache__/*",
|
| 23 |
+
"*.pyc",
|
| 24 |
+
"checkpoints/*",
|
| 25 |
+
"runs/*",
|
| 26 |
+
"wandb/*",
|
| 27 |
+
"data/raw/*",
|
| 28 |
+
"data/processed/*",
|
| 29 |
+
"tokenizer/*.model",
|
| 30 |
+
"tokenizer/*.vocab",
|
| 31 |
+
"tokenizer/training_corpus.txt",
|
| 32 |
+
],
|
| 33 |
+
delete_patterns="*",
|
| 34 |
+
commit_message="feat: rewrite SAGE 1B architecture and replace legacy repo contents",
|
| 35 |
+
)
|
| 36 |
+
print("Sync complete.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Model architecture for SAGE."""
|
model/attention.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Grouped-query attention with SDPA and KV-cache support."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from model.config import ModelConfig
|
| 12 |
+
from model.rope import apply_rope
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def repeat_kv(x: torch.Tensor, num_groups: int) -> torch.Tensor:
|
| 16 |
+
"""Expand KV heads to match the number of query heads."""
|
| 17 |
+
if num_groups == 1:
|
| 18 |
+
return x
|
| 19 |
+
batch, kv_heads, seq_len, head_dim = x.shape
|
| 20 |
+
x = x[:, :, None, :, :].expand(batch, kv_heads, num_groups, seq_len, head_dim)
|
| 21 |
+
return x.reshape(batch, kv_heads * num_groups, seq_len, head_dim)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GQAAttention(nn.Module):
|
| 25 |
+
"""Fused-QKV grouped-query attention."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, config: ModelConfig):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.config = config
|
| 30 |
+
self.num_heads = config.num_attn_heads
|
| 31 |
+
self.num_kv_heads = config.num_kv_heads
|
| 32 |
+
self.head_dim = config.head_dim
|
| 33 |
+
self.num_groups = self.num_heads // self.num_kv_heads
|
| 34 |
+
qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim
|
| 35 |
+
self.qkv_proj = nn.Linear(config.d_model, qkv_dim, bias=False)
|
| 36 |
+
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
| 37 |
+
self.dropout = config.dropout
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
hidden_states: torch.Tensor,
|
| 42 |
+
cos: torch.Tensor,
|
| 43 |
+
sin: torch.Tensor,
|
| 44 |
+
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 45 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 46 |
+
"""Compute causal self-attention and return an updated KV cache."""
|
| 47 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 48 |
+
qkv = self.qkv_proj(hidden_states)
|
| 49 |
+
q_end = self.num_heads * self.head_dim
|
| 50 |
+
k_end = q_end + self.num_kv_heads * self.head_dim
|
| 51 |
+
q, k, v = qkv.split((q_end, self.num_kv_heads * self.head_dim, self.num_kv_heads * self.head_dim), dim=-1)
|
| 52 |
+
|
| 53 |
+
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 54 |
+
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 55 |
+
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 56 |
+
|
| 57 |
+
q_rope, k_rope = apply_rope(q, repeat_kv(k, self.num_groups), cos, sin)
|
| 58 |
+
k = k_rope[:, :: self.num_groups, :, :]
|
| 59 |
+
|
| 60 |
+
if past_key_value is not None:
|
| 61 |
+
past_key, past_value = past_key_value
|
| 62 |
+
k = torch.cat([past_key, k], dim=-2)
|
| 63 |
+
v = torch.cat([past_value, v], dim=-2)
|
| 64 |
+
|
| 65 |
+
expanded_k = repeat_kv(k, self.num_groups)
|
| 66 |
+
expanded_v = repeat_kv(v, self.num_groups)
|
| 67 |
+
attn_output = F.scaled_dot_product_attention(
|
| 68 |
+
q_rope,
|
| 69 |
+
expanded_k,
|
| 70 |
+
expanded_v,
|
| 71 |
+
attn_mask=None,
|
| 72 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 73 |
+
is_causal=past_key_value is None,
|
| 74 |
+
)
|
| 75 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.config.d_model)
|
| 76 |
+
return self.out_proj(attn_output), (k, v)
|
model/block.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer block for the dense SAGE model."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from model.attention import GQAAttention
|
| 11 |
+
from model.config import ModelConfig
|
| 12 |
+
from model.mlp import SwiGLUMLP
|
| 13 |
+
from model.rmsnorm import RMSNorm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TransformerBlock(nn.Module):
|
| 17 |
+
"""Pre-norm transformer block with attention and SwiGLU."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config: ModelConfig):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
|
| 22 |
+
self.attn = GQAAttention(config)
|
| 23 |
+
self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps)
|
| 24 |
+
self.mlp = SwiGLUMLP(config)
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
hidden_states: torch.Tensor,
|
| 29 |
+
cos: torch.Tensor,
|
| 30 |
+
sin: torch.Tensor,
|
| 31 |
+
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 32 |
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
| 33 |
+
"""Forward pass with residual connections."""
|
| 34 |
+
attn_output, present = self.attn(self.norm1(hidden_states), cos, sin, past_key_value=past_key_value)
|
| 35 |
+
hidden_states = hidden_states + attn_output
|
| 36 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 37 |
+
return hidden_states, present
|
model/config.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model configuration for SAGE."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import asdict, dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import yaml
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ModelConfig:
|
| 14 |
+
"""Configuration for the dense SAGE decoder-only transformer."""
|
| 15 |
+
|
| 16 |
+
name: str = "sage-1b"
|
| 17 |
+
num_layers: int = 24
|
| 18 |
+
d_model: int = 2048
|
| 19 |
+
num_attn_heads: int = 16
|
| 20 |
+
num_kv_heads: int = 8
|
| 21 |
+
head_dim: int = 128
|
| 22 |
+
ffn_hidden_dim: int = 5632
|
| 23 |
+
vocab_size: int = 50_000
|
| 24 |
+
context_length: int = 4096
|
| 25 |
+
rope_base_frequency: int = 500_000
|
| 26 |
+
rope_scaling_factor: float = 1.0
|
| 27 |
+
dropout: float = 0.0
|
| 28 |
+
tie_word_embeddings: bool = True
|
| 29 |
+
rms_norm_eps: float = 1.0e-5
|
| 30 |
+
initializer_range: float = 0.02
|
| 31 |
+
|
| 32 |
+
def __post_init__(self) -> None:
|
| 33 |
+
if self.num_attn_heads * self.head_dim != self.d_model:
|
| 34 |
+
raise ValueError("num_attn_heads * head_dim must equal d_model.")
|
| 35 |
+
if self.num_attn_heads % self.num_kv_heads != 0:
|
| 36 |
+
raise ValueError("num_attn_heads must be divisible by num_kv_heads.")
|
| 37 |
+
if self.ffn_hidden_dim % 256 != 0:
|
| 38 |
+
raise ValueError("ffn_hidden_dim must be a multiple of 256.")
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def from_yaml(cls, path: str | Path) -> "ModelConfig":
|
| 42 |
+
"""Load a config from YAML."""
|
| 43 |
+
payload = yaml.safe_load(Path(path).read_text(encoding="utf-8"))
|
| 44 |
+
return cls(**payload)
|
| 45 |
+
|
| 46 |
+
def to_dict(self) -> dict[str, Any]:
|
| 47 |
+
"""Serialize the config to a dict."""
|
| 48 |
+
return asdict(self)
|
model/mlp.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SwiGLU feed-forward module."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from model.config import ModelConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SwiGLUMLP(nn.Module):
|
| 13 |
+
"""Bias-free SwiGLU feed-forward network."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, config: ModelConfig):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
|
| 18 |
+
self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False)
|
| 19 |
+
self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
"""Apply SwiGLU and project back to the model width."""
|
| 23 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
model/model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Full dense decoder-only transformer model for SAGE."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from model.block import TransformerBlock
|
| 12 |
+
from model.config import ModelConfig
|
| 13 |
+
from model.rope import build_rope_cache
|
| 14 |
+
from model.rmsnorm import RMSNorm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SageTransformer(nn.Module):
|
| 18 |
+
"""A dense Llama-style decoder-only transformer."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, config: ModelConfig):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.config = config
|
| 23 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
| 24 |
+
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
|
| 25 |
+
self.norm = RMSNorm(config.d_model, eps=config.rms_norm_eps)
|
| 26 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 27 |
+
if config.tie_word_embeddings:
|
| 28 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 29 |
+
cos, sin = build_rope_cache(
|
| 30 |
+
seq_len=config.context_length,
|
| 31 |
+
head_dim=config.head_dim,
|
| 32 |
+
base_frequency=config.rope_base_frequency,
|
| 33 |
+
scaling_factor=config.rope_scaling_factor,
|
| 34 |
+
)
|
| 35 |
+
self.register_buffer("rope_cos", cos, persistent=False)
|
| 36 |
+
self.register_buffer("rope_sin", sin, persistent=False)
|
| 37 |
+
self._reset_parameters()
|
| 38 |
+
|
| 39 |
+
def _reset_parameters(self) -> None:
|
| 40 |
+
"""Apply scaled initialization to the model."""
|
| 41 |
+
embed_std = 1.0 / math.sqrt(self.config.d_model)
|
| 42 |
+
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=embed_std)
|
| 43 |
+
for module in self.modules():
|
| 44 |
+
if not isinstance(module, nn.Linear):
|
| 45 |
+
continue
|
| 46 |
+
std = self.config.initializer_range
|
| 47 |
+
if module is self.lm_head and self.config.tie_word_embeddings:
|
| 48 |
+
continue
|
| 49 |
+
if module.out_features == self.config.d_model:
|
| 50 |
+
std = std / math.sqrt(2 * self.config.num_layers)
|
| 51 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self,
|
| 55 |
+
input_ids: torch.Tensor,
|
| 56 |
+
past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
|
| 57 |
+
) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
|
| 58 |
+
"""Return logits and the updated KV cache."""
|
| 59 |
+
batch_size, seq_len = input_ids.shape
|
| 60 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 61 |
+
past_key_values = past_key_values or [None] * self.config.num_layers
|
| 62 |
+
start = 0
|
| 63 |
+
if past_key_values[0] is not None:
|
| 64 |
+
start = past_key_values[0][0].size(-2)
|
| 65 |
+
cos = self.rope_cos[start : start + seq_len].to(hidden_states.device)
|
| 66 |
+
sin = self.rope_sin[start : start + seq_len].to(hidden_states.device)
|
| 67 |
+
presents: list[tuple[torch.Tensor, torch.Tensor]] = []
|
| 68 |
+
for layer, past in zip(self.layers, past_key_values):
|
| 69 |
+
hidden_states, present = layer(hidden_states, cos, sin, past_key_value=past)
|
| 70 |
+
presents.append(present)
|
| 71 |
+
hidden_states = self.norm(hidden_states)
|
| 72 |
+
logits = self.lm_head(hidden_states)
|
| 73 |
+
return logits, presents
|
model/rmsnorm.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RMSNorm implementation used by SAGE."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RMSNorm(nn.Module):
|
| 10 |
+
"""Root mean square normalization with float32 accumulation."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, dim: int, eps: float = 1.0e-5):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.eps = eps
|
| 15 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
"""Normalize the last dimension and cast back to the input dtype."""
|
| 19 |
+
if x.ndim < 2:
|
| 20 |
+
raise ValueError("RMSNorm expects at least 2 dimensions.")
|
| 21 |
+
variance = x.float().pow(2).mean(dim=-1, keepdim=True)
|
| 22 |
+
normalized = x.float() * torch.rsqrt(variance + self.eps)
|
| 23 |
+
return (normalized.to(dtype=x.dtype)) * self.weight
|
model/rope.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rotary positional embedding helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _scaled_positions(seq_len: int, scaling_factor: float, device: torch.device) -> torch.Tensor:
|
| 9 |
+
"""Apply a simple YaRN-style position scaling factor."""
|
| 10 |
+
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 11 |
+
if scaling_factor > 1.0:
|
| 12 |
+
positions = positions / scaling_factor
|
| 13 |
+
return positions
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_rope_cache(
|
| 17 |
+
seq_len: int,
|
| 18 |
+
head_dim: int,
|
| 19 |
+
base_frequency: int = 500_000,
|
| 20 |
+
scaling_factor: float = 1.0,
|
| 21 |
+
device: torch.device | None = None,
|
| 22 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 23 |
+
"""Precompute cosine and sine tables for RoPE."""
|
| 24 |
+
if head_dim % 2 != 0:
|
| 25 |
+
raise ValueError("head_dim must be even for RoPE.")
|
| 26 |
+
device = device or torch.device("cpu")
|
| 27 |
+
positions = _scaled_positions(seq_len, scaling_factor, device)
|
| 28 |
+
inv_freq = 1.0 / (base_frequency ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim))
|
| 29 |
+
freqs = torch.outer(positions, inv_freq)
|
| 30 |
+
cos = torch.cos(freqs)
|
| 31 |
+
sin = torch.sin(freqs)
|
| 32 |
+
return cos, sin
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
"""Rotate the last dimension in pairs."""
|
| 37 |
+
even = x[..., ::2]
|
| 38 |
+
odd = x[..., 1::2]
|
| 39 |
+
rotated = torch.stack((-odd, even), dim=-1)
|
| 40 |
+
return rotated.flatten(start_dim=-2)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def apply_rope(
|
| 44 |
+
q: torch.Tensor,
|
| 45 |
+
k: torch.Tensor,
|
| 46 |
+
cos: torch.Tensor,
|
| 47 |
+
sin: torch.Tensor,
|
| 48 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 49 |
+
"""Apply rotary embeddings to query and key tensors."""
|
| 50 |
+
if q.shape != k.shape:
|
| 51 |
+
raise ValueError("q and k must share the same shape for RoPE application.")
|
| 52 |
+
seq_len = q.size(-2)
|
| 53 |
+
cos = cos[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1)
|
| 54 |
+
sin = sin[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1)
|
| 55 |
+
q_out = (q * cos) + (rotate_half(q) * sin)
|
| 56 |
+
k_out = (k * cos) + (rotate_half(k) * sin)
|
| 57 |
+
return q_out, k_out
|
requirements.txt
CHANGED
|
@@ -1,27 +1,13 @@
|
|
| 1 |
-
# SAGE - Self-Adaptive General Engine
|
| 2 |
-
# ======================================
|
| 3 |
-
# Core dependencies
|
| 4 |
-
|
| 5 |
-
# PyTorch - GPU compatibility notes:
|
| 6 |
-
# - For Tesla P100 (sm_60), V100, T4, A100: torch>=2.1.0
|
| 7 |
-
# - For older GPUs (sm_60): use torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
| 8 |
-
# - The code auto-detects GPU compatibility and falls back to CPU if needed
|
| 9 |
torch>=2.1.0
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# Quantization (optional GPU support)
|
| 23 |
-
bitsandbytes>=0.41.0
|
| 24 |
-
|
| 25 |
-
# Model Hub & Experiment Tracking
|
| 26 |
-
huggingface_hub>=0.20.0
|
| 27 |
-
wandb>=0.16.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
torch>=2.1.0
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn>=0.29.0
|
| 4 |
+
python-multipart>=0.0.9
|
| 5 |
+
pydantic>=2.7.0
|
| 6 |
+
pyyaml>=6.0.1
|
| 7 |
+
sentencepiece>=0.2.0
|
| 8 |
+
pyarrow>=16.0.0
|
| 9 |
+
psutil>=5.9.8
|
| 10 |
+
wandb>=0.17.0
|
| 11 |
+
pytest>=8.2.0
|
| 12 |
+
httpx>=0.27.0
|
| 13 |
+
bitsandbytes>=0.43.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/__init__.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE — Self-Adaptive General Engine
|
| 3 |
-
A complete mini-LLM system built from scratch.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
__version__ = "1.0.0"
|
| 7 |
-
|
| 8 |
-
from .model import SageModel
|
| 9 |
-
from .config import SageConfig
|
| 10 |
-
from .data import SageTokenizer
|
| 11 |
-
from .inference import generate
|
| 12 |
-
from .memory import ConversationHistory, RAGManager
|
| 13 |
-
from .train import train
|
| 14 |
-
from .finetune import finetune_instruction as finetune
|
| 15 |
-
from .utils import get_compatible_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/cli.py
DELETED
|
@@ -1,299 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE CLI — Interactive Terminal Interface
|
| 3 |
-
==========================================
|
| 4 |
-
Provides a REPL with slash-commands for training, fine-tuning, quantization,
|
| 5 |
-
RAG toggling, and real-time chat with streaming output.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import os
|
| 10 |
-
import torch
|
| 11 |
-
from typing import Optional
|
| 12 |
-
|
| 13 |
-
from .config import SageConfig
|
| 14 |
-
from .model import SageModel
|
| 15 |
-
from .data import SageTokenizer
|
| 16 |
-
from .train import train
|
| 17 |
-
from .inference import generate
|
| 18 |
-
from .finetune import finetune_instruction, DEMO_INSTRUCTION_SAMPLES
|
| 19 |
-
from .optimize import quantize_int8
|
| 20 |
-
from .memory import RAGManager, ConversationHistory
|
| 21 |
-
from .utils import setup_logger, save_checkpoint, load_checkpoint
|
| 22 |
-
from . import __version__
|
| 23 |
-
|
| 24 |
-
logger = setup_logger("sage.cli")
|
| 25 |
-
|
| 26 |
-
# ===================================================================
|
| 27 |
-
# Banner
|
| 28 |
-
# ===================================================================
|
| 29 |
-
|
| 30 |
-
BANNER = r"""
|
| 31 |
-
╔══════════════════════════════════════════════════════════════╗
|
| 32 |
-
║ ║
|
| 33 |
-
║ ███████ █████ ██████ ███████ ║
|
| 34 |
-
║ ██ ██ ██ ██ ██ ║
|
| 35 |
-
║ ███████ ███████ ██ ███ █████ ║
|
| 36 |
-
║ ██ ██ ██ ██ ██ ██ ║
|
| 37 |
-
║ ███████ ██ ██ ██████ ███████ ║
|
| 38 |
-
║ ║
|
| 39 |
-
║ Self-Adaptive General Engine v{version} ║
|
| 40 |
-
║ ║
|
| 41 |
-
╚══════════════════════════════════════════════════════════════╝
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def print_banner(model: SageModel, config: SageConfig) -> None:
|
| 46 |
-
"""Display startup banner with model statistics."""
|
| 47 |
-
base_model = getattr(model, "module", model)
|
| 48 |
-
total_params = sum(p.numel() for p in base_model.parameters())
|
| 49 |
-
trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
|
| 50 |
-
|
| 51 |
-
print(BANNER.format(version=__version__))
|
| 52 |
-
print(f" Model params : {total_params:,} ({total_params/1e6:.1f}M)")
|
| 53 |
-
print(f" Trainable : {trainable_params:,}")
|
| 54 |
-
print(f" Context length: {config.max_seq_len}")
|
| 55 |
-
print(f" Device : {config.device}")
|
| 56 |
-
print(f" Layers: {config.n_layers} | Heads: {config.n_heads} | Experts: {config.n_experts}")
|
| 57 |
-
print()
|
| 58 |
-
print(" Type /help for commands, or start chatting!\n")
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# ===================================================================
|
| 62 |
-
# Help text
|
| 63 |
-
# ===================================================================
|
| 64 |
-
|
| 65 |
-
HELP_TEXT = """
|
| 66 |
-
Available Commands:
|
| 67 |
-
/train [steps] Train the model (default: 100 steps)
|
| 68 |
-
/finetune [steps] Instruction-tune with LoRA (default: 200 steps)
|
| 69 |
-
/save Save current model checkpoint
|
| 70 |
-
/load Load latest checkpoint
|
| 71 |
-
/quantize Quantize model to INT8 (CPU only)
|
| 72 |
-
/rag on|off Enable/disable retrieval-augmented generation
|
| 73 |
-
/rag add <text> Add a document for RAG retrieval
|
| 74 |
-
/clear Clear conversation history
|
| 75 |
-
/help Show this message
|
| 76 |
-
/exit Exit SAGE
|
| 77 |
-
"""
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
# ===================================================================
|
| 81 |
-
# Command handlers
|
| 82 |
-
# ===================================================================
|
| 83 |
-
|
| 84 |
-
def handle_train(model, config, tokenizer, args):
|
| 85 |
-
"""Handle /train [steps]"""
|
| 86 |
-
steps = 100
|
| 87 |
-
if args:
|
| 88 |
-
try:
|
| 89 |
-
steps = int(args[0])
|
| 90 |
-
except ValueError:
|
| 91 |
-
print(f" Invalid step count: {args[0]}")
|
| 92 |
-
return model
|
| 93 |
-
|
| 94 |
-
print(f"\n Starting training for {steps} steps …\n")
|
| 95 |
-
model = train(model, config, total_steps=steps, tokenizer=tokenizer, resume=True)
|
| 96 |
-
|
| 97 |
-
# Show a quick sample after training
|
| 98 |
-
print("\n --- Sample generation after training ---")
|
| 99 |
-
generate(model, tokenizer, "Once upon a time", max_new_tokens=80, stream=True, device=config.device)
|
| 100 |
-
print()
|
| 101 |
-
return model
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def handle_finetune(model, config, tokenizer, args):
|
| 105 |
-
"""Handle /finetune [steps]"""
|
| 106 |
-
steps = 200
|
| 107 |
-
if args:
|
| 108 |
-
try:
|
| 109 |
-
steps = int(args[0])
|
| 110 |
-
except ValueError:
|
| 111 |
-
print(f" Invalid step count: {args[0]}")
|
| 112 |
-
return model
|
| 113 |
-
|
| 114 |
-
print(f"\n Starting instruction fine-tuning for {steps} steps (LoRA) …\n")
|
| 115 |
-
model = finetune_instruction(
|
| 116 |
-
model, config,
|
| 117 |
-
samples=DEMO_INSTRUCTION_SAMPLES,
|
| 118 |
-
total_steps=steps,
|
| 119 |
-
use_lora=True,
|
| 120 |
-
tokenizer=tokenizer,
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
print("\n --- Sample after fine-tuning ---")
|
| 124 |
-
prompt = "### Instruction:\nWhat is the speed of light?\n\n### Response:\n"
|
| 125 |
-
generate(model, tokenizer, prompt, max_new_tokens=100, stream=True, device=config.device)
|
| 126 |
-
print()
|
| 127 |
-
return model
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def handle_save(model, config):
|
| 131 |
-
"""Handle /save"""
|
| 132 |
-
path = save_checkpoint(model, None, 0, config.checkpoint_dir)
|
| 133 |
-
print(f" Model saved to {path}")
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def handle_load(model, config):
|
| 137 |
-
"""Handle /load"""
|
| 138 |
-
model, _, step = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device))
|
| 139 |
-
model = model.to(config.device)
|
| 140 |
-
print(f" Model loaded (step {step})")
|
| 141 |
-
return model
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def handle_quantize(model):
|
| 145 |
-
"""Handle /quantize"""
|
| 146 |
-
print(" Quantizing model to INT8 (model will be on CPU) …")
|
| 147 |
-
model = quantize_int8(model)
|
| 148 |
-
print(" Quantization complete.")
|
| 149 |
-
return model
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
def handle_rag(rag_manager: RAGManager, args):
|
| 153 |
-
"""Handle /rag on|off|add <text>"""
|
| 154 |
-
if not args:
|
| 155 |
-
state = "enabled" if rag_manager.enabled else "disabled"
|
| 156 |
-
print(f" RAG is currently {state} ({rag_manager.store.size} chunks indexed)")
|
| 157 |
-
return
|
| 158 |
-
|
| 159 |
-
subcmd = args[0].lower()
|
| 160 |
-
if subcmd == "on":
|
| 161 |
-
rag_manager.toggle(True)
|
| 162 |
-
print(" RAG enabled.")
|
| 163 |
-
elif subcmd == "off":
|
| 164 |
-
rag_manager.toggle(False)
|
| 165 |
-
print(" RAG disabled.")
|
| 166 |
-
elif subcmd == "add":
|
| 167 |
-
text = " ".join(args[1:])
|
| 168 |
-
if text:
|
| 169 |
-
rag_manager.add_documents([text])
|
| 170 |
-
print(f" Document added. Store now has {rag_manager.store.size} chunks.")
|
| 171 |
-
else:
|
| 172 |
-
print(" Usage: /rag add <your document text here>")
|
| 173 |
-
else:
|
| 174 |
-
print(" Usage: /rag on|off|add <text>")
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
# ===================================================================
|
| 178 |
-
# Main REPL
|
| 179 |
-
# ===================================================================
|
| 180 |
-
|
| 181 |
-
def main() -> None:
|
| 182 |
-
"""Entry point for the SAGE interactive CLI."""
|
| 183 |
-
config = SageConfig()
|
| 184 |
-
tokenizer = SageTokenizer()
|
| 185 |
-
|
| 186 |
-
# Ensure vocab_size matches the tokenizer
|
| 187 |
-
config.vocab_size = tokenizer.vocab_size
|
| 188 |
-
|
| 189 |
-
print(" Initializing SAGE model …")
|
| 190 |
-
model = SageModel(config)
|
| 191 |
-
model = model.to(config.device)
|
| 192 |
-
|
| 193 |
-
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
| 194 |
-
print(f" Multi-GPU detected! Wrapping model in DataParallel across {torch.cuda.device_count()} GPUs.")
|
| 195 |
-
model = torch.nn.DataParallel(model)
|
| 196 |
-
|
| 197 |
-
# Attempt to load existing checkpoint
|
| 198 |
-
model, _, loaded_step = load_checkpoint(
|
| 199 |
-
model, None, config.checkpoint_dir, device=str(config.device)
|
| 200 |
-
)
|
| 201 |
-
if loaded_step > 0:
|
| 202 |
-
print(f" Resumed from checkpoint at step {loaded_step}")
|
| 203 |
-
|
| 204 |
-
print_banner(model, config)
|
| 205 |
-
|
| 206 |
-
# Initialize RAG and conversation history
|
| 207 |
-
rag_manager = RAGManager(model, tokenizer, config.device)
|
| 208 |
-
history = ConversationHistory(tokenizer, max_tokens=config.max_seq_len - 128)
|
| 209 |
-
|
| 210 |
-
# ---------- One-liner CLI arguments ----------
|
| 211 |
-
if len(sys.argv) > 1:
|
| 212 |
-
cmd = sys.argv[1].lower()
|
| 213 |
-
args = sys.argv[2:]
|
| 214 |
-
if cmd == "--train":
|
| 215 |
-
handle_train(model, config, tokenizer, args)
|
| 216 |
-
elif cmd == "--finetune":
|
| 217 |
-
handle_finetune(model, config, tokenizer, args)
|
| 218 |
-
elif cmd == "--quantize":
|
| 219 |
-
handle_quantize(model)
|
| 220 |
-
else:
|
| 221 |
-
print(f" Unknown argument: {cmd}")
|
| 222 |
-
print(" Usage: --train [steps] | --finetune [steps] | --quantize")
|
| 223 |
-
return
|
| 224 |
-
|
| 225 |
-
# ---------- REPL loop ----------
|
| 226 |
-
while True:
|
| 227 |
-
try:
|
| 228 |
-
user_input = input("You: ").strip()
|
| 229 |
-
except (EOFError, KeyboardInterrupt):
|
| 230 |
-
print("\n Goodbye!")
|
| 231 |
-
break
|
| 232 |
-
|
| 233 |
-
if not user_input:
|
| 234 |
-
continue
|
| 235 |
-
|
| 236 |
-
# ---------- Slash commands ----------
|
| 237 |
-
if user_input.startswith("/"):
|
| 238 |
-
parts = user_input.split()
|
| 239 |
-
cmd = parts[0].lower()
|
| 240 |
-
args = parts[1:]
|
| 241 |
-
|
| 242 |
-
if cmd == "/exit":
|
| 243 |
-
print(" Goodbye!")
|
| 244 |
-
break
|
| 245 |
-
elif cmd == "/help":
|
| 246 |
-
print(HELP_TEXT)
|
| 247 |
-
elif cmd == "/train":
|
| 248 |
-
model = handle_train(model, config, tokenizer, args)
|
| 249 |
-
elif cmd == "/finetune":
|
| 250 |
-
model = handle_finetune(model, config, tokenizer, args)
|
| 251 |
-
elif cmd == "/save":
|
| 252 |
-
handle_save(model, config)
|
| 253 |
-
elif cmd == "/load":
|
| 254 |
-
model = handle_load(model, config)
|
| 255 |
-
# Re-attach to RAG manager since model changed
|
| 256 |
-
rag_manager.model = model
|
| 257 |
-
elif cmd == "/quantize":
|
| 258 |
-
model = handle_quantize(model)
|
| 259 |
-
rag_manager.model = model
|
| 260 |
-
elif cmd == "/rag":
|
| 261 |
-
handle_rag(rag_manager, args)
|
| 262 |
-
elif cmd == "/clear":
|
| 263 |
-
history.clear()
|
| 264 |
-
print(" Conversation history cleared.")
|
| 265 |
-
else:
|
| 266 |
-
print(f" Unknown command: {cmd}. Type /help for a list.")
|
| 267 |
-
continue
|
| 268 |
-
|
| 269 |
-
# ---------- Chat mode ----------
|
| 270 |
-
# Build prompt with history and optional RAG context
|
| 271 |
-
rag_context = rag_manager.retrieve_context(user_input)
|
| 272 |
-
prompt = history.build_prompt(user_input, rag_context=rag_context)
|
| 273 |
-
|
| 274 |
-
history.add("user", user_input)
|
| 275 |
-
|
| 276 |
-
print("SAGE: ", end="", flush=True)
|
| 277 |
-
response = generate(
|
| 278 |
-
model,
|
| 279 |
-
tokenizer,
|
| 280 |
-
prompt,
|
| 281 |
-
max_new_tokens=256,
|
| 282 |
-
temperature=0.8,
|
| 283 |
-
top_k=50,
|
| 284 |
-
top_p=0.9,
|
| 285 |
-
stream=True,
|
| 286 |
-
device=config.device,
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
-
# Extract only the SAGE response part from the full generation
|
| 290 |
-
if "SAGE:" in response:
|
| 291 |
-
reply = response.split("SAGE:")[-1].strip()
|
| 292 |
-
else:
|
| 293 |
-
reply = response[len(prompt):].strip()
|
| 294 |
-
|
| 295 |
-
history.add("assistant", reply)
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
if __name__ == "__main__":
|
| 299 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/config.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from dataclasses import dataclass, field
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
-
@dataclass
|
| 6 |
-
class SageConfig:
|
| 7 |
-
# Model dimensions corresponding to T4 (16GB VRAM) fit
|
| 8 |
-
d_model: int = 512
|
| 9 |
-
n_heads: int = 8
|
| 10 |
-
n_kv_heads: int = 4 # GQA: must divide n_heads
|
| 11 |
-
n_layers: int = 6
|
| 12 |
-
d_ff: int = 2048
|
| 13 |
-
|
| 14 |
-
# MoE (Mixture of Experts) config
|
| 15 |
-
n_experts: int = 4
|
| 16 |
-
num_experts_per_tok: int = 2
|
| 17 |
-
|
| 18 |
-
# Vocabulary and sequence parameters
|
| 19 |
-
vocab_size: int = 100277 # Default for tiktoken "cl100k_base"
|
| 20 |
-
max_seq_len: int = 1024
|
| 21 |
-
|
| 22 |
-
# Regularization
|
| 23 |
-
dropout: float = 0.1
|
| 24 |
-
|
| 25 |
-
# Training Loop defaults
|
| 26 |
-
batch_size: int = 4
|
| 27 |
-
gradient_accumulation_steps: int = 16
|
| 28 |
-
learning_rate: float = 3e-4
|
| 29 |
-
min_learning_rate: float = 1e-5
|
| 30 |
-
warmup_steps: int = 100
|
| 31 |
-
weight_decay: float = 0.01
|
| 32 |
-
max_grad_norm: float = 1.0
|
| 33 |
-
|
| 34 |
-
# Checkpointing and path details
|
| 35 |
-
checkpoint_dir: str = "checkpoints"
|
| 36 |
-
project_name: str = "sage-v2"
|
| 37 |
-
|
| 38 |
-
# Cache for device (set on first access)
|
| 39 |
-
_device: Any = field(default=None, repr=False)
|
| 40 |
-
|
| 41 |
-
@property
|
| 42 |
-
def device(self):
|
| 43 |
-
"""Returns the best available device with CUDA compatibility checking."""
|
| 44 |
-
if self._device is None:
|
| 45 |
-
# Import here to avoid circular imports
|
| 46 |
-
from .utils import get_compatible_device
|
| 47 |
-
self._device = get_compatible_device()
|
| 48 |
-
return self._device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/data.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE Data Pipeline
|
| 3 |
-
==================
|
| 4 |
-
Handles tokenization (tiktoken), streaming dataset loading from HuggingFace,
|
| 5 |
-
text cleaning, chunking into fixed-length sequences, and batched DataLoader
|
| 6 |
-
construction with shuffle buffering.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import re
|
| 10 |
-
import random
|
| 11 |
-
import tiktoken
|
| 12 |
-
import torch
|
| 13 |
-
from torch.utils.data import IterableDataset, DataLoader
|
| 14 |
-
from typing import Iterator, List, Optional
|
| 15 |
-
from .config import SageConfig
|
| 16 |
-
from .utils import setup_logger
|
| 17 |
-
|
| 18 |
-
logger = setup_logger("sage.data")
|
| 19 |
-
|
| 20 |
-
# ---------------------------------------------------------------------------
|
| 21 |
-
# Tokenizer wrapper
|
| 22 |
-
# ---------------------------------------------------------------------------
|
| 23 |
-
|
| 24 |
-
class SageTokenizer:
|
| 25 |
-
"""Thin wrapper around tiktoken providing encode/decode and special tokens."""
|
| 26 |
-
|
| 27 |
-
def __init__(self, encoding_name: str = "cl100k_base"):
|
| 28 |
-
self.enc = tiktoken.get_encoding(encoding_name)
|
| 29 |
-
self.encoding_name = encoding_name
|
| 30 |
-
|
| 31 |
-
# Use the last token in the vocabulary as the EOS sentinel.
|
| 32 |
-
# tiktoken doesn't expose a dedicated EOS, so we pick one that
|
| 33 |
-
# won't collide with real text.
|
| 34 |
-
self.eos_token_id: int = self.enc.n_vocab - 1
|
| 35 |
-
self.pad_token_id: int = self.enc.n_vocab - 2
|
| 36 |
-
self.vocab_size: int = self.enc.n_vocab
|
| 37 |
-
|
| 38 |
-
def encode(self, text: str, add_eos: bool = False) -> List[int]:
|
| 39 |
-
"""Encode text to token IDs."""
|
| 40 |
-
tokens = self.enc.encode(text, allowed_special="all")
|
| 41 |
-
if add_eos:
|
| 42 |
-
tokens.append(self.eos_token_id)
|
| 43 |
-
return tokens
|
| 44 |
-
|
| 45 |
-
def decode(self, tokens: List[int]) -> str:
|
| 46 |
-
"""Decode token IDs back to text, filtering out special sentinel IDs."""
|
| 47 |
-
# Filter out our custom pad/eos sentinels before decoding
|
| 48 |
-
filtered = [t for t in tokens if t not in (self.eos_token_id, self.pad_token_id)]
|
| 49 |
-
return self.enc.decode(filtered)
|
| 50 |
-
|
| 51 |
-
# ---------------------------------------------------------------------------
|
| 52 |
-
# Text cleaning
|
| 53 |
-
# ---------------------------------------------------------------------------
|
| 54 |
-
|
| 55 |
-
_HTML_TAG_RE = re.compile(r"<[^>]+>")
|
| 56 |
-
_MULTI_SPACE_RE = re.compile(r"[ \t]+")
|
| 57 |
-
_MULTI_NEWLINE_RE = re.compile(r"\n{3,}")
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def clean_text(text: str) -> str:
|
| 61 |
-
"""Strip HTML tags, collapse whitespace, and trim to reasonable length."""
|
| 62 |
-
text = _HTML_TAG_RE.sub("", text) # remove HTML tags
|
| 63 |
-
text = _MULTI_SPACE_RE.sub(" ", text) # collapse horizontal whitespace
|
| 64 |
-
text = _MULTI_NEWLINE_RE.sub("\n\n", text) # collapse vertical whitespace
|
| 65 |
-
return text.strip()
|
| 66 |
-
|
| 67 |
-
# ---------------------------------------------------------------------------
|
| 68 |
-
# Streaming iterable dataset
|
| 69 |
-
# ---------------------------------------------------------------------------
|
| 70 |
-
|
| 71 |
-
class StreamingTextDataset(IterableDataset):
|
| 72 |
-
"""
|
| 73 |
-
An IterableDataset that streams data from HuggingFace ``datasets``,
|
| 74 |
-
tokenizes on the fly, and yields fixed-length chunks.
|
| 75 |
-
|
| 76 |
-
It maintains an internal shuffle buffer so that consecutive chunks are
|
| 77 |
-
not always from the same document.
|
| 78 |
-
"""
|
| 79 |
-
|
| 80 |
-
def __init__(
|
| 81 |
-
self,
|
| 82 |
-
dataset_name: str = "HuggingFaceFW/fineweb-edu",
|
| 83 |
-
split: str = "train",
|
| 84 |
-
seq_len: int = 512,
|
| 85 |
-
tokenizer: Optional[SageTokenizer] = None,
|
| 86 |
-
shuffle_buffer_size: int = 1000,
|
| 87 |
-
text_field: str = "text",
|
| 88 |
-
min_doc_len: int = 50,
|
| 89 |
-
max_doc_len: int = 50000,
|
| 90 |
-
):
|
| 91 |
-
super().__init__()
|
| 92 |
-
self.dataset_name = dataset_name
|
| 93 |
-
self.split = split
|
| 94 |
-
self.seq_len = seq_len
|
| 95 |
-
self.tokenizer = tokenizer or SageTokenizer()
|
| 96 |
-
self.shuffle_buffer_size = shuffle_buffer_size
|
| 97 |
-
self.text_field = text_field
|
| 98 |
-
self.min_doc_len = min_doc_len
|
| 99 |
-
self.max_doc_len = max_doc_len
|
| 100 |
-
|
| 101 |
-
# Auto-adjust configuration based on popular datasets
|
| 102 |
-
if "fineweb-edu" in dataset_name.lower():
|
| 103 |
-
self.text_field = "text"
|
| 104 |
-
self.split = "train" if split == "train" else split
|
| 105 |
-
elif "tinystories" in dataset_name.lower():
|
| 106 |
-
self.text_field = "text"
|
| 107 |
-
|
| 108 |
-
def _stream_tokens(self) -> Iterator[int]:
|
| 109 |
-
"""Yields individual token IDs from the HuggingFace dataset stream."""
|
| 110 |
-
try:
|
| 111 |
-
from datasets import load_dataset
|
| 112 |
-
except ImportError:
|
| 113 |
-
raise ImportError(
|
| 114 |
-
"The 'datasets' library is required. Install it with: "
|
| 115 |
-
"pip install datasets"
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
logger.info(
|
| 119 |
-
f"Streaming dataset '{self.dataset_name}' (split={self.split}) …"
|
| 120 |
-
)
|
| 121 |
-
ds = load_dataset(
|
| 122 |
-
self.dataset_name,
|
| 123 |
-
split=self.split,
|
| 124 |
-
streaming=True,
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
for sample in ds:
|
| 128 |
-
raw = sample.get(self.text_field, "")
|
| 129 |
-
if not raw:
|
| 130 |
-
continue
|
| 131 |
-
|
| 132 |
-
text = clean_text(raw)
|
| 133 |
-
|
| 134 |
-
# Filter documents that are too short or too long
|
| 135 |
-
if len(text) < self.min_doc_len or len(text) > self.max_doc_len:
|
| 136 |
-
continue
|
| 137 |
-
|
| 138 |
-
tokens = self.tokenizer.encode(text, add_eos=True)
|
| 139 |
-
yield from tokens
|
| 140 |
-
|
| 141 |
-
def _chunk_tokens(self) -> Iterator[torch.Tensor]:
|
| 142 |
-
"""Groups raw token stream into fixed-length chunks of (seq_len + 1).
|
| 143 |
-
|
| 144 |
-
The extra token is needed so that input = chunk[:-1] and
|
| 145 |
-
target = chunk[1:] for next-token-prediction.
|
| 146 |
-
"""
|
| 147 |
-
chunk: List[int] = []
|
| 148 |
-
for tok in self._stream_tokens():
|
| 149 |
-
chunk.append(tok)
|
| 150 |
-
if len(chunk) == self.seq_len + 1:
|
| 151 |
-
yield torch.tensor(chunk, dtype=torch.long)
|
| 152 |
-
chunk = []
|
| 153 |
-
# Discard any trailing partial chunk
|
| 154 |
-
|
| 155 |
-
def __iter__(self) -> Iterator[torch.Tensor]:
|
| 156 |
-
"""Yields shuffled chunks from an internal buffer."""
|
| 157 |
-
buffer: List[torch.Tensor] = []
|
| 158 |
-
for chunk in self._chunk_tokens():
|
| 159 |
-
buffer.append(chunk)
|
| 160 |
-
if len(buffer) >= self.shuffle_buffer_size:
|
| 161 |
-
random.shuffle(buffer)
|
| 162 |
-
while len(buffer) > self.shuffle_buffer_size // 2:
|
| 163 |
-
yield buffer.pop()
|
| 164 |
-
# Flush remaining items
|
| 165 |
-
random.shuffle(buffer)
|
| 166 |
-
yield from buffer
|
| 167 |
-
|
| 168 |
-
# ---------------------------------------------------------------------------
|
| 169 |
-
# DataLoader factory
|
| 170 |
-
# ---------------------------------------------------------------------------
|
| 171 |
-
|
| 172 |
-
def create_dataloader(
|
| 173 |
-
config: SageConfig,
|
| 174 |
-
dataset_name: str = "HuggingFaceFW/fineweb-edu",
|
| 175 |
-
split: str = "train",
|
| 176 |
-
tokenizer: Optional[SageTokenizer] = None,
|
| 177 |
-
) -> DataLoader:
|
| 178 |
-
"""Creates a streaming DataLoader ready for the training loop."""
|
| 179 |
-
tok = tokenizer or SageTokenizer()
|
| 180 |
-
ds = StreamingTextDataset(
|
| 181 |
-
dataset_name=dataset_name,
|
| 182 |
-
split=split,
|
| 183 |
-
seq_len=config.max_seq_len,
|
| 184 |
-
tokenizer=tok,
|
| 185 |
-
)
|
| 186 |
-
return DataLoader(
|
| 187 |
-
ds,
|
| 188 |
-
batch_size=config.batch_size,
|
| 189 |
-
num_workers=2,
|
| 190 |
-
pin_memory=True,
|
| 191 |
-
drop_last=True,
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
# ---------------------------------------------------------------------------
|
| 195 |
-
# Instruction-tuning data helpers
|
| 196 |
-
# ---------------------------------------------------------------------------
|
| 197 |
-
|
| 198 |
-
INSTRUCTION_TEMPLATE = (
|
| 199 |
-
"### Instruction:\n{instruction}\n\n### Response:\n{response}"
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
def format_instruction_sample(instruction: str, response: str) -> str:
|
| 204 |
-
"""Formats an instruction/response pair into the chat template."""
|
| 205 |
-
return INSTRUCTION_TEMPLATE.format(
|
| 206 |
-
instruction=instruction.strip(),
|
| 207 |
-
response=response.strip(),
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
def create_instruction_batch(
|
| 212 |
-
samples: List[dict],
|
| 213 |
-
tokenizer: SageTokenizer,
|
| 214 |
-
max_len: int = 512,
|
| 215 |
-
) -> dict:
|
| 216 |
-
"""
|
| 217 |
-
Tokenize a list of {instruction, response} dicts and produce input_ids,
|
| 218 |
-
labels, and a loss_mask that zeros out the instruction portion.
|
| 219 |
-
|
| 220 |
-
Returns a dict with keys: input_ids, labels, loss_mask — all as tensors.
|
| 221 |
-
"""
|
| 222 |
-
all_input_ids: List[List[int]] = []
|
| 223 |
-
all_labels: List[List[int]] = []
|
| 224 |
-
all_masks: List[List[int]] = []
|
| 225 |
-
|
| 226 |
-
for sample in samples:
|
| 227 |
-
instruction_text = f"### Instruction:\n{sample['instruction'].strip()}\n\n### Response:\n"
|
| 228 |
-
response_text = sample["response"].strip()
|
| 229 |
-
full_text = instruction_text + response_text
|
| 230 |
-
|
| 231 |
-
instruction_tokens = tokenizer.encode(instruction_text)
|
| 232 |
-
full_tokens = tokenizer.encode(full_text, add_eos=True)
|
| 233 |
-
|
| 234 |
-
# Truncate to max_len
|
| 235 |
-
full_tokens = full_tokens[:max_len]
|
| 236 |
-
n_instruction = min(len(instruction_tokens), len(full_tokens))
|
| 237 |
-
|
| 238 |
-
# Labels are the same as input shifted by 1 (handled by caller),
|
| 239 |
-
# but we need a mask to zero out loss on instruction tokens.
|
| 240 |
-
mask = [0] * n_instruction + [1] * (len(full_tokens) - n_instruction)
|
| 241 |
-
|
| 242 |
-
# Pad to max_len
|
| 243 |
-
pad_len = max_len - len(full_tokens)
|
| 244 |
-
full_tokens += [tokenizer.pad_token_id] * pad_len
|
| 245 |
-
mask += [0] * pad_len
|
| 246 |
-
|
| 247 |
-
all_input_ids.append(full_tokens)
|
| 248 |
-
all_labels.append(full_tokens) # shift will be done in the loss fn
|
| 249 |
-
all_masks.append(mask)
|
| 250 |
-
|
| 251 |
-
return {
|
| 252 |
-
"input_ids": torch.tensor(all_input_ids, dtype=torch.long),
|
| 253 |
-
"labels": torch.tensor(all_labels, dtype=torch.long),
|
| 254 |
-
"loss_mask": torch.tensor(all_masks, dtype=torch.float32),
|
| 255 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/finetune.py
DELETED
|
@@ -1,268 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE Fine-Tuning
|
| 3 |
-
================
|
| 4 |
-
Provides two fine-tuning modes:
|
| 5 |
-
|
| 6 |
-
1. **Instruction tuning** — trains on instruction/response pairs with loss
|
| 7 |
-
masked on the instruction portion.
|
| 8 |
-
2. **LoRA (Low-Rank Adaptation)** — injects small trainable matrices into
|
| 9 |
-
attention layers while keeping the base model frozen.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import math
|
| 13 |
-
import time
|
| 14 |
-
import copy
|
| 15 |
-
import torch
|
| 16 |
-
import torch.nn as nn
|
| 17 |
-
from torch.amp import GradScaler, autocast
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
-
import wandb
|
| 20 |
-
from typing import Optional, List
|
| 21 |
-
|
| 22 |
-
from .config import SageConfig
|
| 23 |
-
from .model import SageModel, CausalSelfAttention
|
| 24 |
-
from .data import SageTokenizer, create_instruction_batch
|
| 25 |
-
from .train import create_optimizer, get_lr, set_lr
|
| 26 |
-
from .utils import setup_logger, save_checkpoint
|
| 27 |
-
|
| 28 |
-
logger = setup_logger("sage.finetune")
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# ===================================================================
|
| 32 |
-
# LoRA Implementation
|
| 33 |
-
# ===================================================================
|
| 34 |
-
|
| 35 |
-
class LoRALinear(nn.Module):
|
| 36 |
-
"""
|
| 37 |
-
Wraps an existing ``nn.Linear`` with a low-rank adapter (A @ B).
|
| 38 |
-
|
| 39 |
-
During fine-tuning only *A* and *B* are trained; the original weight
|
| 40 |
-
is frozen. After fine-tuning the adapter can be merged back into
|
| 41 |
-
the original weight for zero-overhead inference.
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
def __init__(self, original: nn.Linear, rank: int = 8, alpha: float = 16.0):
|
| 45 |
-
super().__init__()
|
| 46 |
-
self.original = original
|
| 47 |
-
self.rank = rank
|
| 48 |
-
self.alpha = alpha
|
| 49 |
-
self.scaling = alpha / rank
|
| 50 |
-
|
| 51 |
-
in_features = original.in_features
|
| 52 |
-
out_features = original.out_features
|
| 53 |
-
|
| 54 |
-
# Low-rank matrices
|
| 55 |
-
device, dtype = original.weight.device, original.weight.dtype
|
| 56 |
-
self.lora_A = nn.Parameter(torch.randn(in_features, rank, device=device, dtype=dtype) * 0.01)
|
| 57 |
-
self.lora_B = nn.Parameter(torch.zeros(rank, out_features, device=device, dtype=dtype))
|
| 58 |
-
|
| 59 |
-
# Freeze the original weight
|
| 60 |
-
self.original.weight.requires_grad = False
|
| 61 |
-
if self.original.bias is not None:
|
| 62 |
-
self.original.bias.requires_grad = False
|
| 63 |
-
|
| 64 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 65 |
-
"""original(x) + x @ A @ B * scaling"""
|
| 66 |
-
base_out = self.original(x)
|
| 67 |
-
lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling
|
| 68 |
-
return base_out + lora_out
|
| 69 |
-
|
| 70 |
-
def merge(self) -> nn.Linear:
|
| 71 |
-
"""Merge LoRA weights back into the original linear layer."""
|
| 72 |
-
merged = copy.deepcopy(self.original)
|
| 73 |
-
merged.weight.data += (self.lora_B.T @ self.lora_A.T).T * self.scaling
|
| 74 |
-
merged.weight.requires_grad = True
|
| 75 |
-
return merged
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# ---------------------------------------------------------------------------
|
| 79 |
-
# LoRA injection / removal helpers
|
| 80 |
-
# ---------------------------------------------------------------------------
|
| 81 |
-
|
| 82 |
-
def inject_lora(model: SageModel, rank: int = 8, alpha: float = 16.0) -> SageModel:
|
| 83 |
-
"""
|
| 84 |
-
Replace the Q, K, V, O projection layers in every attention block with
|
| 85 |
-
LoRA-wrapped versions. Returns the same model (mutated in-place).
|
| 86 |
-
"""
|
| 87 |
-
base_model = getattr(model, "module", model)
|
| 88 |
-
for layer in base_model.layers:
|
| 89 |
-
attn: CausalSelfAttention = layer.attn
|
| 90 |
-
attn.wq = LoRALinear(attn.wq, rank=rank, alpha=alpha)
|
| 91 |
-
attn.wk = LoRALinear(attn.wk, rank=rank, alpha=alpha)
|
| 92 |
-
attn.wv = LoRALinear(attn.wv, rank=rank, alpha=alpha)
|
| 93 |
-
attn.wo = LoRALinear(attn.wo, rank=rank, alpha=alpha)
|
| 94 |
-
|
| 95 |
-
# Log trainable parameter count
|
| 96 |
-
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 97 |
-
total = sum(p.numel() for p in model.parameters())
|
| 98 |
-
logger.info(
|
| 99 |
-
f"LoRA injected (rank={rank}). Trainable: {trainable:,} / {total:,} "
|
| 100 |
-
f"({100 * trainable / total:.2f}%)"
|
| 101 |
-
)
|
| 102 |
-
return model
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def merge_lora(model: SageModel) -> SageModel:
|
| 106 |
-
"""
|
| 107 |
-
Merge all LoRA adapters back into the base weights and replace the
|
| 108 |
-
LoRALinear wrappers with plain nn.Linear modules.
|
| 109 |
-
"""
|
| 110 |
-
base_model = getattr(model, "module", model)
|
| 111 |
-
for layer in base_model.layers:
|
| 112 |
-
attn: CausalSelfAttention = layer.attn
|
| 113 |
-
for name in ("wq", "wk", "wv", "wo"):
|
| 114 |
-
module = getattr(attn, name)
|
| 115 |
-
if isinstance(module, LoRALinear):
|
| 116 |
-
setattr(attn, name, module.merge())
|
| 117 |
-
logger.info("LoRA weights merged into base model.")
|
| 118 |
-
return model
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
# ===================================================================
|
| 122 |
-
# Instruction fine-tuning loop
|
| 123 |
-
# ===================================================================
|
| 124 |
-
|
| 125 |
-
def finetune_instruction(
|
| 126 |
-
model: SageModel,
|
| 127 |
-
config: SageConfig,
|
| 128 |
-
samples: List[dict],
|
| 129 |
-
total_steps: int = 200,
|
| 130 |
-
use_lora: bool = True,
|
| 131 |
-
lora_rank: int = 8,
|
| 132 |
-
tokenizer: Optional[SageTokenizer] = None,
|
| 133 |
-
) -> SageModel:
|
| 134 |
-
"""
|
| 135 |
-
Fine-tune the model on instruction/response pairs.
|
| 136 |
-
|
| 137 |
-
Parameters
|
| 138 |
-
----------
|
| 139 |
-
model : SageModel
|
| 140 |
-
config : SageConfig
|
| 141 |
-
samples : list[dict]
|
| 142 |
-
Each dict must contain ``instruction`` and ``response`` string keys.
|
| 143 |
-
total_steps : int
|
| 144 |
-
use_lora : bool
|
| 145 |
-
If True, inject LoRA adapters before training.
|
| 146 |
-
lora_rank : int
|
| 147 |
-
tokenizer : SageTokenizer, optional
|
| 148 |
-
|
| 149 |
-
Returns
|
| 150 |
-
-------
|
| 151 |
-
SageModel — the fine-tuned model (LoRA merged if applicable).
|
| 152 |
-
"""
|
| 153 |
-
# --- TURBO MODE: TF32 & COMPILE ---
|
| 154 |
-
if torch.cuda.is_available():
|
| 155 |
-
torch.set_float32_matmul_precision('high')
|
| 156 |
-
|
| 157 |
-
device = config.device
|
| 158 |
-
model = model.to(device)
|
| 159 |
-
|
| 160 |
-
# Wrap model with torch.compile for graph-level optimization
|
| 161 |
-
if hasattr(torch, "compile"):
|
| 162 |
-
logger.info("Turbo Mode: Compiling fine-tune engine...")
|
| 163 |
-
base = getattr(model, "module", model)
|
| 164 |
-
compiled_base = torch.compile(base, mode="reduce-overhead")
|
| 165 |
-
if hasattr(model, "module"):
|
| 166 |
-
model.module = compiled_base
|
| 167 |
-
else:
|
| 168 |
-
model = compiled_base
|
| 169 |
-
|
| 170 |
-
tok = tokenizer or SageTokenizer()
|
| 171 |
-
|
| 172 |
-
if use_lora:
|
| 173 |
-
model = inject_lora(model, rank=lora_rank)
|
| 174 |
-
|
| 175 |
-
# ------- W&B Logging -------
|
| 176 |
-
wandb.init(
|
| 177 |
-
project=config.project_name,
|
| 178 |
-
name=f"finetune-{time.strftime('%Y%m%d-%H%M')}",
|
| 179 |
-
config=config.__dict__,
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
optimizer = create_optimizer(model, config)
|
| 183 |
-
|
| 184 |
-
# AMP setup
|
| 185 |
-
use_amp = device.type == "cuda"
|
| 186 |
-
amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16
|
| 187 |
-
scaler = GradScaler("cuda", enabled=(use_amp and amp_dtype == torch.float16))
|
| 188 |
-
|
| 189 |
-
model.train()
|
| 190 |
-
pbar = tqdm(range(total_steps), desc="Fine-tuning", unit="step")
|
| 191 |
-
accum_loss = 0.0
|
| 192 |
-
|
| 193 |
-
for step in pbar:
|
| 194 |
-
lr = get_lr(step, config, total_steps)
|
| 195 |
-
set_lr(optimizer, lr)
|
| 196 |
-
|
| 197 |
-
# Build a batch by sampling from the instruction dataset
|
| 198 |
-
batch_size = min(config.batch_size, len(samples))
|
| 199 |
-
import random
|
| 200 |
-
batch_samples = random.choices(samples, k=batch_size)
|
| 201 |
-
batch = create_instruction_batch(batch_samples, tok, max_len=config.max_seq_len)
|
| 202 |
-
|
| 203 |
-
input_ids = batch["input_ids"].to(device)
|
| 204 |
-
labels = batch["labels"].to(device)
|
| 205 |
-
loss_mask = batch["loss_mask"].to(device)
|
| 206 |
-
|
| 207 |
-
optimizer.zero_grad(set_to_none=True)
|
| 208 |
-
|
| 209 |
-
with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
|
| 210 |
-
logits, _ = model(input_ids)
|
| 211 |
-
# Shift: predict next token
|
| 212 |
-
shift_logits = logits[:, :-1, :].contiguous()
|
| 213 |
-
shift_labels = labels[:, 1:].contiguous()
|
| 214 |
-
shift_mask = loss_mask[:, 1:].contiguous()
|
| 215 |
-
|
| 216 |
-
# Compute per-token loss
|
| 217 |
-
per_token_loss = nn.functional.cross_entropy(
|
| 218 |
-
shift_logits.view(-1, shift_logits.size(-1)),
|
| 219 |
-
shift_labels.view(-1),
|
| 220 |
-
reduction="none",
|
| 221 |
-
)
|
| 222 |
-
per_token_loss = per_token_loss.view(shift_labels.size())
|
| 223 |
-
|
| 224 |
-
# Mask out instruction tokens so we only learn from responses
|
| 225 |
-
masked_loss = (per_token_loss * shift_mask).sum() / shift_mask.sum().clamp(min=1)
|
| 226 |
-
|
| 227 |
-
scaler.scale(masked_loss).backward()
|
| 228 |
-
scaler.unscale_(optimizer)
|
| 229 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
| 230 |
-
scaler.step(optimizer)
|
| 231 |
-
scaler.update()
|
| 232 |
-
|
| 233 |
-
accum_loss += masked_loss.item()
|
| 234 |
-
|
| 235 |
-
if (step + 1) % 10 == 0:
|
| 236 |
-
avg = accum_loss / 10
|
| 237 |
-
pbar.set_postfix(loss=f"{avg:.4f}", lr=f"{lr:.2e}")
|
| 238 |
-
logger.info(f"finetune step={step+1} | loss={avg:.4f}")
|
| 239 |
-
wandb.log({
|
| 240 |
-
"finetune/loss": avg,
|
| 241 |
-
"finetune/lr": lr,
|
| 242 |
-
}, step=step + 1)
|
| 243 |
-
accum_loss = 0.0
|
| 244 |
-
|
| 245 |
-
# Merge LoRA weights back for clean inference
|
| 246 |
-
if use_lora:
|
| 247 |
-
model = merge_lora(model)
|
| 248 |
-
|
| 249 |
-
save_checkpoint(model, None, total_steps, config.checkpoint_dir, filename="sage_finetuned.pt")
|
| 250 |
-
logger.info("Instruction fine-tuning complete. Checkpoint saved as sage_finetuned.pt")
|
| 251 |
-
wandb.finish()
|
| 252 |
-
return model
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
# ---------------------------------------------------------------------------
|
| 256 |
-
# Demo instruction samples (used when no dataset is provided)
|
| 257 |
-
# ---------------------------------------------------------------------------
|
| 258 |
-
|
| 259 |
-
DEMO_INSTRUCTION_SAMPLES = [
|
| 260 |
-
{"instruction": "What is the capital of France?", "response": "The capital of France is Paris."},
|
| 261 |
-
{"instruction": "Explain gravity in simple terms.", "response": "Gravity is the force that pulls objects toward each other. The more mass an object has, the stronger its gravitational pull."},
|
| 262 |
-
{"instruction": "Write a short poem about the ocean.", "response": "Waves crash upon the sandy shore,\nThe ocean's song forevermore.\nDeep blue stretching to the sky,\nSeagulls dance and clouds float by."},
|
| 263 |
-
{"instruction": "What is 15 times 12?", "response": "15 times 12 equals 180."},
|
| 264 |
-
{"instruction": "Summarize photosynthesis.", "response": "Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen, providing energy for the plant."},
|
| 265 |
-
{"instruction": "Tell me a fun fact about space.", "response": "A day on Venus is longer than a year on Venus. It takes Venus 243 Earth days to rotate once on its axis but only 225 Earth days to orbit the Sun."},
|
| 266 |
-
{"instruction": "How do airplanes fly?", "response": "Airplanes fly by generating lift through their wings. Air moves faster over the curved top of the wing than the flat bottom, creating lower pressure above and higher pressure below, which pushes the wing upward."},
|
| 267 |
-
{"instruction": "What is machine learning?", "response": "Machine learning is a branch of artificial intelligence where computers learn patterns from data instead of being explicitly programmed, allowing them to make predictions or decisions."},
|
| 268 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/inference.py
DELETED
|
@@ -1,171 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE Inference Engine
|
| 3 |
-
=====================
|
| 4 |
-
Text generation with greedy, temperature, top-k, and nucleus (top-p) sampling.
|
| 5 |
-
Supports KV-cache for O(1)-per-token generation and streaming output.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import sys
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
from typing import Optional, List
|
| 12 |
-
|
| 13 |
-
from .config import SageConfig
|
| 14 |
-
from .model import SageModel
|
| 15 |
-
from .data import SageTokenizer
|
| 16 |
-
from .utils import setup_logger
|
| 17 |
-
|
| 18 |
-
logger = setup_logger("sage.inference")
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# ---------------------------------------------------------------------------
|
| 22 |
-
# Sampling helpers
|
| 23 |
-
# ---------------------------------------------------------------------------
|
| 24 |
-
|
| 25 |
-
def _top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor:
|
| 26 |
-
"""Zero out all logits outside the top-k highest values."""
|
| 27 |
-
if k <= 0 or k >= logits.size(-1):
|
| 28 |
-
return logits
|
| 29 |
-
values, _ = torch.topk(logits, k)
|
| 30 |
-
min_val = values[:, -1].unsqueeze(-1)
|
| 31 |
-
return torch.where(logits < min_val, torch.full_like(logits, float("-inf")), logits)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _top_p_filter(logits: torch.Tensor, p: float) -> torch.Tensor:
|
| 35 |
-
"""Nucleus sampling: keep the smallest set of tokens whose cumulative
|
| 36 |
-
probability exceeds *p*."""
|
| 37 |
-
if p >= 1.0:
|
| 38 |
-
return logits
|
| 39 |
-
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 40 |
-
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 41 |
-
|
| 42 |
-
# Identify tokens to remove (cumulative prob exceeds p)
|
| 43 |
-
sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= p
|
| 44 |
-
sorted_logits[sorted_mask] = float("-inf")
|
| 45 |
-
|
| 46 |
-
# Scatter back to original order
|
| 47 |
-
logits = logits.scatter(1, sorted_idx, sorted_logits)
|
| 48 |
-
return logits
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def sample_next_token(
|
| 52 |
-
logits: torch.Tensor,
|
| 53 |
-
temperature: float = 0.8,
|
| 54 |
-
top_k: int = 50,
|
| 55 |
-
top_p: float = 0.9,
|
| 56 |
-
greedy: bool = False,
|
| 57 |
-
) -> torch.Tensor:
|
| 58 |
-
"""
|
| 59 |
-
Given raw logits for the last position, sample or greedily select the
|
| 60 |
-
next token.
|
| 61 |
-
|
| 62 |
-
Parameters
|
| 63 |
-
----------
|
| 64 |
-
logits : Tensor [batch, vocab]
|
| 65 |
-
temperature : float
|
| 66 |
-
top_k : int
|
| 67 |
-
top_p : float
|
| 68 |
-
greedy : bool — if True, ignore temperature/top-k/top-p and pick argmax.
|
| 69 |
-
|
| 70 |
-
Returns
|
| 71 |
-
-------
|
| 72 |
-
Tensor [batch, 1]
|
| 73 |
-
"""
|
| 74 |
-
if greedy:
|
| 75 |
-
return logits.argmax(dim=-1, keepdim=True)
|
| 76 |
-
|
| 77 |
-
logits = logits / max(temperature, 1e-8)
|
| 78 |
-
logits = _top_k_filter(logits, top_k)
|
| 79 |
-
logits = _top_p_filter(logits, top_p)
|
| 80 |
-
|
| 81 |
-
probs = F.softmax(logits, dim=-1)
|
| 82 |
-
return torch.multinomial(probs, num_samples=1)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# ---------------------------------------------------------------------------
|
| 86 |
-
# Main generation function
|
| 87 |
-
# ---------------------------------------------------------------------------
|
| 88 |
-
|
| 89 |
-
@torch.no_grad()
|
| 90 |
-
def generate(
|
| 91 |
-
model: SageModel,
|
| 92 |
-
tokenizer: SageTokenizer,
|
| 93 |
-
prompt: str,
|
| 94 |
-
max_new_tokens: int = 256,
|
| 95 |
-
temperature: float = 0.8,
|
| 96 |
-
top_k: int = 50,
|
| 97 |
-
top_p: float = 0.9,
|
| 98 |
-
greedy: bool = False,
|
| 99 |
-
stream: bool = True,
|
| 100 |
-
device: Optional[torch.device] = None,
|
| 101 |
-
) -> str:
|
| 102 |
-
"""
|
| 103 |
-
Generate text from *prompt* using the SAGE model.
|
| 104 |
-
|
| 105 |
-
Parameters
|
| 106 |
-
----------
|
| 107 |
-
model : SageModel
|
| 108 |
-
tokenizer : SageTokenizer
|
| 109 |
-
prompt : str
|
| 110 |
-
max_new_tokens : int
|
| 111 |
-
temperature, top_k, top_p : sampling hyper-parameters
|
| 112 |
-
greedy : bool — use argmax decoding
|
| 113 |
-
stream : bool — print tokens as they are generated
|
| 114 |
-
device : torch.device
|
| 115 |
-
|
| 116 |
-
Returns
|
| 117 |
-
-------
|
| 118 |
-
str — the complete generated text (prompt + new tokens).
|
| 119 |
-
"""
|
| 120 |
-
if device is None:
|
| 121 |
-
device = next(model.parameters()).device
|
| 122 |
-
|
| 123 |
-
base_model = getattr(model, "module", model)
|
| 124 |
-
base_model.eval()
|
| 125 |
-
|
| 126 |
-
# Encode prompt
|
| 127 |
-
prompt_tokens = tokenizer.encode(prompt)
|
| 128 |
-
if not prompt_tokens:
|
| 129 |
-
prompt_tokens = [tokenizer.eos_token_id]
|
| 130 |
-
|
| 131 |
-
input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
|
| 132 |
-
|
| 133 |
-
generated_tokens: List[int] = list(prompt_tokens)
|
| 134 |
-
kv_caches = None
|
| 135 |
-
|
| 136 |
-
# --- Prefill: run the full prompt through the model once ---
|
| 137 |
-
logits, kv_caches = base_model(input_ids)
|
| 138 |
-
next_logits = logits[:, -1, :]
|
| 139 |
-
|
| 140 |
-
for _ in range(max_new_tokens):
|
| 141 |
-
next_id = sample_next_token(
|
| 142 |
-
next_logits,
|
| 143 |
-
temperature=temperature,
|
| 144 |
-
top_k=top_k,
|
| 145 |
-
top_p=top_p,
|
| 146 |
-
greedy=greedy,
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
token_id = next_id.item()
|
| 150 |
-
|
| 151 |
-
# Stop on EOS
|
| 152 |
-
if token_id == tokenizer.eos_token_id:
|
| 153 |
-
break
|
| 154 |
-
|
| 155 |
-
generated_tokens.append(token_id)
|
| 156 |
-
|
| 157 |
-
# Stream output: decode and print only the new token
|
| 158 |
-
if stream:
|
| 159 |
-
token_str = tokenizer.decode([token_id])
|
| 160 |
-
print(token_str, end="", flush=True)
|
| 161 |
-
|
| 162 |
-
# --- Decode step: feed only the new token, reuse KV-cache ---
|
| 163 |
-
next_input = next_id.view(1, 1)
|
| 164 |
-
logits, kv_caches = base_model(next_input, kv_caches=kv_caches)
|
| 165 |
-
next_logits = logits[:, -1, :]
|
| 166 |
-
|
| 167 |
-
if stream:
|
| 168 |
-
print() # newline after streaming completes
|
| 169 |
-
|
| 170 |
-
base_model.train()
|
| 171 |
-
return tokenizer.decode(generated_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/memory.py
DELETED
|
@@ -1,240 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE Memory & RAG Module
|
| 3 |
-
=========================
|
| 4 |
-
Provides:
|
| 5 |
-
- A FAISS-backed vector store for retrieval-augmented generation (RAG).
|
| 6 |
-
- A rolling conversation-history manager that truncates intelligently
|
| 7 |
-
to stay within the model's context window.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import os
|
| 11 |
-
import numpy as np
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from typing import List, Optional, Tuple
|
| 15 |
-
|
| 16 |
-
from .data import SageTokenizer
|
| 17 |
-
from .utils import setup_logger
|
| 18 |
-
|
| 19 |
-
logger = setup_logger("sage.memory")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# ===================================================================
|
| 23 |
-
# Simple embedding helper (uses mean-pooled token embeddings)
|
| 24 |
-
# ===================================================================
|
| 25 |
-
|
| 26 |
-
def _embed_text(text: str, tokenizer: SageTokenizer, model: torch.nn.Module, device: torch.device) -> np.ndarray:
|
| 27 |
-
"""
|
| 28 |
-
Produce a fixed-length embedding for *text* by mean-pooling the
|
| 29 |
-
model's token embeddings. This is lightweight and avoids a full
|
| 30 |
-
forward pass — suitable for a small retrieval index.
|
| 31 |
-
"""
|
| 32 |
-
tokens = tokenizer.encode(text)
|
| 33 |
-
if not tokens:
|
| 34 |
-
# Return a zero vector when text is empty
|
| 35 |
-
d_model = model.wte.weight.shape[1]
|
| 36 |
-
return np.zeros(d_model, dtype=np.float32)
|
| 37 |
-
|
| 38 |
-
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 39 |
-
with torch.no_grad():
|
| 40 |
-
embeddings = model.wte(ids) # [1, seq_len, d_model]
|
| 41 |
-
mean_emb = embeddings.mean(dim=1) # [1, d_model]
|
| 42 |
-
# L2-normalize for cosine similarity in FAISS
|
| 43 |
-
mean_emb = F.normalize(mean_emb, p=2, dim=-1)
|
| 44 |
-
return mean_emb.squeeze(0).cpu().numpy()
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# ===================================================================
|
| 48 |
-
# FAISS-backed Vector Store
|
| 49 |
-
# ===================================================================
|
| 50 |
-
|
| 51 |
-
class VectorStore:
|
| 52 |
-
"""
|
| 53 |
-
A lightweight document store backed by FAISS (Inner Product index,
|
| 54 |
-
which equals cosine similarity when vectors are L2-normalized).
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(self, dim: int):
|
| 58 |
-
try:
|
| 59 |
-
import faiss
|
| 60 |
-
except ImportError:
|
| 61 |
-
raise ImportError(
|
| 62 |
-
"FAISS is required for RAG. Install it with: pip install faiss-cpu"
|
| 63 |
-
)
|
| 64 |
-
self.dim = dim
|
| 65 |
-
self.index = faiss.IndexFlatIP(dim) # inner-product (cosine after L2-norm)
|
| 66 |
-
self.documents: List[str] = []
|
| 67 |
-
logger.info(f"VectorStore initialized (dim={dim})")
|
| 68 |
-
|
| 69 |
-
def add(self, texts: List[str], embeddings: np.ndarray) -> None:
|
| 70 |
-
"""Add documents and their embeddings to the store."""
|
| 71 |
-
assert embeddings.shape[0] == len(texts)
|
| 72 |
-
assert embeddings.shape[1] == self.dim
|
| 73 |
-
self.index.add(embeddings.astype(np.float32))
|
| 74 |
-
self.documents.extend(texts)
|
| 75 |
-
logger.info(f"Added {len(texts)} documents. Total: {len(self.documents)}")
|
| 76 |
-
|
| 77 |
-
def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[Tuple[str, float]]:
|
| 78 |
-
"""Return the top-k most similar documents with their scores."""
|
| 79 |
-
if self.index.ntotal == 0:
|
| 80 |
-
return []
|
| 81 |
-
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
|
| 82 |
-
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
| 83 |
-
results = []
|
| 84 |
-
for score, idx in zip(scores[0], indices[0]):
|
| 85 |
-
if idx < 0:
|
| 86 |
-
continue
|
| 87 |
-
results.append((self.documents[idx], float(score)))
|
| 88 |
-
return results
|
| 89 |
-
|
| 90 |
-
@property
|
| 91 |
-
def size(self) -> int:
|
| 92 |
-
return self.index.ntotal
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# ===================================================================
|
| 96 |
-
# RAG Manager
|
| 97 |
-
# ===================================================================
|
| 98 |
-
|
| 99 |
-
class RAGManager:
|
| 100 |
-
"""
|
| 101 |
-
High-level retrieval-augmented generation manager.
|
| 102 |
-
|
| 103 |
-
Call ``add_documents`` to ingest text, then ``retrieve_context`` at
|
| 104 |
-
inference time to prepend relevant chunks to the user prompt.
|
| 105 |
-
"""
|
| 106 |
-
|
| 107 |
-
def __init__(
|
| 108 |
-
self,
|
| 109 |
-
model: torch.nn.Module,
|
| 110 |
-
tokenizer: SageTokenizer,
|
| 111 |
-
device: torch.device,
|
| 112 |
-
chunk_size: int = 200,
|
| 113 |
-
chunk_overlap: int = 50,
|
| 114 |
-
):
|
| 115 |
-
self.model = model
|
| 116 |
-
self.tokenizer = tokenizer
|
| 117 |
-
self.device = device
|
| 118 |
-
self.chunk_size = chunk_size
|
| 119 |
-
self.chunk_overlap = chunk_overlap
|
| 120 |
-
|
| 121 |
-
d_model = model.wte.weight.shape[1]
|
| 122 |
-
self.store = VectorStore(dim=d_model)
|
| 123 |
-
self.enabled = False
|
| 124 |
-
|
| 125 |
-
def _chunk_text(self, text: str) -> List[str]:
|
| 126 |
-
"""Split text into overlapping word-level chunks."""
|
| 127 |
-
words = text.split()
|
| 128 |
-
chunks: List[str] = []
|
| 129 |
-
start = 0
|
| 130 |
-
while start < len(words):
|
| 131 |
-
end = start + self.chunk_size
|
| 132 |
-
chunk = " ".join(words[start:end])
|
| 133 |
-
chunks.append(chunk)
|
| 134 |
-
start += self.chunk_size - self.chunk_overlap
|
| 135 |
-
return chunks
|
| 136 |
-
|
| 137 |
-
def add_documents(self, texts: List[str]) -> None:
|
| 138 |
-
"""Chunk and embed documents, then add to the vector store."""
|
| 139 |
-
all_chunks: List[str] = []
|
| 140 |
-
for text in texts:
|
| 141 |
-
all_chunks.extend(self._chunk_text(text))
|
| 142 |
-
|
| 143 |
-
if not all_chunks:
|
| 144 |
-
logger.warning("No document chunks to add.")
|
| 145 |
-
return
|
| 146 |
-
|
| 147 |
-
embeddings = np.stack([
|
| 148 |
-
_embed_text(chunk, self.tokenizer, self.model, self.device)
|
| 149 |
-
for chunk in all_chunks
|
| 150 |
-
])
|
| 151 |
-
self.store.add(all_chunks, embeddings)
|
| 152 |
-
|
| 153 |
-
def retrieve_context(self, query: str, top_k: int = 3) -> str:
|
| 154 |
-
"""
|
| 155 |
-
Retrieve the top-k most relevant chunks for *query* and
|
| 156 |
-
concatenate them into a context string.
|
| 157 |
-
"""
|
| 158 |
-
if not self.enabled or self.store.size == 0:
|
| 159 |
-
return ""
|
| 160 |
-
|
| 161 |
-
q_emb = _embed_text(query, self.tokenizer, self.model, self.device)
|
| 162 |
-
results = self.store.search(q_emb, top_k=top_k)
|
| 163 |
-
|
| 164 |
-
if not results:
|
| 165 |
-
return ""
|
| 166 |
-
|
| 167 |
-
context_parts = [f"[Context {i+1}] {doc}" for i, (doc, _score) in enumerate(results)]
|
| 168 |
-
return "\n\n".join(context_parts) + "\n\n"
|
| 169 |
-
|
| 170 |
-
def toggle(self, on: bool) -> None:
|
| 171 |
-
self.enabled = on
|
| 172 |
-
state = "enabled" if on else "disabled"
|
| 173 |
-
logger.info(f"RAG {state}. Store contains {self.store.size} chunks.")
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
# ===================================================================
|
| 177 |
-
# Conversation History Manager
|
| 178 |
-
# ===================================================================
|
| 179 |
-
|
| 180 |
-
DEFAULT_SYSTEM_PROMPT = (
|
| 181 |
-
"You are SAGE, a high-quality reasoning assistant. "
|
| 182 |
-
"Your goal is to provide accurate, structured, and deep logical explanations.\n\n"
|
| 183 |
-
"CRITICAL GUIDELINES:\n"
|
| 184 |
-
"1. THINKING PHASE: You must ALWAYS start your response with a <thinking> section. "
|
| 185 |
-
"In this section, break down the user's request, identify key constraints, and plan your logical steps.\n"
|
| 186 |
-
"2. RESPONSE PHASE: After completing your internal reasoning, provide your final answer within <response> tags.\n"
|
| 187 |
-
"3. QUALITY: Prioritize step-by-step mathematical or logical derivation over short answers.\n"
|
| 188 |
-
"4. NO REPETITION: Avoid filler words or circular logic.\n\n"
|
| 189 |
-
"RESPONSE TEMPLATE:\n"
|
| 190 |
-
"<thinking>\n[Step-by-step logic here]\n</thinking>\n"
|
| 191 |
-
"<response>\n[Final clear answer here]\n</response>"
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
class ConversationHistory:
|
| 195 |
-
"""
|
| 196 |
-
Rolling conversation history that stays within a token budget.
|
| 197 |
-
|
| 198 |
-
Older turns are dropped when the history would exceed the context window.
|
| 199 |
-
"""
|
| 200 |
-
|
| 201 |
-
def __init__(self, tokenizer: SageTokenizer, max_tokens: int = 900):
|
| 202 |
-
self.tokenizer = tokenizer
|
| 203 |
-
self.max_tokens = max_tokens
|
| 204 |
-
self.turns: List[dict] = [] # [{"role": "user"/"assistant", "text": ...}, ...]
|
| 205 |
-
|
| 206 |
-
def add(self, role: str, text: str) -> None:
|
| 207 |
-
"""Record a new conversational turn."""
|
| 208 |
-
self.turns.append({"role": role, "text": text})
|
| 209 |
-
self._trim()
|
| 210 |
-
|
| 211 |
-
def _trim(self) -> None:
|
| 212 |
-
"""Drop oldest turns until the total token count is within budget."""
|
| 213 |
-
while self._total_tokens() > self.max_tokens and len(self.turns) > 1:
|
| 214 |
-
self.turns.pop(0)
|
| 215 |
-
|
| 216 |
-
def _total_tokens(self) -> int:
|
| 217 |
-
return sum(len(self.tokenizer.encode(t["text"])) for t in self.turns)
|
| 218 |
-
|
| 219 |
-
def build_prompt(self, new_user_message: str, rag_context: str = "") -> str:
|
| 220 |
-
"""
|
| 221 |
-
Assemble the full prompt from history + RAG context + new message.
|
| 222 |
-
"""
|
| 223 |
-
parts: List[str] = []
|
| 224 |
-
|
| 225 |
-
parts.append(DEFAULT_SYSTEM_PROMPT)
|
| 226 |
-
|
| 227 |
-
if rag_context:
|
| 228 |
-
parts.append(rag_context)
|
| 229 |
-
|
| 230 |
-
for turn in self.turns:
|
| 231 |
-
prefix = "User:" if turn["role"] == "user" else "SAGE:"
|
| 232 |
-
parts.append(f"{prefix} {turn['text']}")
|
| 233 |
-
|
| 234 |
-
parts.append(f"User: {new_user_message}")
|
| 235 |
-
parts.append("SAGE:")
|
| 236 |
-
|
| 237 |
-
return "\n".join(parts)
|
| 238 |
-
|
| 239 |
-
def clear(self) -> None:
|
| 240 |
-
self.turns.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/model.py
DELETED
|
@@ -1,267 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import math
|
| 5 |
-
from typing import Optional, Tuple
|
| 6 |
-
from .config import SageConfig
|
| 7 |
-
|
| 8 |
-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
|
| 9 |
-
"""Precomputes rotary positional embedding frequencies."""
|
| 10 |
-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 11 |
-
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
| 12 |
-
freqs = torch.outer(t, freqs).float()
|
| 13 |
-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 14 |
-
return freqs_cis
|
| 15 |
-
|
| 16 |
-
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 17 |
-
"""Applies rotary positional embeddings to queries and keys."""
|
| 18 |
-
# Ensure freqs_cis is complex (DataParallel can sometimes replicate it as real)
|
| 19 |
-
if not torch.is_complex(freqs_cis) and freqs_cis.shape[-1] == 2:
|
| 20 |
-
freqs_cis = torch.view_as_complex(freqs_cis)
|
| 21 |
-
|
| 22 |
-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 23 |
-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 24 |
-
|
| 25 |
-
# Reshape freqs_cis to broadcast with xq_ and xk_
|
| 26 |
-
# xq_, xk_ shape: [batch, seq_len, n_heads, dim_head//2]
|
| 27 |
-
# freqs_cis shape: [seq_len, dim_head//2]
|
| 28 |
-
freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2)
|
| 29 |
-
|
| 30 |
-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 31 |
-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 32 |
-
|
| 33 |
-
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 34 |
-
|
| 35 |
-
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 36 |
-
"""Repeat Key/Value heads n_rep times to match number of Query heads."""
|
| 37 |
-
if n_rep == 1:
|
| 38 |
-
return x
|
| 39 |
-
B, T, n_kv_heads, head_dim = x.size()
|
| 40 |
-
return (
|
| 41 |
-
x[:, :, :, None, :]
|
| 42 |
-
.expand(B, T, n_kv_heads, n_rep, head_dim)
|
| 43 |
-
.reshape(B, T, n_kv_heads * n_rep, head_dim)
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
class CausalSelfAttention(nn.Module):
|
| 47 |
-
def __init__(self, config: SageConfig):
|
| 48 |
-
super().__init__()
|
| 49 |
-
self.n_heads = config.n_heads
|
| 50 |
-
self.n_kv_heads = config.n_kv_heads
|
| 51 |
-
self.n_rep = self.n_heads // self.n_kv_heads
|
| 52 |
-
self.d_model = config.d_model
|
| 53 |
-
assert self.d_model % self.n_heads == 0
|
| 54 |
-
self.head_dim = self.d_model // self.n_heads
|
| 55 |
-
|
| 56 |
-
self.wq = nn.Linear(self.d_model, self.n_heads * self.head_dim, bias=False)
|
| 57 |
-
self.wk = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=False)
|
| 58 |
-
self.wv = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=False)
|
| 59 |
-
self.wo = nn.Linear(self.d_model, self.d_model, bias=False)
|
| 60 |
-
|
| 61 |
-
self.resid_dropout = nn.Dropout(config.dropout)
|
| 62 |
-
|
| 63 |
-
# Flash attention handles causality via is_causal flag if seq_len > 1
|
| 64 |
-
|
| 65 |
-
def forward(
|
| 66 |
-
self,
|
| 67 |
-
x: torch.Tensor,
|
| 68 |
-
freqs_cis: torch.Tensor,
|
| 69 |
-
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
| 70 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 71 |
-
B, T, C = x.size() # batch, seq_len, d_model
|
| 72 |
-
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
| 73 |
-
|
| 74 |
-
q = q.view(B, T, self.n_heads, self.head_dim)
|
| 75 |
-
k = k.view(B, T, self.n_kv_heads, self.head_dim)
|
| 76 |
-
v = v.view(B, T, self.n_kv_heads, self.head_dim)
|
| 77 |
-
|
| 78 |
-
q, k = apply_rotary_emb(q, k, freqs_cis)
|
| 79 |
-
|
| 80 |
-
if kv_cache is not None:
|
| 81 |
-
# We are generating token by token
|
| 82 |
-
k_cache, v_cache = kv_cache
|
| 83 |
-
k = torch.cat([k_cache, k], dim=1)
|
| 84 |
-
v = torch.cat([v_cache, v], dim=1)
|
| 85 |
-
new_kv_cache = (k, v)
|
| 86 |
-
else:
|
| 87 |
-
new_kv_cache = None
|
| 88 |
-
|
| 89 |
-
# Repeat KV heads to match Q heads (GQA)
|
| 90 |
-
k = repeat_kv(k, self.n_rep)
|
| 91 |
-
v = repeat_kv(v, self.n_rep)
|
| 92 |
-
|
| 93 |
-
# Move heads to correct dimension: (B, n_heads, T, head_dim)
|
| 94 |
-
q = q.transpose(1, 2)
|
| 95 |
-
k = k.transpose(1, 2)
|
| 96 |
-
v = v.transpose(1, 2)
|
| 97 |
-
|
| 98 |
-
# Flash attention natively supported via scaled_dot_product_attention
|
| 99 |
-
is_causal = (kv_cache is None and T > 1)
|
| 100 |
-
try:
|
| 101 |
-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.resid_dropout.p if self.training else 0.0, is_causal=is_causal)
|
| 102 |
-
except Exception:
|
| 103 |
-
# Manual attention fallback for older architectures (like P100 sm_60)
|
| 104 |
-
attn_weights = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 105 |
-
if is_causal:
|
| 106 |
-
# Use a causal mask
|
| 107 |
-
causal_mask = torch.tril(torch.ones(T, T, device=q.device)).view(1, 1, T, T)
|
| 108 |
-
attn_weights = attn_weights.masked_fill(causal_mask == 0, float('-inf'))
|
| 109 |
-
|
| 110 |
-
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 111 |
-
if self.training:
|
| 112 |
-
attn_weights = self.resid_dropout(attn_weights)
|
| 113 |
-
|
| 114 |
-
y = attn_weights @ v
|
| 115 |
-
|
| 116 |
-
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 117 |
-
y = self.resid_dropout(self.wo(y))
|
| 118 |
-
|
| 119 |
-
return y, new_kv_cache
|
| 120 |
-
|
| 121 |
-
class ExpertFFN(nn.Module):
|
| 122 |
-
def __init__(self, config: SageConfig):
|
| 123 |
-
super().__init__()
|
| 124 |
-
self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| 125 |
-
self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False)
|
| 126 |
-
self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| 127 |
-
self.dropout = nn.Dropout(config.dropout)
|
| 128 |
-
|
| 129 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 130 |
-
# SwiGLU activation structure
|
| 131 |
-
hidden = F.silu(self.w1(x)) * self.w3(x)
|
| 132 |
-
return self.dropout(self.w2(hidden))
|
| 133 |
-
|
| 134 |
-
class MoE(nn.Module):
|
| 135 |
-
def __init__(self, config: SageConfig):
|
| 136 |
-
super().__init__()
|
| 137 |
-
self.n_experts = config.n_experts
|
| 138 |
-
self.top_k = config.num_experts_per_tok
|
| 139 |
-
self.d_model = config.d_model
|
| 140 |
-
|
| 141 |
-
self.router = nn.Linear(self.d_model, self.n_experts, bias=False)
|
| 142 |
-
self.experts = nn.ModuleList([ExpertFFN(config) for _ in range(self.n_experts)])
|
| 143 |
-
|
| 144 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 145 |
-
B, T, C = x.size()
|
| 146 |
-
x_flat = x.view(-1, C) # [B*T, C]
|
| 147 |
-
|
| 148 |
-
router_logits = self.router(x_flat) # [B*T, n_experts]
|
| 149 |
-
routing_weights = F.softmax(router_logits, dim=-1)
|
| 150 |
-
|
| 151 |
-
# Select Top K experts
|
| 152 |
-
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # [B*T, top_k]
|
| 153 |
-
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # re-normalize
|
| 154 |
-
|
| 155 |
-
final_out = torch.zeros_like(x_flat)
|
| 156 |
-
|
| 157 |
-
# Iterate over experts and compute their outputs
|
| 158 |
-
for i, expert in enumerate(self.experts):
|
| 159 |
-
# Find which tokens chose this expert
|
| 160 |
-
expert_mask = (selected_experts == i)
|
| 161 |
-
token_idx, kth_expert = torch.where(expert_mask)
|
| 162 |
-
|
| 163 |
-
if token_idx.shape[0] > 0:
|
| 164 |
-
expert_inputs = x_flat[token_idx]
|
| 165 |
-
expert_outputs = expert(expert_inputs)
|
| 166 |
-
|
| 167 |
-
# Apply router weight
|
| 168 |
-
weights = routing_weights[token_idx, kth_expert].unsqueeze(-1)
|
| 169 |
-
final_out[token_idx] += expert_outputs * weights
|
| 170 |
-
|
| 171 |
-
return final_out.view(B, T, C)
|
| 172 |
-
|
| 173 |
-
class TransformerBlock(nn.Module):
|
| 174 |
-
def __init__(self, config: SageConfig):
|
| 175 |
-
super().__init__()
|
| 176 |
-
self.norm1 = nn.LayerNorm(config.d_model)
|
| 177 |
-
self.attn = CausalSelfAttention(config)
|
| 178 |
-
self.norm2 = nn.LayerNorm(config.d_model)
|
| 179 |
-
self.moe = MoE(config)
|
| 180 |
-
|
| 181 |
-
def forward(
|
| 182 |
-
self,
|
| 183 |
-
x: torch.Tensor,
|
| 184 |
-
freqs_cis: torch.Tensor,
|
| 185 |
-
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
| 186 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 187 |
-
# Pre-LayerNorm architecture
|
| 188 |
-
h, new_kv_cache = self.attn(self.norm1(x), freqs_cis, kv_cache)
|
| 189 |
-
x = x + h
|
| 190 |
-
x = x + self.moe(self.norm2(x))
|
| 191 |
-
return x, new_kv_cache
|
| 192 |
-
|
| 193 |
-
class SageModel(nn.Module):
|
| 194 |
-
def __init__(self, config: SageConfig):
|
| 195 |
-
super().__init__()
|
| 196 |
-
self.config = config
|
| 197 |
-
|
| 198 |
-
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
| 199 |
-
self.drop = nn.Dropout(config.dropout)
|
| 200 |
-
|
| 201 |
-
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
|
| 202 |
-
|
| 203 |
-
self.ln_f = nn.LayerNorm(config.d_model)
|
| 204 |
-
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 205 |
-
|
| 206 |
-
# Weight tying
|
| 207 |
-
self.wte.weight = self.lm_head.weight
|
| 208 |
-
|
| 209 |
-
# Precompute RoPE frequencies
|
| 210 |
-
self.register_buffer("freqs_cis", precompute_freqs_cis(config.d_model // config.n_heads, config.max_seq_len * 2), persistent=False)
|
| 211 |
-
|
| 212 |
-
self.apply(self._init_weights)
|
| 213 |
-
|
| 214 |
-
def _init_weights(self, module):
|
| 215 |
-
if isinstance(module, nn.Linear):
|
| 216 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 217 |
-
if module.bias is not None:
|
| 218 |
-
torch.nn.init.zeros_(module.bias)
|
| 219 |
-
elif isinstance(module, nn.Embedding):
|
| 220 |
-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 221 |
-
elif isinstance(module, nn.LayerNorm):
|
| 222 |
-
torch.nn.init.zeros_(module.bias)
|
| 223 |
-
torch.nn.init.ones_(module.weight)
|
| 224 |
-
|
| 225 |
-
def forward(
|
| 226 |
-
self,
|
| 227 |
-
idx: torch.Tensor,
|
| 228 |
-
kv_caches: Optional[list] = None
|
| 229 |
-
) -> Tuple[torch.Tensor, Optional[list]]:
|
| 230 |
-
B, T = idx.size()
|
| 231 |
-
|
| 232 |
-
if kv_caches is not None:
|
| 233 |
-
# generating context, token is at specific position
|
| 234 |
-
start_pos = kv_caches[0][0].shape[1]
|
| 235 |
-
else:
|
| 236 |
-
start_pos = 0
|
| 237 |
-
|
| 238 |
-
freqs_cis = self.freqs_cis[start_pos : start_pos + T]
|
| 239 |
-
|
| 240 |
-
x = self.drop(self.wte(idx))
|
| 241 |
-
|
| 242 |
-
new_kv_caches = []
|
| 243 |
-
for i, layer in enumerate(self.layers):
|
| 244 |
-
kv_cache = kv_caches[i] if kv_caches else None
|
| 245 |
-
|
| 246 |
-
# Use gradient checkpointing during training
|
| 247 |
-
if self.training and kv_cache is None:
|
| 248 |
-
def create_custom_forward(module):
|
| 249 |
-
def custom_forward(x_in, freqs_cis_in):
|
| 250 |
-
return module(x_in, freqs_cis_in, None)
|
| 251 |
-
return custom_forward
|
| 252 |
-
|
| 253 |
-
x, new_kv_cache = torch.utils.checkpoint.checkpoint(
|
| 254 |
-
create_custom_forward(layer),
|
| 255 |
-
x, freqs_cis,
|
| 256 |
-
use_reentrant=False
|
| 257 |
-
)
|
| 258 |
-
else:
|
| 259 |
-
x, new_kv_cache = layer(x, freqs_cis, kv_cache)
|
| 260 |
-
|
| 261 |
-
if new_kv_cache is not None:
|
| 262 |
-
new_kv_caches.append(new_kv_cache)
|
| 263 |
-
|
| 264 |
-
x = self.ln_f(x)
|
| 265 |
-
logits = self.lm_head(x) # [B, T, vocab_size]
|
| 266 |
-
|
| 267 |
-
return logits, new_kv_caches if len(new_kv_caches) > 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/optimize.py
DELETED
|
@@ -1,164 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE Optimization Layer
|
| 3 |
-
=======================
|
| 4 |
-
Post-training quantization (INT8), optional pruning, and knowledge-distillation
|
| 5 |
-
loss utilities.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.utils.prune as prune
|
| 11 |
-
from typing import Optional
|
| 12 |
-
|
| 13 |
-
from .model import SageModel
|
| 14 |
-
from .config import SageConfig
|
| 15 |
-
from .utils import setup_logger
|
| 16 |
-
|
| 17 |
-
logger = setup_logger("sage.optimize")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# ===================================================================
|
| 21 |
-
# INT8 Dynamic Quantization
|
| 22 |
-
# ===================================================================
|
| 23 |
-
|
| 24 |
-
def quantize_int8(model: SageModel) -> nn.Module:
|
| 25 |
-
"""
|
| 26 |
-
Apply dynamic INT8 quantization to all Linear layers in the model.
|
| 27 |
-
|
| 28 |
-
This reduces model size by ~2-4x and can speed up CPU inference.
|
| 29 |
-
The model is moved to CPU before quantization because PyTorch's
|
| 30 |
-
dynamic quantization only supports CPU tensors.
|
| 31 |
-
|
| 32 |
-
Returns
|
| 33 |
-
-------
|
| 34 |
-
nn.Module — the quantized model (on CPU).
|
| 35 |
-
"""
|
| 36 |
-
base_model = getattr(model, "module", model)
|
| 37 |
-
base_model = base_model.cpu().eval()
|
| 38 |
-
|
| 39 |
-
quantized = torch.quantization.quantize_dynamic(
|
| 40 |
-
base_model,
|
| 41 |
-
{nn.Linear}, # quantize all linear layers
|
| 42 |
-
dtype=torch.qint8,
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
# Report size reduction
|
| 46 |
-
orig_size = sum(p.numel() * p.element_size() for p in base_model.parameters())
|
| 47 |
-
# Quantized parameters may not report element_size correctly,
|
| 48 |
-
# so we estimate based on INT8 = 1 byte per weight.
|
| 49 |
-
quant_size = sum(p.numel() for p in quantized.parameters()) # * 1 byte
|
| 50 |
-
logger.info(
|
| 51 |
-
f"Quantization complete. "
|
| 52 |
-
f"Original: {orig_size / 1e6:.1f} MB → Quantized: ~{quant_size / 1e6:.1f} MB (INT8)"
|
| 53 |
-
)
|
| 54 |
-
return quantized
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
# ===================================================================
|
| 58 |
-
# Weight Pruning
|
| 59 |
-
# ===================================================================
|
| 60 |
-
|
| 61 |
-
def prune_model(model: SageModel, amount: float = 0.3) -> SageModel:
|
| 62 |
-
"""
|
| 63 |
-
Apply unstructured L1 pruning to all Linear layers, removing the
|
| 64 |
-
*amount* fraction of weights with the smallest magnitude.
|
| 65 |
-
|
| 66 |
-
Parameters
|
| 67 |
-
----------
|
| 68 |
-
model : SageModel
|
| 69 |
-
amount : float
|
| 70 |
-
Fraction of weights to prune (0.0 – 1.0).
|
| 71 |
-
|
| 72 |
-
Returns
|
| 73 |
-
-------
|
| 74 |
-
SageModel — the pruned model (pruning masks are permanent after this call).
|
| 75 |
-
"""
|
| 76 |
-
pruned_count = 0
|
| 77 |
-
total_count = 0
|
| 78 |
-
|
| 79 |
-
base_model = getattr(model, "module", model)
|
| 80 |
-
for name, module in base_model.named_modules():
|
| 81 |
-
if isinstance(module, nn.Linear):
|
| 82 |
-
prune.l1_unstructured(module, name="weight", amount=amount)
|
| 83 |
-
prune.remove(module, "weight") # make the pruning permanent
|
| 84 |
-
pruned_count += (module.weight == 0).sum().item()
|
| 85 |
-
total_count += module.weight.numel()
|
| 86 |
-
|
| 87 |
-
sparsity = pruned_count / max(total_count, 1) * 100
|
| 88 |
-
logger.info(
|
| 89 |
-
f"Pruning complete. {pruned_count:,} / {total_count:,} weights zeroed "
|
| 90 |
-
f"({sparsity:.1f}% sparsity)"
|
| 91 |
-
)
|
| 92 |
-
return model
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
# ===================================================================
|
| 96 |
-
# Knowledge Distillation Loss
|
| 97 |
-
# ===================================================================
|
| 98 |
-
|
| 99 |
-
def distillation_loss(
|
| 100 |
-
student_logits: torch.Tensor,
|
| 101 |
-
teacher_logits: torch.Tensor,
|
| 102 |
-
labels: torch.Tensor,
|
| 103 |
-
temperature: float = 2.0,
|
| 104 |
-
alpha: float = 0.5,
|
| 105 |
-
ignore_index: int = -100,
|
| 106 |
-
) -> torch.Tensor:
|
| 107 |
-
"""
|
| 108 |
-
Combined knowledge-distillation loss.
|
| 109 |
-
|
| 110 |
-
``L = alpha * KL(softmax(teacher/T), softmax(student/T)) * T^2
|
| 111 |
-
+ (1 - alpha) * CE(student, labels)``
|
| 112 |
-
|
| 113 |
-
Parameters
|
| 114 |
-
----------
|
| 115 |
-
student_logits : Tensor [B, T, V]
|
| 116 |
-
teacher_logits : Tensor [B, T, V]
|
| 117 |
-
labels : Tensor [B, T]
|
| 118 |
-
temperature : float
|
| 119 |
-
alpha : float — weight for the distillation term (0 → pure CE, 1 → pure KD).
|
| 120 |
-
ignore_index : int — label value to ignore in cross-entropy.
|
| 121 |
-
|
| 122 |
-
Returns
|
| 123 |
-
-------
|
| 124 |
-
Tensor (scalar)
|
| 125 |
-
"""
|
| 126 |
-
# Soft targets
|
| 127 |
-
soft_student = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
|
| 128 |
-
soft_teacher = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
|
| 129 |
-
|
| 130 |
-
kd_loss = torch.nn.functional.kl_div(
|
| 131 |
-
soft_student.view(-1, soft_student.size(-1)),
|
| 132 |
-
soft_teacher.view(-1, soft_teacher.size(-1)),
|
| 133 |
-
reduction="batchmean",
|
| 134 |
-
) * (temperature ** 2)
|
| 135 |
-
|
| 136 |
-
# Hard-label cross-entropy
|
| 137 |
-
ce_loss = torch.nn.functional.cross_entropy(
|
| 138 |
-
student_logits.view(-1, student_logits.size(-1)),
|
| 139 |
-
labels.view(-1),
|
| 140 |
-
ignore_index=ignore_index,
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
return alpha * kd_loss + (1 - alpha) * ce_loss
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
# ===================================================================
|
| 147 |
-
# torch.compile wrapper (PyTorch 2.0+)
|
| 148 |
-
# ===================================================================
|
| 149 |
-
|
| 150 |
-
def try_compile(model: nn.Module) -> nn.Module:
|
| 151 |
-
"""
|
| 152 |
-
Attempt to compile the model with ``torch.compile`` for faster
|
| 153 |
-
execution. Falls back gracefully if compilation is not available.
|
| 154 |
-
"""
|
| 155 |
-
if hasattr(torch, "compile"):
|
| 156 |
-
try:
|
| 157 |
-
compiled = torch.compile(model)
|
| 158 |
-
logger.info("Model compiled with torch.compile for accelerated execution.")
|
| 159 |
-
return compiled
|
| 160 |
-
except Exception as e:
|
| 161 |
-
logger.warning(f"torch.compile failed ({e}). Using eager mode.")
|
| 162 |
-
else:
|
| 163 |
-
logger.info("torch.compile not available (requires PyTorch 2.0+). Using eager mode.")
|
| 164 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/train.py
DELETED
|
@@ -1,266 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
SAGE Training System
|
| 3 |
-
====================
|
| 4 |
-
Complete training loop with AdamW, cosine-decay LR schedule, mixed-precision
|
| 5 |
-
(AMP), gradient accumulation, gradient clipping, and checkpoint management.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import math
|
| 9 |
-
import time
|
| 10 |
-
import torch
|
| 11 |
-
import torch.nn as nn
|
| 12 |
-
from torch.amp import GradScaler, autocast
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
import wandb
|
| 15 |
-
from typing import Optional
|
| 16 |
-
|
| 17 |
-
from .config import SageConfig
|
| 18 |
-
from .model import SageModel
|
| 19 |
-
from .data import SageTokenizer, create_dataloader
|
| 20 |
-
from .utils import setup_logger, save_checkpoint, load_checkpoint
|
| 21 |
-
|
| 22 |
-
logger = setup_logger("sage.train")
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# ---------------------------------------------------------------------------
|
| 26 |
-
# Learning-rate scheduler helpers
|
| 27 |
-
# ---------------------------------------------------------------------------
|
| 28 |
-
|
| 29 |
-
def get_lr(step: int, config: SageConfig, total_steps: int) -> float:
|
| 30 |
-
"""Cosine decay with linear warmup. Returns the learning rate for *step*."""
|
| 31 |
-
if step < config.warmup_steps:
|
| 32 |
-
# Linear warmup
|
| 33 |
-
return config.learning_rate * (step + 1) / config.warmup_steps
|
| 34 |
-
|
| 35 |
-
# Cosine decay phase
|
| 36 |
-
decay_steps = total_steps - config.warmup_steps
|
| 37 |
-
progress = (step - config.warmup_steps) / max(1, decay_steps)
|
| 38 |
-
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 39 |
-
return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def set_lr(optimizer: torch.optim.Optimizer, lr: float) -> None:
|
| 43 |
-
"""Manually sets the learning rate for every parameter group."""
|
| 44 |
-
for pg in optimizer.param_groups:
|
| 45 |
-
pg["lr"] = lr
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
-
# Optimizer factory
|
| 50 |
-
# ---------------------------------------------------------------------------
|
| 51 |
-
|
| 52 |
-
def create_optimizer(model: SageModel, config: SageConfig) -> torch.optim.AdamW:
|
| 53 |
-
"""
|
| 54 |
-
Create an AdamW optimizer with weight-decay applied only to weight
|
| 55 |
-
matrices (not biases or LayerNorm parameters).
|
| 56 |
-
"""
|
| 57 |
-
decay_params = []
|
| 58 |
-
no_decay_params = []
|
| 59 |
-
|
| 60 |
-
for name, param in model.named_parameters():
|
| 61 |
-
if not param.requires_grad:
|
| 62 |
-
continue
|
| 63 |
-
# Biases and LayerNorm weights should not be decayed
|
| 64 |
-
if param.ndim == 1 or "bias" in name:
|
| 65 |
-
no_decay_params.append(param)
|
| 66 |
-
else:
|
| 67 |
-
decay_params.append(param)
|
| 68 |
-
|
| 69 |
-
param_groups = [
|
| 70 |
-
{"params": decay_params, "weight_decay": config.weight_decay},
|
| 71 |
-
{"params": no_decay_params, "weight_decay": 0.0},
|
| 72 |
-
]
|
| 73 |
-
|
| 74 |
-
# Enable Fused AdamW for 10% speedup if CUDA is active
|
| 75 |
-
use_fused = torch.cuda.is_available() and 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames
|
| 76 |
-
optimizer = torch.optim.AdamW(
|
| 77 |
-
param_groups,
|
| 78 |
-
lr=config.learning_rate,
|
| 79 |
-
betas=(0.9, 0.95),
|
| 80 |
-
eps=1e-8,
|
| 81 |
-
fused=use_fused,
|
| 82 |
-
)
|
| 83 |
-
return optimizer
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ---------------------------------------------------------------------------
|
| 87 |
-
# Main training loop
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
|
| 90 |
-
def train(
|
| 91 |
-
model: SageModel,
|
| 92 |
-
config: SageConfig,
|
| 93 |
-
total_steps: int = 500,
|
| 94 |
-
dataset_name: str = "roneneldan/TinyStories",
|
| 95 |
-
resume: bool = True,
|
| 96 |
-
tokenizer: Optional[SageTokenizer] = None,
|
| 97 |
-
) -> SageModel:
|
| 98 |
-
"""
|
| 99 |
-
Run pre-training for *total_steps* gradient-update steps.
|
| 100 |
-
|
| 101 |
-
Parameters
|
| 102 |
-
----------
|
| 103 |
-
model : SageModel
|
| 104 |
-
The model to train (will be moved to config.device).
|
| 105 |
-
config : SageConfig
|
| 106 |
-
Hyperparameters.
|
| 107 |
-
total_steps : int
|
| 108 |
-
Number of optimizer steps to run.
|
| 109 |
-
dataset_name : str
|
| 110 |
-
HuggingFace dataset identifier.
|
| 111 |
-
resume : bool
|
| 112 |
-
If True, attempt to load the latest checkpoint before training.
|
| 113 |
-
tokenizer : SageTokenizer, optional
|
| 114 |
-
Tokenizer instance; one will be created if not supplied.
|
| 115 |
-
|
| 116 |
-
Returns
|
| 117 |
-
-------
|
| 118 |
-
SageModel
|
| 119 |
-
The trained model (on config.device).
|
| 120 |
-
"""
|
| 121 |
-
# --- TURBO MODE: TF32 & COMPILE ---
|
| 122 |
-
if torch.cuda.is_available():
|
| 123 |
-
torch.set_float32_matmul_precision('high')
|
| 124 |
-
|
| 125 |
-
device = config.device
|
| 126 |
-
model = model.to(device)
|
| 127 |
-
|
| 128 |
-
# Wrap model with torch.compile for graph-level optimization
|
| 129 |
-
if hasattr(torch, "compile"):
|
| 130 |
-
try:
|
| 131 |
-
logger.info("Turbo Mode: Compiling model graph...")
|
| 132 |
-
base = getattr(model, "module", model)
|
| 133 |
-
compiled_base = torch.compile(base, mode="reduce-overhead")
|
| 134 |
-
if hasattr(model, "module"):
|
| 135 |
-
model.module = compiled_base
|
| 136 |
-
else:
|
| 137 |
-
model = compiled_base
|
| 138 |
-
except (ValueError, RuntimeError, ImportError) as e:
|
| 139 |
-
# Graceful fallback: numpy compatibility issues or other compilation errors
|
| 140 |
-
logger.warning(f"torch.compile failed ({type(e).__name__}), proceeding without optimization: {str(e)[:100]}")
|
| 141 |
-
|
| 142 |
-
tok = tokenizer or SageTokenizer()
|
| 143 |
-
optimizer = create_optimizer(model, config)
|
| 144 |
-
|
| 145 |
-
# ------- resume from checkpoint if available -------
|
| 146 |
-
start_step = 0
|
| 147 |
-
if resume:
|
| 148 |
-
model, optimizer, start_step = load_checkpoint(
|
| 149 |
-
model, optimizer, config.checkpoint_dir, device=str(device)
|
| 150 |
-
)
|
| 151 |
-
if start_step >= total_steps:
|
| 152 |
-
logger.info("Checkpoint already at or past requested steps. Nothing to do.")
|
| 153 |
-
return model
|
| 154 |
-
|
| 155 |
-
# ------- mixed precision setup -------
|
| 156 |
-
use_amp = device.type == "cuda"
|
| 157 |
-
# prefer bf16 if the GPU supports it
|
| 158 |
-
amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16
|
| 159 |
-
scaler = GradScaler("cuda", enabled=(use_amp and amp_dtype == torch.float16))
|
| 160 |
-
|
| 161 |
-
# ------- data loader -------
|
| 162 |
-
loader = create_dataloader(config, dataset_name=dataset_name, tokenizer=tok)
|
| 163 |
-
data_iter = iter(loader)
|
| 164 |
-
|
| 165 |
-
# ------- W&B Logging -------
|
| 166 |
-
wandb.init(
|
| 167 |
-
project=config.project_name,
|
| 168 |
-
name=f"pretrain-{time.strftime('%Y%m%d-%H%M')}",
|
| 169 |
-
config=config.__dict__,
|
| 170 |
-
)
|
| 171 |
-
|
| 172 |
-
# ------- gradient checkpointing (saves VRAM) -------
|
| 173 |
-
base_model = getattr(model, "module", model)
|
| 174 |
-
if hasattr(base_model, "layers"):
|
| 175 |
-
for layer in base_model.layers:
|
| 176 |
-
layer: nn.Module
|
| 177 |
-
# PyTorch gradient checkpointing
|
| 178 |
-
try:
|
| 179 |
-
from torch.utils.checkpoint import checkpoint # noqa: F401
|
| 180 |
-
# We wrap the forward below instead, using it at call-site.
|
| 181 |
-
except ImportError:
|
| 182 |
-
pass
|
| 183 |
-
|
| 184 |
-
# ------- training loop -------
|
| 185 |
-
model.train()
|
| 186 |
-
accum_loss = 0.0
|
| 187 |
-
log_interval = 10
|
| 188 |
-
t0 = time.time()
|
| 189 |
-
|
| 190 |
-
pbar = tqdm(range(start_step, total_steps), desc="Training", unit="step")
|
| 191 |
-
micro_step = 0
|
| 192 |
-
|
| 193 |
-
for step in pbar:
|
| 194 |
-
# Update learning rate
|
| 195 |
-
lr = get_lr(step, config, total_steps)
|
| 196 |
-
set_lr(optimizer, lr)
|
| 197 |
-
|
| 198 |
-
# Accumulate gradients over multiple micro-batches
|
| 199 |
-
optimizer.zero_grad(set_to_none=True)
|
| 200 |
-
step_loss = 0.0
|
| 201 |
-
|
| 202 |
-
for micro in range(config.gradient_accumulation_steps):
|
| 203 |
-
try:
|
| 204 |
-
batch = next(data_iter)
|
| 205 |
-
except StopIteration:
|
| 206 |
-
# Restart the data stream when exhausted
|
| 207 |
-
data_iter = iter(loader)
|
| 208 |
-
batch = next(data_iter)
|
| 209 |
-
|
| 210 |
-
batch = batch.to(device)
|
| 211 |
-
inputs = batch[:, :-1] # all tokens except last
|
| 212 |
-
targets = batch[:, 1:] # all tokens except first
|
| 213 |
-
|
| 214 |
-
with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
|
| 215 |
-
logits, _ = model(inputs)
|
| 216 |
-
loss = nn.functional.cross_entropy(
|
| 217 |
-
logits.reshape(-1, logits.size(-1)),
|
| 218 |
-
targets.reshape(-1),
|
| 219 |
-
ignore_index=tok.pad_token_id,
|
| 220 |
-
)
|
| 221 |
-
# Scale loss by accumulation steps so the effective loss
|
| 222 |
-
# is independent of the number of micro-batches.
|
| 223 |
-
loss = loss / config.gradient_accumulation_steps
|
| 224 |
-
|
| 225 |
-
scaler.scale(loss).backward()
|
| 226 |
-
step_loss += loss.item()
|
| 227 |
-
|
| 228 |
-
# Gradient clipping (unscale first for correct norm computation)
|
| 229 |
-
scaler.unscale_(optimizer)
|
| 230 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
| 231 |
-
|
| 232 |
-
scaler.step(optimizer)
|
| 233 |
-
scaler.update()
|
| 234 |
-
|
| 235 |
-
accum_loss += step_loss
|
| 236 |
-
micro_step += 1
|
| 237 |
-
|
| 238 |
-
# ------- logging -------
|
| 239 |
-
if (step + 1) % log_interval == 0 or step == total_steps - 1:
|
| 240 |
-
avg_loss = accum_loss / log_interval
|
| 241 |
-
elapsed = time.time() - t0
|
| 242 |
-
perplexity = math.exp(min(avg_loss, 20)) # clamp to avoid overflow
|
| 243 |
-
pbar.set_postfix(
|
| 244 |
-
loss=f"{avg_loss:.4f}",
|
| 245 |
-
ppl=f"{perplexity:.2f}",
|
| 246 |
-
lr=f"{lr:.2e}",
|
| 247 |
-
elapsed=f"{elapsed:.1f}s",
|
| 248 |
-
)
|
| 249 |
-
logger.info(
|
| 250 |
-
f"step={step+1} | loss={avg_loss:.4f} | ppl={perplexity:.2f} | lr={lr:.2e}"
|
| 251 |
-
)
|
| 252 |
-
wandb.log({
|
| 253 |
-
"train/loss": avg_loss,
|
| 254 |
-
"train/perplexity": perplexity,
|
| 255 |
-
"train/lr": lr,
|
| 256 |
-
}, step=step + 1)
|
| 257 |
-
accum_loss = 0.0
|
| 258 |
-
|
| 259 |
-
# ------- checkpoint every 100 steps -------
|
| 260 |
-
if (step + 1) % 100 == 0 or step == total_steps - 1:
|
| 261 |
-
save_checkpoint(model, optimizer, step + 1, config.checkpoint_dir)
|
| 262 |
-
logger.info(f"Checkpoint saved at step {step + 1}")
|
| 263 |
-
|
| 264 |
-
logger.info("Training complete.")
|
| 265 |
-
wandb.finish()
|
| 266 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage/utils.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import logging
|
| 3 |
-
import torch
|
| 4 |
-
from typing import Optional, Tuple
|
| 5 |
-
|
| 6 |
-
def _get_logger(name: str) -> logging.Logger:
|
| 7 |
-
"""Simple logger getter to avoid circular imports."""
|
| 8 |
-
logger = logging.getLogger(name)
|
| 9 |
-
if not logger.handlers:
|
| 10 |
-
logger.setLevel(logging.INFO)
|
| 11 |
-
console_handler = logging.StreamHandler()
|
| 12 |
-
console_handler.setLevel(logging.INFO)
|
| 13 |
-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
| 14 |
-
console_handler.setFormatter(formatter)
|
| 15 |
-
logger.addHandler(console_handler)
|
| 16 |
-
logger.propagate = False
|
| 17 |
-
return logger
|
| 18 |
-
|
| 19 |
-
def get_compatible_device() -> torch.device:
|
| 20 |
-
"""
|
| 21 |
-
Returns the best available device with CUDA compatibility checking.
|
| 22 |
-
|
| 23 |
-
Automatically detects GPU compute capability and falls back to CPU
|
| 24 |
-
if the current PyTorch installation doesn't support the GPU.
|
| 25 |
-
"""
|
| 26 |
-
logger = _get_logger("sage.device")
|
| 27 |
-
|
| 28 |
-
# Check CUDA availability and compatibility
|
| 29 |
-
if torch.cuda.is_available():
|
| 30 |
-
gpu_name = torch.cuda.get_device_name(0)
|
| 31 |
-
capability = torch.cuda.get_device_capability()
|
| 32 |
-
major, minor = capability
|
| 33 |
-
sm_version = f"sm_{major}{minor}"
|
| 34 |
-
|
| 35 |
-
logger.info(f"Detected GPU: {gpu_name} (CUDA Capability: {sm_version})")
|
| 36 |
-
|
| 37 |
-
# PyTorch 2.0+ minimum is sm_70, PyTorch 1.13 supports sm_60
|
| 38 |
-
# Check if we can actually run model operations (embedding, linear, etc.)
|
| 39 |
-
try:
|
| 40 |
-
# Test 1: Basic tensor operation
|
| 41 |
-
test_tensor = torch.zeros(2, 4).cuda()
|
| 42 |
-
_ = test_tensor + test_tensor
|
| 43 |
-
|
| 44 |
-
# Test 2: Embedding (this is where P100/sm_60 often fails)
|
| 45 |
-
import torch.nn as nn
|
| 46 |
-
test_emb = nn.Embedding(10, 8).cuda()
|
| 47 |
-
test_indices = torch.tensor([0, 1, 2], dtype=torch.long).cuda()
|
| 48 |
-
_ = test_emb(test_indices)
|
| 49 |
-
|
| 50 |
-
# Test 3: Linear layer
|
| 51 |
-
test_linear = nn.Linear(8, 4).cuda()
|
| 52 |
-
_ = test_linear(test_emb(test_indices))
|
| 53 |
-
|
| 54 |
-
logger.info(f"✅ GPU is compatible with current PyTorch")
|
| 55 |
-
return torch.device("cuda")
|
| 56 |
-
except RuntimeError as e:
|
| 57 |
-
if "no kernel image is available" in str(e).lower():
|
| 58 |
-
logger.warning(f"⚠️ GPU {sm_version} not supported by current PyTorch")
|
| 59 |
-
logger.warning(f" Current PyTorch supports: {torch.cuda.get_arch_list() or 'sm_70+'}")
|
| 60 |
-
logger.warning(f" Install compatible PyTorch:")
|
| 61 |
-
if major < 7:
|
| 62 |
-
logger.warning(f" !pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121")
|
| 63 |
-
else:
|
| 64 |
-
logger.warning(f" !pip install torch --index-url https://download.pytorch.org/whl/cu118")
|
| 65 |
-
logger.warning(f" Falling back to CPU...")
|
| 66 |
-
else:
|
| 67 |
-
raise
|
| 68 |
-
|
| 69 |
-
# Check MPS (Apple Silicon)
|
| 70 |
-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 71 |
-
logger.info("Using Apple Silicon (MPS)")
|
| 72 |
-
return torch.device("mps")
|
| 73 |
-
|
| 74 |
-
logger.info("Using CPU")
|
| 75 |
-
return torch.device("cpu")
|
| 76 |
-
|
| 77 |
-
def setup_logger(name: str) -> logging.Logger:
|
| 78 |
-
"""Sets up a standardized logger for the SAGE system."""
|
| 79 |
-
logger = logging.getLogger(name)
|
| 80 |
-
if not logger.handlers:
|
| 81 |
-
logger.setLevel(logging.INFO)
|
| 82 |
-
console_handler = logging.StreamHandler()
|
| 83 |
-
console_handler.setLevel(logging.INFO)
|
| 84 |
-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
| 85 |
-
console_handler.setFormatter(formatter)
|
| 86 |
-
logger.addHandler(console_handler)
|
| 87 |
-
# Prevent propagation to the root logger to avoid double printing
|
| 88 |
-
logger.propagate = False
|
| 89 |
-
return logger
|
| 90 |
-
|
| 91 |
-
def save_checkpoint(
|
| 92 |
-
model: torch.nn.Module,
|
| 93 |
-
optimizer: Optional[torch.optim.Optimizer],
|
| 94 |
-
step: int,
|
| 95 |
-
checkpoint_dir: str,
|
| 96 |
-
filename: str = "sage_latest.pt"
|
| 97 |
-
) -> str:
|
| 98 |
-
"""Saves the model and optimizer state to a checkpoint file."""
|
| 99 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 100 |
-
path = os.path.join(checkpoint_dir, filename)
|
| 101 |
-
|
| 102 |
-
base_model = getattr(model, "module", model)
|
| 103 |
-
checkpoint = {
|
| 104 |
-
'step': step,
|
| 105 |
-
'model_state_dict': base_model.state_dict(),
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
if optimizer is not None:
|
| 109 |
-
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
| 110 |
-
|
| 111 |
-
torch.save(checkpoint, path)
|
| 112 |
-
return path
|
| 113 |
-
|
| 114 |
-
def load_checkpoint(
|
| 115 |
-
model: torch.nn.Module,
|
| 116 |
-
optimizer: Optional[torch.optim.Optimizer],
|
| 117 |
-
checkpoint_dir: str,
|
| 118 |
-
filename: str = "sage_latest.pt",
|
| 119 |
-
device: str = "cpu"
|
| 120 |
-
) -> Tuple[torch.nn.Module, Optional[torch.optim.Optimizer], int]:
|
| 121 |
-
"""Loads a checkpoint and restores the model and optimizer states."""
|
| 122 |
-
path = os.path.join(checkpoint_dir, filename)
|
| 123 |
-
|
| 124 |
-
if not os.path.exists(path):
|
| 125 |
-
logger = setup_logger("utils")
|
| 126 |
-
logger.warning(f"No checkpoint found at {path}. Starting from scratch.")
|
| 127 |
-
return model, optimizer, 0
|
| 128 |
-
|
| 129 |
-
# Load to CPU first to avoid VRAM spikes, then the module will be moved later if needed
|
| 130 |
-
checkpoint = torch.load(path, map_location=device)
|
| 131 |
-
|
| 132 |
-
base_model = getattr(model, "module", model)
|
| 133 |
-
base_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 134 |
-
|
| 135 |
-
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
|
| 136 |
-
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 137 |
-
|
| 138 |
-
step = checkpoint.get('step', 0)
|
| 139 |
-
|
| 140 |
-
logger = setup_logger("utils")
|
| 141 |
-
logger.info(f"Loaded checkpoint from {path} at step {step}")
|
| 142 |
-
|
| 143 |
-
return model, optimizer, step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sage_single.py
DELETED
|
@@ -1,824 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
SAGE — Self-Adaptive General Engine (Single-File Edition)
|
| 4 |
-
=========================================================
|
| 5 |
-
A complete mini-LLM in one file. Run with:
|
| 6 |
-
|
| 7 |
-
python sage_single.py
|
| 8 |
-
|
| 9 |
-
All architecture, data, training, inference, fine-tuning, quantization,
|
| 10 |
-
RAG, and CLI components are included below.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
import os
|
| 14 |
-
import re
|
| 15 |
-
import sys
|
| 16 |
-
import math
|
| 17 |
-
import copy
|
| 18 |
-
import time
|
| 19 |
-
import random
|
| 20 |
-
import logging
|
| 21 |
-
from dataclasses import dataclass
|
| 22 |
-
from typing import Iterator, List, Optional, Tuple
|
| 23 |
-
|
| 24 |
-
import numpy as np
|
| 25 |
-
import torch
|
| 26 |
-
import torch.nn as nn
|
| 27 |
-
import torch.nn.functional as F
|
| 28 |
-
import torch.nn.utils.prune as prune
|
| 29 |
-
from torch.amp import GradScaler, autocast
|
| 30 |
-
from torch.utils.data import IterableDataset, DataLoader
|
| 31 |
-
from tqdm import tqdm
|
| 32 |
-
import tiktoken
|
| 33 |
-
import wandb
|
| 34 |
-
|
| 35 |
-
__version__ = "1.0.0"
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# ===================================================================
|
| 39 |
-
# Section 1 — Configuration
|
| 40 |
-
# ===================================================================
|
| 41 |
-
|
| 42 |
-
@dataclass
|
| 43 |
-
class SageConfig:
|
| 44 |
-
d_model: int = 512
|
| 45 |
-
n_heads: int = 8
|
| 46 |
-
n_kv_heads: int = 4
|
| 47 |
-
n_layers: int = 6
|
| 48 |
-
d_ff: int = 2048
|
| 49 |
-
n_experts: int = 4
|
| 50 |
-
num_experts_per_tok: int = 2
|
| 51 |
-
vocab_size: int = 100277
|
| 52 |
-
max_seq_len: int = 1024
|
| 53 |
-
dropout: float = 0.1
|
| 54 |
-
batch_size: int = 4
|
| 55 |
-
gradient_accumulation_steps: int = 16
|
| 56 |
-
learning_rate: float = 3e-4
|
| 57 |
-
min_learning_rate: float = 1e-5
|
| 58 |
-
warmup_steps: int = 100
|
| 59 |
-
weight_decay: float = 0.01
|
| 60 |
-
max_grad_norm: float = 1.0
|
| 61 |
-
checkpoint_dir: str = "checkpoints"
|
| 62 |
-
project_name: str = "sage-v2"
|
| 63 |
-
|
| 64 |
-
@property
|
| 65 |
-
def device(self):
|
| 66 |
-
if torch.cuda.is_available():
|
| 67 |
-
return torch.device("cuda")
|
| 68 |
-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 69 |
-
return torch.device("mps")
|
| 70 |
-
return torch.device("cpu")
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# ===================================================================
|
| 74 |
-
# Section 2 — Logging & Checkpoint Utilities
|
| 75 |
-
# ===================================================================
|
| 76 |
-
|
| 77 |
-
def setup_logger(name: str) -> logging.Logger:
|
| 78 |
-
logger = logging.getLogger(name)
|
| 79 |
-
if not logger.handlers:
|
| 80 |
-
logger.setLevel(logging.INFO)
|
| 81 |
-
h = logging.StreamHandler()
|
| 82 |
-
h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s", datefmt="%H:%M:%S"))
|
| 83 |
-
logger.addHandler(h)
|
| 84 |
-
logger.propagate = False
|
| 85 |
-
return logger
|
| 86 |
-
|
| 87 |
-
logger = setup_logger("sage")
|
| 88 |
-
|
| 89 |
-
def save_checkpoint(model, optimizer, step, checkpoint_dir, filename="sage_latest.pt"):
|
| 90 |
-
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 91 |
-
path = os.path.join(checkpoint_dir, filename)
|
| 92 |
-
base = getattr(model, "module", model)
|
| 93 |
-
ckpt = {"step": step, "model_state_dict": base.state_dict()}
|
| 94 |
-
if optimizer is not None:
|
| 95 |
-
ckpt["optimizer_state_dict"] = optimizer.state_dict()
|
| 96 |
-
torch.save(ckpt, path)
|
| 97 |
-
return path
|
| 98 |
-
|
| 99 |
-
def load_checkpoint(model, optimizer, checkpoint_dir, filename="sage_latest.pt", device="cpu"):
|
| 100 |
-
path = os.path.join(checkpoint_dir, filename)
|
| 101 |
-
if not os.path.exists(path):
|
| 102 |
-
logger.warning(f"No checkpoint at {path}, starting fresh.")
|
| 103 |
-
return model, optimizer, 0
|
| 104 |
-
ckpt = torch.load(path, map_location=device)
|
| 105 |
-
base = getattr(model, "module", model)
|
| 106 |
-
base.load_state_dict(ckpt["model_state_dict"], strict=False)
|
| 107 |
-
if optimizer and "optimizer_state_dict" in ckpt:
|
| 108 |
-
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 109 |
-
step = ckpt.get("step", 0)
|
| 110 |
-
logger.info(f"Loaded checkpoint from {path} (step {step})")
|
| 111 |
-
return model, optimizer, step
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
# ===================================================================
|
| 115 |
-
# Section 3 — Tokenizer
|
| 116 |
-
# ===================================================================
|
| 117 |
-
|
| 118 |
-
class SageTokenizer:
|
| 119 |
-
def __init__(self, encoding_name="cl100k_base"):
|
| 120 |
-
self.enc = tiktoken.get_encoding(encoding_name)
|
| 121 |
-
self.eos_token_id = self.enc.n_vocab - 1
|
| 122 |
-
self.pad_token_id = self.enc.n_vocab - 2
|
| 123 |
-
self.vocab_size = self.enc.n_vocab
|
| 124 |
-
|
| 125 |
-
def encode(self, text, add_eos=False):
|
| 126 |
-
tokens = self.enc.encode(text, allowed_special="all")
|
| 127 |
-
if add_eos:
|
| 128 |
-
tokens.append(self.eos_token_id)
|
| 129 |
-
return tokens
|
| 130 |
-
|
| 131 |
-
def decode(self, tokens):
|
| 132 |
-
filtered = [t for t in tokens if t not in (self.eos_token_id, self.pad_token_id)]
|
| 133 |
-
return self.enc.decode(filtered)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
# ===================================================================
|
| 137 |
-
# Section 4 — Model Architecture (RoPE, Attention, MoE, Transformer)
|
| 138 |
-
# ===================================================================
|
| 139 |
-
|
| 140 |
-
def precompute_freqs_cis(dim, end, theta=10000.0):
|
| 141 |
-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim))
|
| 142 |
-
t = torch.arange(end, dtype=torch.float32)
|
| 143 |
-
freqs = torch.outer(t, freqs)
|
| 144 |
-
return torch.polar(torch.ones_like(freqs), freqs)
|
| 145 |
-
|
| 146 |
-
def apply_rotary_emb(xq, xk, freqs_cis):
|
| 147 |
-
# Ensure freqs_cis is complex (DataParallel can sometimes replicate it as real)
|
| 148 |
-
if not torch.is_complex(freqs_cis) and freqs_cis.shape[-1] == 2:
|
| 149 |
-
freqs_cis = torch.view_as_complex(freqs_cis)
|
| 150 |
-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 151 |
-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 152 |
-
fc = freqs_cis.unsqueeze(0).unsqueeze(2)
|
| 153 |
-
xq_out = torch.view_as_real(xq_ * fc).flatten(3)
|
| 154 |
-
xk_out = torch.view_as_real(xk_ * fc).flatten(3)
|
| 155 |
-
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 156 |
-
|
| 157 |
-
def repeat_kv(x, n_rep):
|
| 158 |
-
if n_rep == 1: return x
|
| 159 |
-
B, T, n_kv_heads, head_dim = x.size()
|
| 160 |
-
return x[:, :, :, None, :].expand(B, T, n_kv_heads, n_rep, head_dim).reshape(B, T, n_kv_heads * n_rep, head_dim)
|
| 161 |
-
|
| 162 |
-
class CausalSelfAttention(nn.Module):
|
| 163 |
-
def __init__(self, config):
|
| 164 |
-
super().__init__()
|
| 165 |
-
self.n_heads = config.n_heads
|
| 166 |
-
self.n_kv_heads = config.n_kv_heads
|
| 167 |
-
self.n_rep = self.n_heads // self.n_kv_heads
|
| 168 |
-
self.d_model = config.d_model
|
| 169 |
-
self.head_dim = config.d_model // config.n_heads
|
| 170 |
-
self.wq = nn.Linear(config.d_model, config.n_heads * self.head_dim, bias=False)
|
| 171 |
-
self.wk = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False)
|
| 172 |
-
self.wv = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False)
|
| 173 |
-
self.wo = nn.Linear(config.d_model, config.d_model, bias=False)
|
| 174 |
-
self.resid_dropout = nn.Dropout(config.dropout)
|
| 175 |
-
|
| 176 |
-
def forward(self, x, freqs_cis, kv_cache=None):
|
| 177 |
-
B, T, C = x.size()
|
| 178 |
-
q, k, v = self.wq(x), self.wk(x), self.wv(x)
|
| 179 |
-
q = q.view(B, T, self.n_heads, self.head_dim)
|
| 180 |
-
k = k.view(B, T, self.n_kv_heads, self.head_dim)
|
| 181 |
-
v = v.view(B, T, self.n_kv_heads, self.head_dim)
|
| 182 |
-
q, k = apply_rotary_emb(q, k, freqs_cis)
|
| 183 |
-
if kv_cache is not None:
|
| 184 |
-
k = torch.cat([kv_cache[0], k], dim=1)
|
| 185 |
-
v = torch.cat([kv_cache[1], v], dim=1)
|
| 186 |
-
new_kv = (k, v)
|
| 187 |
-
else:
|
| 188 |
-
new_kv = None
|
| 189 |
-
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
| 190 |
-
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
| 191 |
-
is_causal = kv_cache is None and T > 1
|
| 192 |
-
try:
|
| 193 |
-
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0 if not self.training else 0.1, is_causal=is_causal)
|
| 194 |
-
except Exception:
|
| 195 |
-
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
|
| 196 |
-
if is_causal:
|
| 197 |
-
mask = torch.tril(torch.ones(T, T, device=q.device)).view(1, 1, T, T)
|
| 198 |
-
attn = attn.masked_fill(mask == 0, float('-inf'))
|
| 199 |
-
attn = F.softmax(attn, dim=-1)
|
| 200 |
-
if self.training: attn = F.dropout(attn, p=0.1)
|
| 201 |
-
y = attn @ v
|
| 202 |
-
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 203 |
-
return self.resid_dropout(self.wo(y)), new_kv
|
| 204 |
-
|
| 205 |
-
class ExpertFFN(nn.Module):
|
| 206 |
-
def __init__(self, config):
|
| 207 |
-
super().__init__()
|
| 208 |
-
self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| 209 |
-
self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False)
|
| 210 |
-
self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False)
|
| 211 |
-
self.dropout = nn.Dropout(config.dropout)
|
| 212 |
-
|
| 213 |
-
def forward(self, x):
|
| 214 |
-
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
| 215 |
-
|
| 216 |
-
class MoE(nn.Module):
|
| 217 |
-
def __init__(self, config):
|
| 218 |
-
super().__init__()
|
| 219 |
-
self.n_experts = config.n_experts
|
| 220 |
-
self.top_k = config.num_experts_per_tok
|
| 221 |
-
self.router = nn.Linear(config.d_model, config.n_experts, bias=False)
|
| 222 |
-
self.experts = nn.ModuleList([ExpertFFN(config) for _ in range(config.n_experts)])
|
| 223 |
-
|
| 224 |
-
def forward(self, x):
|
| 225 |
-
B, T, C = x.size()
|
| 226 |
-
flat = x.view(-1, C)
|
| 227 |
-
weights = F.softmax(self.router(flat), dim=-1)
|
| 228 |
-
weights, indices = torch.topk(weights, self.top_k, dim=-1)
|
| 229 |
-
weights = weights / weights.sum(dim=-1, keepdim=True)
|
| 230 |
-
out = torch.zeros_like(flat)
|
| 231 |
-
for i, expert in enumerate(self.experts):
|
| 232 |
-
mask = (indices == i)
|
| 233 |
-
tok_idx, kth = torch.where(mask)
|
| 234 |
-
if tok_idx.shape[0] > 0:
|
| 235 |
-
out[tok_idx] += expert(flat[tok_idx]) * weights[tok_idx, kth].unsqueeze(-1)
|
| 236 |
-
return out.view(B, T, C)
|
| 237 |
-
|
| 238 |
-
class TransformerBlock(nn.Module):
|
| 239 |
-
def __init__(self, config):
|
| 240 |
-
super().__init__()
|
| 241 |
-
self.norm1 = nn.LayerNorm(config.d_model)
|
| 242 |
-
self.attn = CausalSelfAttention(config)
|
| 243 |
-
self.norm2 = nn.LayerNorm(config.d_model)
|
| 244 |
-
self.moe = MoE(config)
|
| 245 |
-
|
| 246 |
-
def forward(self, x, freqs_cis, kv_cache=None):
|
| 247 |
-
h, new_kv = self.attn(self.norm1(x), freqs_cis, kv_cache)
|
| 248 |
-
x = x + h
|
| 249 |
-
x = x + self.moe(self.norm2(x))
|
| 250 |
-
return x, new_kv
|
| 251 |
-
|
| 252 |
-
class SageModel(nn.Module):
|
| 253 |
-
def __init__(self, config):
|
| 254 |
-
super().__init__()
|
| 255 |
-
self.config = config
|
| 256 |
-
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
| 257 |
-
self.drop = nn.Dropout(config.dropout)
|
| 258 |
-
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
|
| 259 |
-
self.ln_f = nn.LayerNorm(config.d_model)
|
| 260 |
-
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # tied
|
| 261 |
-
self.wte.weight = self.lm_head.weight
|
| 262 |
-
self.register_buffer("freqs_cis", precompute_freqs_cis(config.d_model // config.n_heads, config.max_seq_len * 2), persistent=False)
|
| 263 |
-
self.apply(self._init_weights)
|
| 264 |
-
|
| 265 |
-
def _init_weights(self, m):
|
| 266 |
-
if isinstance(m, nn.Linear):
|
| 267 |
-
nn.init.normal_(m.weight, std=0.02)
|
| 268 |
-
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 269 |
-
elif isinstance(m, nn.Embedding):
|
| 270 |
-
nn.init.normal_(m.weight, std=0.02)
|
| 271 |
-
elif isinstance(m, nn.LayerNorm):
|
| 272 |
-
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
|
| 273 |
-
|
| 274 |
-
def forward(self, idx, kv_caches=None):
|
| 275 |
-
B, T = idx.size()
|
| 276 |
-
start = kv_caches[0][0].shape[1] if kv_caches else 0
|
| 277 |
-
fc = self.freqs_cis[start:start + T]
|
| 278 |
-
x = self.drop(self.wte(idx))
|
| 279 |
-
new_kvs = []
|
| 280 |
-
for i, layer in enumerate(self.layers):
|
| 281 |
-
kv = kv_caches[i] if kv_caches else None
|
| 282 |
-
if self.training and kv is None:
|
| 283 |
-
def create_custom_forward(module):
|
| 284 |
-
def custom_forward(x_in, freqs_cis_in):
|
| 285 |
-
return module(x_in, freqs_cis_in, None)
|
| 286 |
-
return custom_forward
|
| 287 |
-
x, nkv = torch.utils.checkpoint.checkpoint(create_custom_forward(layer), x, fc, use_reentrant=False)
|
| 288 |
-
else:
|
| 289 |
-
x, nkv = layer(x, fc, kv)
|
| 290 |
-
if nkv is not None: new_kvs.append(nkv)
|
| 291 |
-
return self.lm_head(self.ln_f(x)), new_kvs if new_kvs else None
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
# ===================================================================
|
| 295 |
-
# Section 5 — Data Pipeline
|
| 296 |
-
# ===================================================================
|
| 297 |
-
|
| 298 |
-
_HTML_RE = re.compile(r"<[^>]+>")
|
| 299 |
-
|
| 300 |
-
def clean_text(text):
|
| 301 |
-
text = _HTML_RE.sub("", text)
|
| 302 |
-
text = re.sub(r"[ \t]+", " ", text)
|
| 303 |
-
text = re.sub(r"\n{3,}", "\n\n", text)
|
| 304 |
-
return text.strip()
|
| 305 |
-
|
| 306 |
-
class StreamingTextDataset(IterableDataset):
|
| 307 |
-
def __init__(self, dataset_name="HuggingFaceFW/fineweb-edu", split="train", seq_len=512, tokenizer=None, buffer_size=1000, text_field="text"):
|
| 308 |
-
super().__init__()
|
| 309 |
-
self.dataset_name, self.split, self.seq_len = dataset_name, split, seq_len
|
| 310 |
-
self.tokenizer = tokenizer or SageTokenizer()
|
| 311 |
-
self.buffer_size, self.text_field = buffer_size, text_field
|
| 312 |
-
if "fineweb-edu" in dataset_name.lower(): self.text_field = "text"
|
| 313 |
-
elif "tinystories" in dataset_name.lower(): self.text_field = "text"
|
| 314 |
-
|
| 315 |
-
def _tokens(self):
|
| 316 |
-
from datasets import load_dataset
|
| 317 |
-
ds = load_dataset(self.dataset_name, split=self.split, streaming=True)
|
| 318 |
-
for s in ds:
|
| 319 |
-
raw = s.get(self.text_field, "")
|
| 320 |
-
if not raw or len(raw) < 50: continue
|
| 321 |
-
text = clean_text(raw)
|
| 322 |
-
yield from self.tokenizer.encode(text, add_eos=True)
|
| 323 |
-
|
| 324 |
-
def __iter__(self):
|
| 325 |
-
chunk, buf = [], []
|
| 326 |
-
for tok in self._tokens():
|
| 327 |
-
chunk.append(tok)
|
| 328 |
-
if len(chunk) == self.seq_len + 1:
|
| 329 |
-
buf.append(torch.tensor(chunk, dtype=torch.long))
|
| 330 |
-
chunk = []
|
| 331 |
-
if len(buf) >= self.buffer_size:
|
| 332 |
-
random.shuffle(buf)
|
| 333 |
-
while len(buf) > self.buffer_size // 2: yield buf.pop()
|
| 334 |
-
random.shuffle(buf)
|
| 335 |
-
yield from buf
|
| 336 |
-
|
| 337 |
-
def create_dataloader(config, dataset_name="HuggingFaceFW/fineweb-edu", tokenizer=None):
|
| 338 |
-
tok = tokenizer or SageTokenizer()
|
| 339 |
-
ds = StreamingTextDataset(dataset_name=dataset_name, seq_len=config.max_seq_len, tokenizer=tok)
|
| 340 |
-
return DataLoader(ds, batch_size=config.batch_size, num_workers=2, pin_memory=True, drop_last=True)
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
# ===================================================================
|
| 344 |
-
# Section 6 — Training
|
| 345 |
-
# ===================================================================
|
| 346 |
-
|
| 347 |
-
def get_lr(step, config, total_steps):
|
| 348 |
-
if step < config.warmup_steps:
|
| 349 |
-
return config.learning_rate * (step + 1) / config.warmup_steps
|
| 350 |
-
progress = (step - config.warmup_steps) / max(1, total_steps - config.warmup_steps)
|
| 351 |
-
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 352 |
-
return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate)
|
| 353 |
-
|
| 354 |
-
def create_optimizer(model, config):
|
| 355 |
-
decay, no_decay = [], []
|
| 356 |
-
for n, p in model.named_parameters():
|
| 357 |
-
if not p.requires_grad: continue
|
| 358 |
-
(no_decay if p.ndim == 1 or "bias" in n else decay).append(p)
|
| 359 |
-
# Enable Fused AdamW for 10% speedup if CUDA is active
|
| 360 |
-
use_fused = torch.cuda.is_available() and 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames
|
| 361 |
-
return torch.optim.AdamW([
|
| 362 |
-
{"params": decay, "weight_decay": config.weight_decay},
|
| 363 |
-
{"params": no_decay, "weight_decay": 0.0},
|
| 364 |
-
], lr=config.learning_rate, betas=(0.9, 0.95), fused=use_fused)
|
| 365 |
-
|
| 366 |
-
def train_model(model, config, total_steps=500, dataset_name="roneneldan/TinyStories", resume=True, tokenizer=None):
|
| 367 |
-
device = config.device
|
| 368 |
-
# --- TURBO MODE: TF32 & COMPILE ---
|
| 369 |
-
if torch.cuda.is_available():
|
| 370 |
-
torch.set_float32_matmul_precision('high')
|
| 371 |
-
|
| 372 |
-
model = model.to(device)
|
| 373 |
-
tok = tokenizer or SageTokenizer()
|
| 374 |
-
|
| 375 |
-
# Wrap model with torch.compile for graph-level optimization
|
| 376 |
-
# mode="reduce-overhead" is ideal for smaller-to-medium models like SAGE
|
| 377 |
-
if hasattr(torch, "compile"):
|
| 378 |
-
try:
|
| 379 |
-
logger.info("Turbo Mode: Compiling model graph...")
|
| 380 |
-
# Compile the base model (unwrapped from DataParallel if present)
|
| 381 |
-
base = getattr(model, "module", model)
|
| 382 |
-
compiled_base = torch.compile(base, mode="reduce-overhead")
|
| 383 |
-
if hasattr(model, "module"):
|
| 384 |
-
model.module = compiled_base
|
| 385 |
-
else:
|
| 386 |
-
model = compiled_base
|
| 387 |
-
except (ValueError, RuntimeError, ImportError) as e:
|
| 388 |
-
# Graceful fallback: numpy compatibility issues or other compilation errors
|
| 389 |
-
logger.warning(f"torch.compile failed ({type(e).__name__}), proceeding without optimization: {str(e)[:100]}")
|
| 390 |
-
# Continue with uncompiled model
|
| 391 |
-
|
| 392 |
-
opt = create_optimizer(model, config)
|
| 393 |
-
start_step = 0
|
| 394 |
-
if resume:
|
| 395 |
-
model, opt, start_step = load_checkpoint(model, opt, config.checkpoint_dir, device=str(device))
|
| 396 |
-
if start_step >= total_steps: return model
|
| 397 |
-
use_amp = device.type == "cuda"
|
| 398 |
-
amp_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16
|
| 399 |
-
scaler = GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16)
|
| 400 |
-
loader = create_dataloader(config, dataset_name, tok)
|
| 401 |
-
data_iter = iter(loader)
|
| 402 |
-
wandb.init(project=config.project_name, name=f"pretrain-{time.strftime('%Y%m%d-%H%M')}", config=config.__dict__)
|
| 403 |
-
model.train()
|
| 404 |
-
accum_loss, t0 = 0.0, time.time()
|
| 405 |
-
pbar = tqdm(range(start_step, total_steps), desc="Training")
|
| 406 |
-
for step in pbar:
|
| 407 |
-
lr = get_lr(step, config, total_steps)
|
| 408 |
-
for pg in opt.param_groups: pg["lr"] = lr
|
| 409 |
-
opt.zero_grad(set_to_none=True)
|
| 410 |
-
step_loss = 0.0
|
| 411 |
-
for _ in range(config.gradient_accumulation_steps):
|
| 412 |
-
try: batch = next(data_iter)
|
| 413 |
-
except StopIteration: data_iter = iter(loader); batch = next(data_iter)
|
| 414 |
-
batch = batch.to(device)
|
| 415 |
-
with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
|
| 416 |
-
logits, _ = model(batch[:, :-1])
|
| 417 |
-
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1), ignore_index=tok.pad_token_id)
|
| 418 |
-
loss = loss / config.gradient_accumulation_steps
|
| 419 |
-
scaler.scale(loss).backward()
|
| 420 |
-
step_loss += loss.item()
|
| 421 |
-
scaler.unscale_(opt)
|
| 422 |
-
nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
| 423 |
-
scaler.step(opt); scaler.update()
|
| 424 |
-
accum_loss += step_loss
|
| 425 |
-
if (step + 1) % 10 == 0:
|
| 426 |
-
avg = accum_loss / 10
|
| 427 |
-
pbar.set_postfix(loss=f"{avg:.4f}", ppl=f"{math.exp(min(avg,20)):.1f}", lr=f"{lr:.2e}")
|
| 428 |
-
wandb.log({"train/loss": avg, "train/perplexity": math.exp(min(avg, 20)), "train/lr": lr}, step=step + 1)
|
| 429 |
-
accum_loss = 0.0
|
| 430 |
-
if (step + 1) % 100 == 0:
|
| 431 |
-
save_checkpoint(model, opt, step + 1, config.checkpoint_dir)
|
| 432 |
-
save_checkpoint(model, opt, total_steps, config.checkpoint_dir)
|
| 433 |
-
logger.info("Training complete.")
|
| 434 |
-
wandb.finish()
|
| 435 |
-
return model
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
# ===================================================================
|
| 439 |
-
# Section 7 — Inference
|
| 440 |
-
# ===================================================================
|
| 441 |
-
|
| 442 |
-
def sample_next(logits, temperature=0.8, top_k=50, top_p=0.9, greedy=False):
|
| 443 |
-
if greedy: return logits.argmax(-1, keepdim=True)
|
| 444 |
-
logits = logits / max(temperature, 1e-8)
|
| 445 |
-
if 0 < top_k < logits.size(-1):
|
| 446 |
-
v, _ = torch.topk(logits, top_k)
|
| 447 |
-
logits[logits < v[:, -1:]] = float("-inf")
|
| 448 |
-
if top_p < 1.0:
|
| 449 |
-
sorted_l, sorted_i = torch.sort(logits, descending=True)
|
| 450 |
-
cum = torch.cumsum(F.softmax(sorted_l, -1), -1)
|
| 451 |
-
mask = cum - F.softmax(sorted_l, -1) >= top_p
|
| 452 |
-
sorted_l[mask] = float("-inf")
|
| 453 |
-
logits = logits.scatter(1, sorted_i, sorted_l)
|
| 454 |
-
return torch.multinomial(F.softmax(logits, -1), 1)
|
| 455 |
-
|
| 456 |
-
@torch.no_grad()
|
| 457 |
-
def generate(model, tokenizer, prompt, max_new=256, temperature=0.8, top_k=50, top_p=0.9, stream=True, device=None):
|
| 458 |
-
device = device or next(model.parameters()).device
|
| 459 |
-
base = getattr(model, "module", model)
|
| 460 |
-
base.eval()
|
| 461 |
-
ids = tokenizer.encode(prompt) or [tokenizer.eos_token_id]
|
| 462 |
-
inp = torch.tensor([ids], dtype=torch.long, device=device)
|
| 463 |
-
logits, kvs = base(inp)
|
| 464 |
-
gen = list(ids)
|
| 465 |
-
nl = logits[:, -1, :]
|
| 466 |
-
for _ in range(max_new):
|
| 467 |
-
nid = sample_next(nl, temperature, top_k, top_p)
|
| 468 |
-
tid = nid.item()
|
| 469 |
-
if tid == tokenizer.eos_token_id: break
|
| 470 |
-
gen.append(tid)
|
| 471 |
-
if stream: print(tokenizer.decode([tid]), end="", flush=True)
|
| 472 |
-
logits, kvs = base(nid.view(1, 1), kv_caches=kvs)
|
| 473 |
-
nl = logits[:, -1, :]
|
| 474 |
-
if stream: print()
|
| 475 |
-
base.train()
|
| 476 |
-
return tokenizer.decode(gen)
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
# ===================================================================
|
| 480 |
-
# Section 8 — LoRA Fine-tuning
|
| 481 |
-
# ===================================================================
|
| 482 |
-
|
| 483 |
-
class LoRALinear(nn.Module):
|
| 484 |
-
def __init__(self, original, rank=8, alpha=16.0):
|
| 485 |
-
super().__init__()
|
| 486 |
-
self.original = original
|
| 487 |
-
self.scaling = alpha / rank
|
| 488 |
-
device, dtype = original.weight.device, original.weight.dtype
|
| 489 |
-
self.lora_A = nn.Parameter(torch.randn(original.in_features, rank, device=device, dtype=dtype) * 0.01)
|
| 490 |
-
self.lora_B = nn.Parameter(torch.zeros(rank, original.out_features, device=device, dtype=dtype))
|
| 491 |
-
original.weight.requires_grad = False
|
| 492 |
-
if original.bias is not None: original.bias.requires_grad = False
|
| 493 |
-
|
| 494 |
-
def forward(self, x):
|
| 495 |
-
return self.original(x) + (x @ self.lora_A @ self.lora_B) * self.scaling
|
| 496 |
-
|
| 497 |
-
def merge(self):
|
| 498 |
-
m = copy.deepcopy(self.original)
|
| 499 |
-
m.weight.data += (self.lora_B.T @ self.lora_A.T).T * self.scaling
|
| 500 |
-
m.weight.requires_grad = True
|
| 501 |
-
return m
|
| 502 |
-
|
| 503 |
-
def inject_lora(model, rank=8, alpha=16.0):
|
| 504 |
-
base = getattr(model, "module", model)
|
| 505 |
-
for layer in base.layers:
|
| 506 |
-
a = layer.attn
|
| 507 |
-
for name in ("wq", "wk", "wv", "wo"):
|
| 508 |
-
setattr(a, name, LoRALinear(getattr(a, name), rank, alpha))
|
| 509 |
-
tp = sum(p.numel() for p in base.parameters() if p.requires_grad)
|
| 510 |
-
logger.info(f"LoRA injected (rank={rank}). Trainable params: {tp:,}")
|
| 511 |
-
return model
|
| 512 |
-
|
| 513 |
-
def merge_lora(model):
|
| 514 |
-
base = getattr(model, "module", model)
|
| 515 |
-
for layer in base.layers:
|
| 516 |
-
a = layer.attn
|
| 517 |
-
for name in ("wq", "wk", "wv", "wo"):
|
| 518 |
-
m = getattr(a, name)
|
| 519 |
-
if isinstance(m, LoRALinear): setattr(a, name, m.merge())
|
| 520 |
-
logger.info("LoRA merged.")
|
| 521 |
-
return model
|
| 522 |
-
|
| 523 |
-
INSTRUCTION_TEMPLATE = "### Instruction:\n{instruction}\n\n### Response:\n{response}"
|
| 524 |
-
|
| 525 |
-
DEMO_SAMPLES = [
|
| 526 |
-
{"instruction": "What is the capital of France?", "response": "The capital of France is Paris."},
|
| 527 |
-
{"instruction": "Explain gravity simply.", "response": "Gravity pulls objects toward each other. More mass means stronger pull."},
|
| 528 |
-
{"instruction": "Write a short poem about the ocean.", "response": "Waves crash on sandy shore,\nThe ocean sings forevermore.\nDeep blue meets the sky,\nSeagulls dance and clouds float by."},
|
| 529 |
-
{"instruction": "What is 15 times 12?", "response": "15 times 12 equals 180."},
|
| 530 |
-
{"instruction": "Summarize photosynthesis.", "response": "Plants convert sunlight, water, and CO2 into glucose and oxygen."},
|
| 531 |
-
{"instruction": "Tell me a fun fact about space.", "response": "A day on Venus is longer than its year — 243 Earth days to rotate vs 225 to orbit the Sun."},
|
| 532 |
-
{"instruction": "How do airplanes fly?", "response": "Wings generate lift because air moves faster over the curved top, creating lower pressure above."},
|
| 533 |
-
{"instruction": "What is machine learning?", "response": "ML is AI where computers learn patterns from data instead of being explicitly programmed."},
|
| 534 |
-
]
|
| 535 |
-
|
| 536 |
-
def create_instruction_batch(samples, tokenizer, max_len=512):
|
| 537 |
-
all_ids, all_masks = [], []
|
| 538 |
-
for s in samples:
|
| 539 |
-
inst_text = f"### Instruction:\n{s['instruction'].strip()}\n\n### Response:\n"
|
| 540 |
-
full_text = inst_text + s["response"].strip()
|
| 541 |
-
inst_toks = tokenizer.encode(inst_text)
|
| 542 |
-
full_toks = tokenizer.encode(full_text, add_eos=True)[:max_len]
|
| 543 |
-
ni = min(len(inst_toks), len(full_toks))
|
| 544 |
-
mask = [0] * ni + [1] * (len(full_toks) - ni)
|
| 545 |
-
pad = max_len - len(full_toks)
|
| 546 |
-
full_toks += [tokenizer.pad_token_id] * pad
|
| 547 |
-
mask += [0] * pad
|
| 548 |
-
all_ids.append(full_toks); all_masks.append(mask)
|
| 549 |
-
return {"input_ids": torch.tensor(all_ids), "labels": torch.tensor(all_ids), "loss_mask": torch.tensor(all_masks, dtype=torch.float32)}
|
| 550 |
-
|
| 551 |
-
def finetune(model, config, samples=None, steps=200, use_lora=True, tokenizer=None):
|
| 552 |
-
device = config.device; model = model.to(device)
|
| 553 |
-
tok = tokenizer or SageTokenizer()
|
| 554 |
-
samples = samples or DEMO_SAMPLES
|
| 555 |
-
if use_lora: model = inject_lora(model)
|
| 556 |
-
opt = create_optimizer(model, config)
|
| 557 |
-
use_amp = device.type == "cuda"
|
| 558 |
-
amp_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16
|
| 559 |
-
scaler = GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16)
|
| 560 |
-
wandb.init(project=config.project_name, name=f"finetune-{time.strftime('%Y%m%d-%H%M')}", config=config.__dict__)
|
| 561 |
-
model.train(); accum = 0.0
|
| 562 |
-
for step in tqdm(range(steps), desc="Fine-tuning"):
|
| 563 |
-
lr = get_lr(step, config, steps)
|
| 564 |
-
for pg in opt.param_groups: pg["lr"] = lr
|
| 565 |
-
batch = create_instruction_batch(random.choices(samples, k=min(config.batch_size, len(samples))), tok, config.max_seq_len)
|
| 566 |
-
ids, labels, mask = batch["input_ids"].to(device), batch["labels"].to(device), batch["loss_mask"].to(device)
|
| 567 |
-
opt.zero_grad(set_to_none=True)
|
| 568 |
-
with autocast(device.type, dtype=amp_dtype, enabled=use_amp):
|
| 569 |
-
logits, _ = model(ids)
|
| 570 |
-
sl, slb, sm = logits[:, :-1, :].contiguous(), labels[:, 1:].contiguous(), mask[:, 1:].contiguous()
|
| 571 |
-
ptl = F.cross_entropy(sl.view(-1, sl.size(-1)), slb.view(-1), reduction="none").view(slb.size())
|
| 572 |
-
loss = (ptl * sm).sum() / sm.sum().clamp(min=1)
|
| 573 |
-
scaler.scale(loss).backward()
|
| 574 |
-
scaler.unscale_(opt); nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
| 575 |
-
scaler.step(opt); scaler.update()
|
| 576 |
-
accum += loss.item()
|
| 577 |
-
if (step + 1) % 10 == 0: accum = 0.0
|
| 578 |
-
if use_lora: model = merge_lora(model)
|
| 579 |
-
save_checkpoint(model, None, steps, config.checkpoint_dir, "sage_finetuned.pt")
|
| 580 |
-
logger.info("Fine-tuning complete.")
|
| 581 |
-
wandb.finish()
|
| 582 |
-
return model
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
# ===================================================================
|
| 586 |
-
# Section 9 — Optimization (Quantize / Prune)
|
| 587 |
-
# ===================================================================
|
| 588 |
-
|
| 589 |
-
def quantize_int8(model):
|
| 590 |
-
base = getattr(model, "module", model)
|
| 591 |
-
model = base.cpu().eval()
|
| 592 |
-
q = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
|
| 593 |
-
logger.info("INT8 quantization complete.")
|
| 594 |
-
return q
|
| 595 |
-
|
| 596 |
-
def prune_model(model, amount=0.3):
|
| 597 |
-
base = getattr(model, "module", model)
|
| 598 |
-
for _, m in base.named_modules():
|
| 599 |
-
if isinstance(m, nn.Linear):
|
| 600 |
-
prune.l1_unstructured(m, "weight", amount=amount)
|
| 601 |
-
prune.remove(m, "weight")
|
| 602 |
-
logger.info(f"Pruning complete ({amount*100:.0f}% sparsity target).")
|
| 603 |
-
return model
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
# ===================================================================
|
| 607 |
-
# Section 10 — RAG & Memory
|
| 608 |
-
# ===================================================================
|
| 609 |
-
|
| 610 |
-
def _embed(text, tokenizer, model, device):
|
| 611 |
-
toks = tokenizer.encode(text)
|
| 612 |
-
base = getattr(model, "module", model)
|
| 613 |
-
if not toks: return np.zeros(base.wte.weight.shape[1], dtype=np.float32)
|
| 614 |
-
with torch.no_grad():
|
| 615 |
-
emb = base.wte(torch.tensor([toks], device=device)).mean(1)
|
| 616 |
-
emb = F.normalize(emb, p=2, dim=-1)
|
| 617 |
-
return emb.squeeze(0).cpu().numpy()
|
| 618 |
-
|
| 619 |
-
class VectorStore:
|
| 620 |
-
def __init__(self, dim):
|
| 621 |
-
self.dim = dim; self.docs = []; self.index = None
|
| 622 |
-
try:
|
| 623 |
-
import faiss
|
| 624 |
-
self.index = faiss.IndexFlatIP(dim)
|
| 625 |
-
except ImportError:
|
| 626 |
-
logger.warning("FAISS not installed. RAG will use brute-force search.")
|
| 627 |
-
|
| 628 |
-
def add(self, texts, embeddings):
|
| 629 |
-
if self.index is not None:
|
| 630 |
-
self.index.add(embeddings.astype(np.float32))
|
| 631 |
-
else:
|
| 632 |
-
# Brute-force fallback
|
| 633 |
-
if not hasattr(self, '_embeddings'):
|
| 634 |
-
self._embeddings = []
|
| 635 |
-
self._embeddings.extend(embeddings.astype(np.float32))
|
| 636 |
-
self.docs.extend(texts)
|
| 637 |
-
|
| 638 |
-
def search(self, qemb, k=3):
|
| 639 |
-
if not self.docs: return []
|
| 640 |
-
k = min(k, len(self.docs))
|
| 641 |
-
if self.index is not None:
|
| 642 |
-
scores, idx = self.index.search(qemb.reshape(1, -1).astype(np.float32), k)
|
| 643 |
-
return [(self.docs[i], float(s)) for s, i in zip(scores[0], idx[0]) if i >= 0]
|
| 644 |
-
else:
|
| 645 |
-
# Brute-force cosine similarity
|
| 646 |
-
import numpy as np
|
| 647 |
-
qemb = qemb.reshape(1, -1).astype(np.float32)
|
| 648 |
-
embs = np.array(self._embeddings)
|
| 649 |
-
sims = np.dot(embs, qemb.T).flatten()
|
| 650 |
-
top_k = np.argsort(sims)[-k:][::-1]
|
| 651 |
-
return [(self.docs[i], float(sims[i])) for i in top_k]
|
| 652 |
-
|
| 653 |
-
@property
|
| 654 |
-
def size(self): return len(self.docs)
|
| 655 |
-
|
| 656 |
-
class RAGManager:
|
| 657 |
-
def __init__(self, model, tokenizer, device, chunk_size=200):
|
| 658 |
-
self.model, self.tokenizer, self.device = model, tokenizer, device
|
| 659 |
-
base = getattr(model, "module", model)
|
| 660 |
-
self.store = VectorStore(base.wte.weight.shape[1])
|
| 661 |
-
self.enabled = False
|
| 662 |
-
|
| 663 |
-
def add_documents(self, texts):
|
| 664 |
-
chunks = []
|
| 665 |
-
for t in texts:
|
| 666 |
-
words = t.split()
|
| 667 |
-
for i in range(0, len(words), 150):
|
| 668 |
-
chunks.append(" ".join(words[i:i+200]))
|
| 669 |
-
if chunks:
|
| 670 |
-
embs = np.stack([_embed(c, self.tokenizer, self.model, self.device) for c in chunks])
|
| 671 |
-
self.store.add(chunks, embs)
|
| 672 |
-
|
| 673 |
-
def retrieve(self, query, k=3):
|
| 674 |
-
if not self.enabled or not self.store.size: return ""
|
| 675 |
-
qe = _embed(query, self.tokenizer, self.model, self.device)
|
| 676 |
-
results = self.store.search(qe, k)
|
| 677 |
-
return "\n\n".join(f"[Context {i+1}] {d}" for i, (d, _) in enumerate(results)) + "\n\n" if results else ""
|
| 678 |
-
|
| 679 |
-
def toggle(self, on): self.enabled = on
|
| 680 |
-
|
| 681 |
-
DEFAULT_SYSTEM_PROMPT = (
|
| 682 |
-
"You are a high-quality reasoning assistant model.\n"
|
| 683 |
-
"You must ONLY learn from high-quality instruction and reasoning datasets.\n"
|
| 684 |
-
"You must IGNORE any previously trained low-quality or repetitive patterns.\n\n"
|
| 685 |
-
"Training preference rules:\n"
|
| 686 |
-
"1. Prioritize step-by-step reasoning over short or repetitive answers.\n"
|
| 687 |
-
"2. Always produce structured logical explanations when solving problems.\n"
|
| 688 |
-
"3. Avoid repetition, filler words, or looped phrases.\n"
|
| 689 |
-
"4. Prefer datasets with mathematical reasoning and high-quality instruction.\n"
|
| 690 |
-
"5. Do not imitate noisy conversational or corrupted text patterns.\n"
|
| 691 |
-
"6. Always prefer clarity, correctness, and structured reasoning.\n\n"
|
| 692 |
-
"Output behavior goal:\n"
|
| 693 |
-
"- Think in steps.\n"
|
| 694 |
-
"- Explain logic clearly.\n"
|
| 695 |
-
"- Produce final answer only after reasoning."
|
| 696 |
-
)
|
| 697 |
-
|
| 698 |
-
class ConversationHistory:
|
| 699 |
-
def __init__(self, tokenizer, max_tokens=900):
|
| 700 |
-
self.tokenizer, self.max_tokens, self.turns = tokenizer, max_tokens, []
|
| 701 |
-
|
| 702 |
-
def add(self, role, text):
|
| 703 |
-
self.turns.append({"role": role, "text": text})
|
| 704 |
-
while sum(len(self.tokenizer.encode(t["text"])) for t in self.turns) > self.max_tokens and len(self.turns) > 1:
|
| 705 |
-
self.turns.pop(0)
|
| 706 |
-
|
| 707 |
-
def build_prompt(self, msg, rag_ctx=""):
|
| 708 |
-
parts = [DEFAULT_SYSTEM_PROMPT]
|
| 709 |
-
if rag_ctx: parts.append(rag_ctx)
|
| 710 |
-
for t in self.turns:
|
| 711 |
-
parts.append(f"{'User' if t['role']=='user' else 'SAGE'}: {t['text']}")
|
| 712 |
-
parts += [f"User: {msg}", "SAGE:"]
|
| 713 |
-
return "\n\n".join(parts)
|
| 714 |
-
|
| 715 |
-
def clear(self): self.turns.clear()
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
# ===================================================================
|
| 719 |
-
# Section 11 — CLI
|
| 720 |
-
# ===================================================================
|
| 721 |
-
|
| 722 |
-
BANNER = r"""
|
| 723 |
-
╔══════════════════════════════════════════════════════════════╗
|
| 724 |
-
║ ███████ █████ ██████ ███████ ║
|
| 725 |
-
║ ██ ██ ██ ██ ██ ║
|
| 726 |
-
║ ███████ ███████ ██ ███ █████ ║
|
| 727 |
-
║ ██ ██ ██ ██ ██ ██ ║
|
| 728 |
-
║ ███████ ██ ██ ██████ ███████ ║
|
| 729 |
-
║ Self-Adaptive General Engine v{ver} ║
|
| 730 |
-
╚══════════════════════════════════════════════════════════════╝"""
|
| 731 |
-
|
| 732 |
-
HELP = """
|
| 733 |
-
/train [steps] Train (default 100)
|
| 734 |
-
/finetune [steps] Instruction-tune with LoRA (default 200)
|
| 735 |
-
/save Save checkpoint
|
| 736 |
-
/load Load checkpoint
|
| 737 |
-
/quantize INT8 quantization
|
| 738 |
-
/rag on|off|add Toggle or add docs for RAG
|
| 739 |
-
/clear Clear history
|
| 740 |
-
/help This message
|
| 741 |
-
/exit Quit
|
| 742 |
-
"""
|
| 743 |
-
|
| 744 |
-
def main():
|
| 745 |
-
config = SageConfig()
|
| 746 |
-
tok = SageTokenizer()
|
| 747 |
-
config.vocab_size = tok.vocab_size
|
| 748 |
-
print(" Initializing SAGE …")
|
| 749 |
-
model = SageModel(config).to(config.device)
|
| 750 |
-
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
| 751 |
-
print(f" Multi-GPU detected: {torch.cuda.device_count()} GPUs. Using DataParallel.")
|
| 752 |
-
model = nn.DataParallel(model)
|
| 753 |
-
model, _, step = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device))
|
| 754 |
-
base = getattr(model, "module", model)
|
| 755 |
-
total = sum(p.numel() for p in base.parameters())
|
| 756 |
-
print(BANNER.format(ver=__version__))
|
| 757 |
-
print(f" Params: {total:,} ({total/1e6:.1f}M) | Context: {config.max_seq_len} | Device: {config.device}")
|
| 758 |
-
print(f" Layers: {config.n_layers} | Heads: {config.n_heads} | Experts: {config.n_experts}")
|
| 759 |
-
if step: print(f" Resumed from step {step}")
|
| 760 |
-
print(" Type /help for commands.\n")
|
| 761 |
-
|
| 762 |
-
rag = RAGManager(model, tok, config.device)
|
| 763 |
-
hist = ConversationHistory(tok, config.max_seq_len - 128)
|
| 764 |
-
|
| 765 |
-
if len(sys.argv) > 1:
|
| 766 |
-
cmd = sys.argv[1].lower()
|
| 767 |
-
args = sys.argv[2:]
|
| 768 |
-
if cmd == "--train":
|
| 769 |
-
s = int(args[0]) if args else 100
|
| 770 |
-
train_model(model, config, s, tokenizer=tok)
|
| 771 |
-
return
|
| 772 |
-
elif cmd == "--finetune":
|
| 773 |
-
s = int(args[0]) if args else 200
|
| 774 |
-
finetune(model, config, steps=s, tokenizer=tok)
|
| 775 |
-
return
|
| 776 |
-
elif cmd == "--quantize":
|
| 777 |
-
quantize_int8(model)
|
| 778 |
-
return
|
| 779 |
-
else:
|
| 780 |
-
print(f" Unknown argument: {cmd}\n Usage: --train [steps] | --finetune [steps] | --quantize")
|
| 781 |
-
return
|
| 782 |
-
|
| 783 |
-
while True:
|
| 784 |
-
try: inp = input("You: ").strip()
|
| 785 |
-
except (EOFError, KeyboardInterrupt): print("\n Goodbye!"); break
|
| 786 |
-
if not inp: continue
|
| 787 |
-
|
| 788 |
-
if inp.startswith("/"):
|
| 789 |
-
parts = inp.split(); cmd = parts[0].lower(); args = parts[1:]
|
| 790 |
-
if cmd == "/exit": print(" Goodbye!"); break
|
| 791 |
-
elif cmd == "/help": print(HELP)
|
| 792 |
-
elif cmd == "/train":
|
| 793 |
-
s = int(args[0]) if args else 100
|
| 794 |
-
model = train_model(model, config, s, tokenizer=tok)
|
| 795 |
-
print("\n Sample:"); generate(model, tok, "Once upon a time", max_new=80, device=config.device); print()
|
| 796 |
-
elif cmd == "/finetune":
|
| 797 |
-
s = int(args[0]) if args else 200
|
| 798 |
-
model = finetune(model, config, steps=s, tokenizer=tok)
|
| 799 |
-
print("\n Sample:"); generate(model, tok, "### Instruction:\nWhat is gravity?\n\n### Response:\n", max_new=100, device=config.device); print()
|
| 800 |
-
elif cmd == "/save": print(f" Saved to {save_checkpoint(model, None, 0, config.checkpoint_dir)}")
|
| 801 |
-
elif cmd == "/load":
|
| 802 |
-
model, _, s = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device))
|
| 803 |
-
model = model.to(config.device); rag.model = model; print(f" Loaded (step {s})")
|
| 804 |
-
elif cmd == "/quantize": model = quantize_int8(model); rag.model = model
|
| 805 |
-
elif cmd == "/rag":
|
| 806 |
-
if not args: print(f" RAG {'on' if rag.enabled else 'off'} ({rag.store.size} chunks)")
|
| 807 |
-
elif args[0] == "on": rag.toggle(True); print(" RAG on.")
|
| 808 |
-
elif args[0] == "off": rag.toggle(False); print(" RAG off.")
|
| 809 |
-
elif args[0] == "add" and len(args) > 1: rag.add_documents([" ".join(args[1:])]); print(f" Added. {rag.store.size} chunks.")
|
| 810 |
-
else: print(" /rag on|off|add <text>")
|
| 811 |
-
elif cmd == "/clear": hist.clear(); print(" Cleared.")
|
| 812 |
-
else: print(f" Unknown: {cmd}")
|
| 813 |
-
continue
|
| 814 |
-
|
| 815 |
-
ctx = rag.retrieve(inp)
|
| 816 |
-
prompt = hist.build_prompt(inp, ctx)
|
| 817 |
-
hist.add("user", inp)
|
| 818 |
-
print("SAGE: ", end="", flush=True)
|
| 819 |
-
resp = generate(model, tok, prompt, max_new=256, stream=True, device=config.device)
|
| 820 |
-
reply = resp.split("SAGE:")[-1].strip() if "SAGE:" in resp else resp[len(prompt):].strip()
|
| 821 |
-
hist.add("assistant", reply)
|
| 822 |
-
|
| 823 |
-
if __name__ == "__main__":
|
| 824 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_data_pipeline.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
python -m tokenizer.train_tokenizer "$@"
|
scripts/run_eval.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
python - <<'PY'
|
| 5 |
+
from eval.benchmarks import run_registered_benchmarks
|
| 6 |
+
from model.model import SageTransformer
|
| 7 |
+
from model.config import ModelConfig
|
| 8 |
+
|
| 9 |
+
model = SageTransformer(ModelConfig())
|
| 10 |
+
for result in run_registered_benchmarks(model):
|
| 11 |
+
print(result)
|
| 12 |
+
PY
|
scripts/run_serve.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
uvicorn serve.server:app --host "${HOST:-0.0.0.0}" --port "${PORT:-8000}" "$@"
|
scripts/run_serve_cpu.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
uvicorn serve.server_cpu:app --host "${HOST:-0.0.0.0}" --port "${PORT:-8001}" "$@"
|