diff --git a/.gitattributes b/.gitattributes index 706b1733cec4a3062dc71910ff19fd6912033408..f0a55ba2a8c1da056c2c7f7d801c4328c9ccb3d0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -5,7 +5,7 @@ # Git files .git/* .gitignore - +hf_push.py # Python virtual environments .venv/* venv/* diff --git a/.gitignore b/.gitignore index aa74a95afe1e407168768c23cb541591311eefb5..b1c22b188ccfcb05901a3b84d5343faa39508378 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ __pycache__/ *.py[cod] *$py.class +wandb/ + # C extensions *.so @@ -25,8 +27,6 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST -hf_push.py -.hugging_face_ignore # PyInstaller # Usually these files are written by a python script, before a-one-file pack @@ -112,6 +112,12 @@ celerybeat.pid # Sage Project Specific checkpoints/ +runs/ +tokenizer/*.model +tokenizer/*.vocab +tokenizer/training_corpus.txt +data/raw/ +data/processed/ .venv/ .env .DS_Store diff --git a/README.md b/README.md index 0eccecce84678e420bdc80f070df5b8d130c117a..90346463a5f1391780de09292b1fa3ddb59851ca 100644 --- a/README.md +++ b/README.md @@ -1,163 +1,455 @@ -# SAGE โ€” Self-Adaptive General Engine +# SAGE 1B + +SAGE is a root-level rewrite of this repository into a production-style dense language model project. The current baseline is a 1B-class decoder-only transformer with RMSNorm, RoPE, grouped-query attention, SwiGLU, SentencePiece, resumable training, Parquet-backed datasets, and FastAPI serving. + +This README is written as a practical operator guide. It tells you: + +- what the project contains +- what is already implemented +- what commands to run +- what files are inputs and outputs +- what parts are scaffolding versus fully wired + +## What SAGE Is + +SAGE is organized into these layers: + +1. `tokenizer/` + Trains and validates a SentencePiece tokenizer. +2. `data/` + Handles raw corpus ingest, filtering, deduplication, sharding, and packed datasets. +3. `model/` + Implements the dense decoder-only transformer. +4. `train/` + Handles optimizer setup, scheduler, hardware detection, checkpoints, and the training loop. +5. `eval/` + Provides perplexity evaluation and benchmark harness registration. +6. `serve/` + Exposes FastAPI servers and quantization helpers. + +## Current Baseline + +| Component | Value | +| --- | --- | +| Layers | 24 | +| d_model | 2048 | +| Attention heads | 16 | +| KV heads | 8 | +| Head dim | 128 | +| FFN dim | 5632 | +| Context length | 4096 | +| Vocab size | 50000 | +| Norm | RMSNorm | +| Positional encoding | RoPE | +| Attention | GQA + SDPA | +| Activation | SwiGLU | +| Weight tying | Enabled | + +## Repository Layout -**SAGE** is a senior-grade, production-structured Large Language Model (LLM) system built entirely from scratch using Python and PyTorch. It implements modern transformer architectures including Mixture of Experts (MoE), Rotary Positional Embeddings (RoPE), and Low-Rank Adaptation (LoRA). +```text +configs/ + model/ model YAMLs for 1B, 3B, 7B + data/ corpus mix and shard config + train/ LR, checkpoint, and logging schedule +data/ + ingest.py raw source registry and streaming helpers + filter.py license/lang/PII/safety/quality filtering + dedup.py exact and near-duplicate removal + shard.py tokenization + parquet shard writing + manifest + dataset.py packed iterable dataset with resume skip() +tokenizer/ + train_tokenizer.py + validate_tokenizer.py +model/ + config.py + rmsnorm.py + rope.py + attention.py + mlp.py + block.py + model.py +train/ + loss.py + optimizer.py + checkpoint.py + distributed.py + hardware.py + trainer.py +eval/ + perplexity.py + benchmarks.py + long_context.py + regression.py +serve/ + kv_cache.py + quantize.py + server.py + server_cpu.py +scripts/ + run_data_pipeline.sh + run_training.sh + run_eval.sh + run_serve.sh + run_serve_cpu.sh + run_validate_tokenizer.sh +tests/ +``` -Designed to be both educational and functional, SAGE can be trained, fine-tuned, quantized, and deployed on a single consumer GPU (e.g., NVIDIA T4 with 16GB VRAM). +## What Is Fully Working vs. What Is Scaffolded ---- +### Working now -## โ˜๏ธ Cloud Quickstart (Kaggle / Colab) -Running SAGE in the cloud? Check out the **[Kaggle & Colab Quickstart Guide](file:///c:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/SAGE_KAGGLE_GUIDE.md)** for one-click setup and a premium interactive chat interface. +- tokenizer training +- tokenizer validation +- data filtering and dedup helpers +- packed dataset logic +- dense transformer forward pass +- checkpoint save and resume +- hardware detection +- trainer entrypoint +- FastAPI health and basic generate endpoint +- unit and smoke tests ---- +### Scaffolded but not yet a full production runner -## ๐Ÿš€ Key Features +- benchmark execution against downloaded external datasets +- a single end-to-end corpus build command that downloads and preprocesses public corpora automatically +- production-grade multi-node launch tooling +- real llama.cpp server wiring beyond availability checks -- **Decoder-Only Transformer**: A GPT-style architecture with pre-layer normalization. -- **Mixture of Experts (MoE)**: Efficient scaling with a learned router selecting top-k experts per token. -- **Rotary Positional Embeddings (RoPE)**: Enhanced long-sequence generalization. -- **KV-Cache Inference**: O(1) time-per-token generation for high-speed response. -- **Retrieval-Augmented Generation (RAG)**: Integration with FAISS for document-based context lookup. -- **Efficient Fine-Tuning**: Support for LoRA and instruction tuning with loss masking. -- **Post-Training Quantization**: INT8 support to reduce memory footprint. -- **Interactive CLI**: A full REPL (Read-Eval-Print Loop) for chatting and system management. +That means the core codebase is real, but you still need to provide your own corpus files and Parquet shards before running a training job. ---- +## Install -## ๐Ÿ“‚ Project Structure +Create and activate a virtual environment, then install dependencies: -```text -sage/ -โ”œโ”€โ”€ model.py # Core architecture (Transformer, MoE, RoPE, Attention) -โ”œโ”€โ”€ data.py # Tokenization (tiktoken) & Streaming Datasets (HuggingFace) -โ”œโ”€โ”€ train.py # Pre-training loop with AdamW, AMP, and Cosine Decay -โ”œโ”€โ”€ inference.py # Text generation (Greedy, Temp, Top-k, Top-p sampling) -โ”œโ”€โ”€ finetune.py # LoRA implementation & Instruction Tuning -โ”œโ”€โ”€ optimize.py # INT8 Quantization & Pruning utilities -โ”œโ”€โ”€ memory.py # RAG Vector Store & Conversation History -โ”œโ”€โ”€ cli.py # Interactive Terminal Interface -โ”œโ”€โ”€ utils.py # Logging, Checkpointing, and Helper functions -โ”œโ”€โ”€ config.py # Central Hyperparameter Configuration -โ””โ”€โ”€ requirements.txt # System dependencies -sage_single.py # Consolidated single-file version for easy portability +```bash +pip install -r requirements.txt ``` ---- +Recommended optional extras: + +- `sentencepiece` is required for tokenizer training and validation +- `bitsandbytes` is useful for 8-bit experiments +- `llama.cpp` or `llama-cpp-python` is needed for the CPU serving path + +## Quick Start -## ๐Ÿ› ๏ธ Installation & Setup +If you want the shortest path to verifying the repo: -### 1. Requirements -Ensure you have Python 3.9+ and a CUDA-compatible GPU (recommended). +1. Install dependencies. +2. Run tests. +3. Start the FastAPI server. ```bash -# Clone the repository (GitHub) -git clone https://github.com/er-del/sage.git -cd sage +pytest -q +uvicorn serve.server:app --host 127.0.0.1 --port 8000 +``` -# OR Clone from Hugging Face -git clone https://huggingface.co/sage002/sage -cd sage +Then check: -# Install dependencies -pip install -r requirements.txt +```bash +curl http://127.0.0.1:8000/health +``` + +## Command Reference + +The detailed command guide is in [docs/COMMANDS.md](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/docs/COMMANDS.md:1). The most important commands are below. + +### 1. Train tokenizer + +Cross-platform Python command: + +```bash +python -m tokenizer.train_tokenizer \ + --input data/raw/general_web.txt data/raw/code.txt \ + --model-prefix tokenizer/tokenizer \ + --vocab-size 50000 +``` + +Linux/macOS/WSL wrapper: + +```bash +bash scripts/run_data_pipeline.sh \ + --input data/raw/general_web.txt data/raw/code.txt \ + --model-prefix tokenizer/tokenizer \ + --vocab-size 50000 +``` + +Outputs: + +- `tokenizer/tokenizer.model` +- `tokenizer/tokenizer.vocab` +- `tokenizer/training_corpus.txt` + +### 2. Validate tokenizer + +```bash +python - <<'PY' +from tokenizer.validate_tokenizer import validate_model_file +validate_model_file("tokenizer/tokenizer.model") +print("tokenizer ok") +PY ``` -### 2. Dependencies -- **PyTorch**: Core deep learning framework. -- **tiktoken**: Fast BPE tokenization (OpenAI's cl100k_base). -- **datasets**: For streaming training data from HuggingFace. -- **faiss-cpu**: For vector-based retrieval (RAG). -- **tqdm**: Progress bars for training. -- **bitsandbytes**: (Optional) For advanced quantization. +Or: ---- +```bash +bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model +``` -## ๐ŸŽฎ Getting Started +### 3. Train the model -### Launching the CLI -You can run the modular version or the single-file version: +Training expects existing Parquet shards. Example: ```bash -# Modular version -python -m sage.cli +python -m train.trainer \ + --model-config configs/model/1b.yaml \ + --schedule-config configs/train/schedule.yaml \ + --train-shards data/processed/shard-00000.parquet data/processed/shard-00001.parquet \ + --validation-shards data/processed/shard-00002.parquet \ + --output-dir runs/sage-1b +``` + +Useful options: -# Single-file version -python sage_single.py +- `--steps 100` for a short smoke run +- `--disable-wandb` to disable offline W&B logging + +Example smoke run: + +```bash +python -m train.trainer \ + --train-shards data/processed/shard-00000.parquet \ + --validation-shards data/processed/shard-00001.parquet \ + --output-dir runs/smoke \ + --steps 20 \ + --disable-wandb ``` -### Basic Chat -Once launched, simply type your message to chat with SAGE. The system uses a rolling conversation history to maintain context. +### 4. Run evaluation harness ---- +```bash +bash scripts/run_eval.sh +``` -## ๐Ÿ‘จโ€๐Ÿซ Training SAGE +This currently prints the registered benchmark surfaces. It is a harness check, not a full benchmark download-and-run pipeline. -SAGE supports real-time training either directly from the interactive REPL or via simple one-liner CLI commands (useful for background scripts). +### 5. Start the GPU server -### Non-Interactive "One-Liner" Commands -If you want to bypass the chat interface and just run a training job, pass the command as a CLI argument: ```bash -python sage_single.py --train 100 # Pre-train for 100 steps -python sage_single.py --finetune 200 # Instruction-tune for 200 steps -python sage_single.py --quantize # Apply INT8 quantization +uvicorn serve.server:app --host 0.0.0.0 --port 8000 ``` -### Interactive REPL Commands -If you are inside the chat interface, use the slash commands: +Or: -### /train [steps] -Run pre-training using the `TinyStories` dataset (default). -- `/train 100` โ€” Trains for 100 steps and saves a checkpoint. +```bash +bash scripts/run_serve.sh +``` -### /finetune [steps] -Perform instruction fine-tuning using LoRA adapters. -- `/finetune 200` โ€” Trains on instruction/response pairs and merges weights. +### 6. Start the CPU server ---- +```bash +uvicorn serve.server_cpu:app --host 0.0.0.0 --port 8001 +``` -## ๐Ÿง  Advanced Commands +Or: -| Command | Action | -| :--- | :--- | -| `/save` | Manually save the current model checkpoint. | -| `/load` | Reload the latest checkpoint from the `checkpoints` directory. | -| `/quantize` | Convert model weights to INT8 (CPU) for reduced memory usage. | -| `/rag on` | Enable Retrieval-Augmented Generation. | -| `/rag add ` | Add new knowledge to SAGE's retrieval database. | -| `/clear` | Clear the current conversation history. | -| `/help` | Show the list of available commands. | -| `/exit` | Exit the program cleanly. | +```bash +bash scripts/run_serve_cpu.sh +``` + +### 7. Call the generate endpoint ---- +The current server takes token IDs directly, not raw text strings. -## ๐Ÿ—๏ธ Architecture Details +```bash +curl -X POST http://127.0.0.1:8000/generate \ + -H "Content-Type: application/json" \ + -d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}" +``` + +Response shape: + +```json +{ + "tokens": [1, 42, 99, 123, 456] +} +``` -### Mixture of Experts (MoE) -SAGE swaps standard FFN layers for MoE blocks. Each block contains 4 experts, where exactly 2 are activated per token via a learned linear router. This allows for higher total capacity without increasing the computational cost per token. +## How Training Works -### Rotary Positional Embeddings (RoPE) -Positions are encoded via complex-valued rotations of query and key vectors. This allows SAGE to better handle sequences longer than what it was trained on compared to absolute position embeddings. +The training flow is: -### Inference Engine -Generation supports: -- **Temperature**: Adjusts randomness. -- **Top-k**: Limits sampling to the most likely 'k' tokens. -- **Top-p (Nucleus)**: Limits sampling to a cumulative probability threshold. -- **KV-Caching**: Caches Attention keys and values to avoid redundant computation. +1. load model config from `configs/model/*.yaml` +2. load schedule config from `configs/train/schedule.yaml` +3. detect hardware in `train/hardware.py` +4. build optimizer and cosine scheduler +5. load latest checkpoint if one exists +6. call `PackedDataset.skip()` so resume does not replay already-trained batches +7. run forward/backward with autocast on CUDA or MPS +8. clip gradients +9. log metrics to `metrics.jsonl` and optionally offline W&B +10. run validation perplexity at eval intervals +11. save checkpoint every configured interval ---- +Important output files during training: -## ๐Ÿค— Hugging Face Model Hub +- `runs//metrics.jsonl` +- `runs//ckpt_step_0001000.pt` +- later checkpoints in the same folder + +## How Data Is Expected to Look + +### Raw text files for tokenizer training + +Simple UTF-8 text files are enough: + +```text +This is a training document. +This is another one. +``` -This project is actively maintained on Hugging Face. You can find pre-trained checkpoints, datasets, and community discussions here: +### Raw JSONL records for ingest/filter work + +The ingest layer assumes records like: + +```json +{"text": "example text"} +``` + +### Processed Parquet shards for training + +The trainer expects Parquet rows with at least: + +- `tokens` +- `split` + +The sharding helper writes: + +- `id` +- `text` +- `tokens` +- `domain_tag` +- `quality_tier` +- `lang` +- `token_count` +- `split` + +## Main Config Files + +### [configs/model/1b.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/model/1b.yaml:1) + +Controls the model shape: + +- layers +- hidden size +- heads +- KV heads +- FFN size +- vocab size +- context length + +### [configs/data/mix.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/data/mix.yaml:1) + +Controls corpus weights and split ratios. + +### [configs/train/schedule.yaml](C:/Users/Lenovo/OneDrive/Desktop/Documents/LLM_MOdel/configs/train/schedule.yaml:1) + +Controls: + +- total token target +- LR schedule +- warmup +- checkpoint interval +- log interval +- eval interval + +## Common Workflows + +### Workflow A: verify the repo + +```bash +pip install -r requirements.txt +pytest -q +``` + +### Workflow B: train tokenizer only + +```bash +python -m tokenizer.train_tokenizer --input data/raw/general_web.txt --model-prefix tokenizer/tokenizer +python - <<'PY' +from tokenizer.validate_tokenizer import validate_model_file +validate_model_file("tokenizer/tokenizer.model") +print("ok") +PY +``` + +### Workflow C: smoke-train on local shards + +```bash +python -m train.trainer \ + --train-shards data/processed/shard-00000.parquet \ + --validation-shards data/processed/shard-00001.parquet \ + --output-dir runs/smoke \ + --steps 20 \ + --disable-wandb +``` + +### Workflow D: serve locally + +```bash +uvicorn serve.server:app --host 127.0.0.1 --port 8000 +curl http://127.0.0.1:8000/health +``` + +## Troubleshooting + +### `No training shards provided` + +You launched the trainer without `--train-shards`. The trainer is working as designed, but it needs Parquet shard paths. + +### `ModuleNotFoundError: sentencepiece` + +Install dependencies: + +```bash +pip install -r requirements.txt +``` + +### FastAPI starts but generate is not useful + +That is expected right now if you have not trained or loaded a checkpoint. The server instantiates the model architecture, but it does not yet load a trained checkpoint automatically. + +### CPU server says llama.cpp is unavailable + +Install `llama.cpp` or `llama-cpp-python`. The current CPU server is a readiness surface, not a bundled llama.cpp runtime. + +## Tests + +Run the full suite: + +```bash +pytest -q +``` -๐Ÿ”— **[huggingface.co/sage002/sage](https://huggingface.co/sage002/sage)** +Coverage areas: -**Developed by Antigravity AI Systems.** +- tokenizer roundtrip validation +- model shapes +- attention math +- data filtering and packing +- checkpoint roundtrip +- hardware summaries +- FastAPI health endpoints ---- +## Next Practical Step -## ๐Ÿ“œ Disclaimer -SAGE is an experimental engine. While architecturally complete, the quality of generated responses depends heavily on the amount of training data and compute steps provided. +If you want the fastest real progress from here, the next step is: -**Developed by Antigravity AI Systems.** +1. prepare a small local corpus +2. train the tokenizer +3. write Parquet shards with `data/shard.py` +4. run a `--steps 20` smoke training job +5. only then start extending benchmark or serving behavior diff --git a/SAGE_KAGGLE_GUIDE.md b/SAGE_KAGGLE_GUIDE.md deleted file mode 100644 index 64dc37ee00a2775034a0e2a78fb1e2368658e21f..0000000000000000000000000000000000000000 --- a/SAGE_KAGGLE_GUIDE.md +++ /dev/null @@ -1,146 +0,0 @@ -# ๐Ÿช SAGE: Kaggle & Colab Quickstart Guide - -Welcome to the **Self-Adaptive General Engine (SAGE)**. This guide will help you get SAGE v2 running on a cloud environment (like Kaggle's 2x T4 or Google Colab) in under 5 minutes. - ---- - -## ๐Ÿ› ๏ธ Step 1: Environment Setup - -Run this cell first to install dependencies and fix any common binary incompatibilities (like the Numpy/Torch mismatch). - -```python -# Install PyTorch 2.1 with CUDA 12.1 (supports Tesla P100 sm_60) -!pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121 - -# Install other dependencies -!pip install "numpy<2.0.0" --force-reinstall -!pip install bitsandbytes tqdm tiktoken faiss-cpu datasets wandb --upgrade - -print("โœ… Environment ready. Please RESTART YOUR KERNEL now if this is your first run.") -``` - ---- - -## ๐Ÿ”‘ Step 2: Weights & Biases Logging (Optional but Recommended) - -To track your training progress with professional charts: - -1. Get your API Key from [wandb.ai/authorize](https://wandb.ai/authorize). -2. Add it to your Kaggle **Secrets** with the label `WANDB_API_KEY`. -3. Run this: - -```python -import wandb -from kaggle_secrets import UserSecretsClient -try: - user_secrets = UserSecretsClient() - wandb.login(key=user_secrets.get_secret("WANDB_API_KEY")) -except: - import os - os.environ["WANDB_MODE"] = "offline" - print("โš ๏ธ W&B Secret not found. Running in offline mode.") -``` - ---- - -## ๐Ÿ’ฌ Step 3: Launch the SAGE Chat Interface - -This is a premium, multi-GPU enabled chat widget. Paste this into a cell to start interacting with SAGE. - -**Note:** SAGE automatically detects GPU compatibility and falls back to CPU if needed. - -```python -import sys, os, torch, random -import torch.nn as nn -import ipywidgets as widgets -from IPython.display import display, HTML - -# Verify SAGE is accessible (debugging import issues) -if not os.path.exists('sage/__init__.py'): - print("โŒ ERROR: sage/ folder not found in current directory!") - print(" Make sure you've cloned the repo: !git clone https://github.com/er-del/sage.git") - raise ImportError("sage module not found") - -# Add current directory to path if needed -if '' not in sys.path and '.' not in sys.path: - sys.path.insert(0, '') - -# Import SAGE -from sage import SageModel, SageConfig, SageTokenizer, generate, ConversationHistory, train as train_model, finetune -from sage import __version__ as sage_version - -# Verify import worked -print(f"โœ… SAGE v{sage_version} loaded successfully") - -# -- Initialization -- -config = SageConfig() -# Note: config.device automatically checks GPU compatibility and falls back to CPU if needed -device = config.device -print(f"๐Ÿ–ฅ๏ธ Using device: {device}") - -tokenizer = SageTokenizer() -history = ConversationHistory(tokenizer, max_tokens=1024) -model = SageModel(config) - -# -- Multi-GPU Logic (only if CUDA is actually being used) -- -if device.type == "cuda": - gpu_count = torch.cuda.device_count() - if gpu_count > 1: - print(f"๐Ÿš€ Multi-GPU active: {gpu_count} GPUs.") - model = nn.DataParallel(model) -model = model.to(device) - -# -- Load Weights -- -ckpt_path = "checkpoints/sage_latest.pt" -if os.path.exists(ckpt_path): - base_model = getattr(model, "module", model) - ckpt = torch.load(ckpt_path, map_location=device) - base_model.load_state_dict(ckpt['model_state_dict']) - print("โœ… Weights loaded from checkpoint.") -else: - print("โš ๏ธ RANDOM WEIGHTS (Type /train to begin learning).") - -# -- Render UI -- -chat_display = widgets.Output(layout={'border': '1px solid #444', 'height': '450px', 'overflow_y': 'scroll', 'padding': '10px'}) -text_input = widgets.Text(placeholder="Chat or type /train 1000...", layout={'width': '80%'}) -send_button = widgets.Button(description="Send", button_style='primary', layout={'width': '18%'}) -display(HTML("")) - -def on_send(_=None): - user_text = text_input.value.strip() - if not user_text: return - text_input.value = "" - with chat_display: - if user_text.startswith("/train"): - steps = int(user_text.split()[1]) if len(user_text.split()) > 1 else 100 - print(f"๐Ÿš€ TRAINING STARTING ({steps} steps)...") - train_model(model, config, total_steps=steps) - print("โœ… DONE.") - return - display(HTML(f'
You: {user_text}
')) - response = generate(model, tokenizer, history.build_prompt(user_text), stream=False) - res = response.split("SAGE:")[-1].split("")[0].replace("", "").strip() - history.add("user", user_text); history.add("assistant", res) - display(HTML(f'
SAGE: {res}
')) - -text_input.on_submit(on_send); send_button.on_click(lambda b: on_send()) -display(chat_display, widgets.HBox([text_input, send_button])) -``` - ---- - -## ๐ŸŽฎ Command Cheat Sheet - -| Command | Action | -| :--- | :--- | -| `/train ` | Starts pre-training (Base knowledge). Recommended: 5000+ | -| `/clear` | Resets the conversation history. | -| `/finetune ` | (Coming Soon) Starts instruction fine-tuning. | - ---- - -## ๐Ÿ’ก Pro Tips for T4 GPUs - -1. **Batch Size**: The default `batch_size=4` with `gradient_accumulation=16` is perfect for a 2x T4 setup (32GB VRAM total). -2. **Persistence**: Kaggle outputs are deleted when the session ends. Make sure to **download** the `checkpoints/` folder or sync it to **Hugging Face** regularly. -3. **Patience**: Loss will fluctuate. Look for a steady downward trend on your W&B dashboard! diff --git a/SAGE_V3_ROADMAP.md b/SAGE_V3_ROADMAP.md deleted file mode 100644 index 467eeb4fb55e0e81b991be7ad2a52fe126c1aab5..0000000000000000000000000000000000000000 --- a/SAGE_V3_ROADMAP.md +++ /dev/null @@ -1,52 +0,0 @@ -# SAGE v3: The "Long-Vision" Roadmap ๐Ÿ—บ๏ธ - -This document outlines the high-impact architectural upgrades that will transform SAGE into a multi-thousand token reasoning assistant with multimedia capabilities. - ---- - -## ๐Ÿ—๏ธ 1. Long-Context Scaling (YaRN / RoPE-Interpolation) - -**Goal**: Increase SAGE's maximum comprehension from 1,024 to **4,096+ tokens**. - -### Technical Strategy: -Currently, our `freqs_cis` are precomputed for a fixed window. In v3, we will implement **NTK-Aware Interpolation**. -- **Implementation**: We will add a `scaling_factor` to `SageConfig`. -- **Logic**: During inference, if the sequence length exceeds the original training window, we will "stretch" the rotary frequencies dynamically rather than letting them overflow. -- **Benefit**: SAGE can read entire source code files or long essays without "losing its mind" at the 1,025th token. - ---- - -## ๐Ÿ“‚ 2. Interactive RAG (Kaggle UI Integration) - -**Goal**: Allow users to "Upload and Chat" with any file instantly in the notebook. - -### Technical Strategy: -- **Widget Update**: Add a `widgets.FileUpload` component to the Kaggle chat interface. -- **Auto-Ingestion**: When a file is uploaded, a background hook will: - 1. Parse the text (PDF, `.py`, `.md`). - 2. Chunk it into 200-token segments. - 3. Generate embeddings and add them to the **FAISS Vector Store**. -- **Real-time Recall**: SAGE will automatically pull context from these uploaded files using the `retrieve_context` logic we've already built. - ---- - -## ๐Ÿ‘๏ธ 3. Multimodal Foundation (Vision Projection) - -**Goal**: Let SAGE "see" images. - -### Technical Strategy: -Since SAGE is a small, efficient model (133M), it is the perfect candidate for a **Vision-Language Model (VLM)**. -- **Architecture**: We will add a frozen **CLIP-ViT** image encoder. -- **The Bridge**: We will implement a `VisionProjector` (a simple 2-layer MLP) that converts CLIP image embeddings (e.g., 768-dim) into SAGE token embeddings (512-dim). -- **Outcome**: You will be able to provide an image URL and a prompt like "What is in this picture?", and SAGE will respond based on the visual tokens. - ---- - -## โšก 4. Training Stability: LayerNorm Tuning - -To support these advanced features, we will move to **RMSNorm** (Root Mean Square Layer Normalization) for even faster convergence and better numerical stability on the double-T4 setup. - ---- - -### Which one first? -We can begin implementing **RoPE Scaling** immediately to give SAGE a massive context boost without needing new weights. Just let me know when you're ready! diff --git a/configs/data/mix.yaml b/configs/data/mix.yaml new file mode 100644 index 0000000000000000000000000000000000000000..53e64c4c22fc10589c2764adf175a6d6ffdddfc8 --- /dev/null +++ b/configs/data/mix.yaml @@ -0,0 +1,25 @@ +data_sources: + general_web: + weight_percent: 55 + quality_tiers: [high, medium] + code: + weight_percent: 15 + quality_tiers: [high, medium] + math_science: + weight_percent: 12 + quality_tiers: [high, medium] + books_longform: + weight_percent: 10 + quality_tiers: [high, medium] + multilingual: + weight_percent: 5 + quality_tiers: [high, medium] + synthetic: + weight_percent: 3 + quality_tiers: [high] +splits: + train: 0.989 + validation: 0.01 + test: 0.001 +shard_size_bytes: 2147483648 +format: parquet diff --git a/configs/model/1b.yaml b/configs/model/1b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc76ca287d2bc5b230d2ade41f5c344b8ebfc263 --- /dev/null +++ b/configs/model/1b.yaml @@ -0,0 +1,13 @@ +name: sage-1b +num_layers: 24 +d_model: 2048 +num_attn_heads: 16 +num_kv_heads: 8 +head_dim: 128 +ffn_hidden_dim: 5632 +vocab_size: 50000 +context_length: 4096 +rope_base_frequency: 500000 +rope_scaling_factor: 1.0 +dropout: 0.0 +tie_word_embeddings: true diff --git a/configs/model/3b.yaml b/configs/model/3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a928606086bfba715b131ef9dc09b959ff05dd11 --- /dev/null +++ b/configs/model/3b.yaml @@ -0,0 +1,13 @@ +name: sage-3b +num_layers: 28 +d_model: 3072 +num_attn_heads: 24 +num_kv_heads: 8 +head_dim: 128 +ffn_hidden_dim: 8192 +vocab_size: 50000 +context_length: 8192 +rope_base_frequency: 500000 +rope_scaling_factor: 1.0 +dropout: 0.0 +tie_word_embeddings: true diff --git a/configs/model/7b.yaml b/configs/model/7b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd71fb8c40bcc8b411ffe145548f7bfdc1a270e9 --- /dev/null +++ b/configs/model/7b.yaml @@ -0,0 +1,13 @@ +name: sage-7b +num_layers: 32 +d_model: 4096 +num_attn_heads: 32 +num_kv_heads: 8 +head_dim: 128 +ffn_hidden_dim: 11008 +vocab_size: 50000 +context_length: 8192 +rope_base_frequency: 500000 +rope_scaling_factor: 1.0 +dropout: 0.0 +tie_word_embeddings: true diff --git a/configs/train/schedule.yaml b/configs/train/schedule.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8dfc532fdbe1eb2f29dc272b534c87cfa9d67d7d --- /dev/null +++ b/configs/train/schedule.yaml @@ -0,0 +1,14 @@ +run_name: sage-1b-pretrain +total_tokens: 50000000000 +effective_batch_tokens: 2000000 +peak_learning_rate: 3.0e-4 +min_learning_rate: 3.0e-5 +warmup_steps: 2000 +weight_decay: 0.1 +betas: [0.9, 0.95] +adam_eps: 1.0e-8 +gradient_clip_norm: 1.0 +checkpoint_interval: 1000 +log_interval: 10 +eval_interval: 1000 +seed: 42 diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c854b1a3ded90824cb7df6da2050a6b8fb06b7d1 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1 @@ +"""Data pipeline modules for SAGE.""" diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c69028e33dced043e8985148a5b5ada3ef239385 --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,92 @@ +"""Packed training dataset with deterministic resume support.""" + +from __future__ import annotations + +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Iterator + +import torch +from torch.utils.data import IterableDataset + +try: + import pyarrow.parquet as pq +except ImportError: # pragma: no cover - optional at import time + pq = None + + +@dataclass(frozen=True) +class DatasetConfig: + """Configuration for packing token streams into training batches.""" + + shard_paths: tuple[str, ...] + context_length: int + split: str = "train" + seed: int = 42 + + +class PackedDataset(IterableDataset): + """Iterate packed token sequences with document-boundary masks.""" + + def __init__(self, config: DatasetConfig): + super().__init__() + self.config = config + self._skip = 0 + + def skip(self, n_batches: int) -> None: + """Fast-forward the iterator by discarding the first n batches.""" + self._skip = max(0, int(n_batches)) + + def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: + skipped = 0 + for batch in self._generate(): + if skipped < self._skip: + skipped += 1 + continue + yield batch + + def _generate(self) -> Iterator[dict[str, torch.Tensor]]: + token_buffer: list[int] = [] + boundary_buffer: list[int] = [] + for row in self._iter_rows(): + tokens = list(row["tokens"]) + if len(tokens) < 2: + continue + token_buffer.extend(tokens) + boundary_buffer.extend([0] * (len(tokens) - 1) + [1]) + while len(token_buffer) >= self.config.context_length + 1: + window_tokens = token_buffer[: self.config.context_length + 1] + window_boundaries = boundary_buffer[: self.config.context_length + 1] + yield pack_sequence(window_tokens, window_boundaries) + token_buffer = token_buffer[self.config.context_length :] + boundary_buffer = boundary_buffer[self.config.context_length :] + + def _iter_rows(self) -> Iterator[dict[str, object]]: + if pq is None: + raise ImportError("pyarrow is required to read parquet shards.") + shard_paths = [Path(path) for path in self.config.shard_paths] + rng = random.Random(self.config.seed) + shard_paths = shard_paths[:] + rng.shuffle(shard_paths) + for path in shard_paths: + table = pq.read_table(path, columns=["tokens", "split"]) + rows = table.to_pylist() + for row in rows: + if row["split"] != self.config.split: + continue + yield row + + +def pack_sequence(tokens: list[int], boundaries: list[int]) -> dict[str, torch.Tensor]: + """Turn one packed token window into model-ready tensors.""" + input_ids = torch.tensor(tokens[:-1], dtype=torch.long) + labels = torch.tensor(tokens[1:], dtype=torch.long) + loss_mask = torch.ones_like(input_ids, dtype=torch.float32) + attention_document_mask = torch.tensor(boundaries[:-1], dtype=torch.long) + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "document_boundaries": attention_document_mask, + } diff --git a/data/dedup.py b/data/dedup.py new file mode 100644 index 0000000000000000000000000000000000000000..e799682e17a9d75df811d4caf24f268b813b6168 --- /dev/null +++ b/data/dedup.py @@ -0,0 +1,59 @@ +"""Exact and near-duplicate detection helpers.""" + +from __future__ import annotations + +import hashlib +import re +from collections import defaultdict +from typing import Iterable + + +TOKEN_RE = re.compile(r"\w+") + + +def exact_content_hash(text: str) -> str: + """Return an exact content hash.""" + return hashlib.sha1(text.encode("utf-8")).hexdigest() + + +def shingles(text: str, n: int = 5) -> set[str]: + """Build token shingles for near-duplicate detection.""" + tokens = TOKEN_RE.findall(text.lower()) + if len(tokens) < n: + return {" ".join(tokens)} if tokens else set() + return {" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)} + + +def jaccard_similarity(left: str, right: str, n: int = 5) -> float: + """Compute shingle-level Jaccard similarity.""" + left_set = shingles(left, n) + right_set = shingles(right, n) + if not left_set and not right_set: + return 1.0 + if not left_set or not right_set: + return 0.0 + return len(left_set & right_set) / len(left_set | right_set) + + +def deduplicate_records(records: Iterable[dict[str, object]], near_dup_threshold: float = 0.92) -> list[dict[str, object]]: + """Drop exact and near-duplicate records.""" + exact_seen: set[str] = set() + buckets: dict[str, list[dict[str, object]]] = defaultdict(list) + kept: list[dict[str, object]] = [] + for record in records: + text = str(record["text"]) + digest = exact_content_hash(text) + if digest in exact_seen: + continue + signature = digest[:8] + near_duplicate = False + for candidate in buckets[signature]: + if jaccard_similarity(text, str(candidate["text"])) >= near_dup_threshold: + near_duplicate = True + break + if near_duplicate: + continue + exact_seen.add(digest) + buckets[signature].append(record) + kept.append(record) + return kept diff --git a/data/filter.py b/data/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b361b30dd3bc9346c22f910591ffb37208dba1c9 --- /dev/null +++ b/data/filter.py @@ -0,0 +1,135 @@ +"""Corpus filtering, safety, and quality heuristics.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Iterable + + +EMAIL_RE = re.compile(r"\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Z]{2,}\b", re.IGNORECASE) +PHONE_RE = re.compile(r"(?:(?:\+?\d{1,3})?[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?){2}\d{4}") +SSN_RE = re.compile(r"\b\d{3}-\d{2}-\d{4}\b") +HTML_RE = re.compile(r"<[^>]+>") +MULTISPACE_RE = re.compile(r"[ \t]+") +NSFW_TERMS = {"porn", "explicit sex", "rape"} +HATE_TERMS = {"kill all", "ethnic cleansing"} +ALLOWED_LICENSES = {"permissive", "restricted"} +ALLOWED_LANGS = {"en", "es", "fr", "de", "hi", "zh", "ar", "pt"} + + +@dataclass(frozen=True) +class FilterConfig: + """Policy controls for the filtering pipeline.""" + + minimum_chars: int = 200 + maximum_chars: int = 200_000 + minimum_alpha_ratio: float = 0.45 + minimum_quality_score: float = 0.20 + language_confidence_threshold: float = 0.65 + + +def normalize_text(text: str) -> str: + """Strip tags and normalize whitespace.""" + text = HTML_RE.sub(" ", text) + text = MULTISPACE_RE.sub(" ", text) + return text.strip() + + +def detect_language(text: str) -> tuple[str, float]: + """Use a light heuristic to assign a language code.""" + ascii_ratio = sum(ch.isascii() for ch in text) / max(len(text), 1) + devanagari = sum("\u0900" <= ch <= "\u097f" for ch in text) + cjk = sum("\u4e00" <= ch <= "\u9fff" for ch in text) + arabic = sum("\u0600" <= ch <= "\u06ff" for ch in text) + if cjk > 8: + return "zh", 0.95 + if arabic > 8: + return "ar", 0.95 + if devanagari > 8: + return "hi", 0.95 + if ascii_ratio > 0.9: + return "en", 0.80 + return "unknown", 0.40 + + +def quality_score(text: str) -> float: + """Score text using length, punctuation, and alphabetic density.""" + if not text: + return 0.0 + alpha_ratio = sum(ch.isalpha() for ch in text) / len(text) + punct_ratio = sum(ch in ".,;:!?()[]{}" for ch in text) / len(text) + line_count = text.count("\n") + 1 + score = min(len(text) / 4000.0, 1.0) * 0.4 + alpha_ratio * 0.4 + min(punct_ratio * 8.0, 1.0) * 0.2 + if line_count < 2 and len(text) > 10_000: + score *= 0.85 + return round(score, 4) + + +def quality_tier(score: float) -> str: + """Map a numeric score to a quality tier.""" + if score >= 0.70: + return "high" + if score >= 0.40: + return "medium" + return "low" + + +def strip_pii(text: str) -> str: + """Mask basic email, phone, and SSN patterns.""" + text = EMAIL_RE.sub("[EMAIL]", text) + text = PHONE_RE.sub("[PHONE]", text) + text = SSN_RE.sub("[SSN]", text) + return text + + +def passes_safety_filter(text: str) -> bool: + """Reject obviously unsafe content with simple keyword checks.""" + lower = text.lower() + if any(term in lower for term in NSFW_TERMS): + return False + if any(term in lower for term in HATE_TERMS): + return False + return True + + +def license_allowed(category: str) -> bool: + """Return whether the source license category is allowed.""" + return category in ALLOWED_LICENSES + + +def filter_record(record: dict[str, object], config: FilterConfig = FilterConfig()) -> dict[str, object] | None: + """Apply the full filter pipeline to one record.""" + if not license_allowed(str(record.get("license_category", ""))): + return None + text = normalize_text(str(record.get("text", ""))) + if not (config.minimum_chars <= len(text) <= config.maximum_chars): + return None + lang, confidence = detect_language(text) + if lang not in ALLOWED_LANGS or confidence < config.language_confidence_threshold: + return None + text = strip_pii(text) + if not passes_safety_filter(text): + return None + score = quality_score(text) + if score < config.minimum_quality_score: + return None + return { + **record, + "text": text, + "lang": lang, + "lang_confidence": confidence, + "quality_score": score, + "quality_tier": quality_tier(score), + "token_count_estimate": max(1, len(text) // 4), + } + + +def filter_corpus(records: Iterable[dict[str, object]], config: FilterConfig = FilterConfig()) -> list[dict[str, object]]: + """Filter a corpus in memory.""" + kept: list[dict[str, object]] = [] + for record in records: + filtered = filter_record(record, config) + if filtered is not None: + kept.append(filtered) + return kept diff --git a/data/ingest.py b/data/ingest.py new file mode 100644 index 0000000000000000000000000000000000000000..99859df81f3e026fa1f4a8b5111ca456351830b7 --- /dev/null +++ b/data/ingest.py @@ -0,0 +1,79 @@ +"""Raw corpus ingestion utilities.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Iterator + + +@dataclass(frozen=True) +class SourceSpec: + """Describes one raw corpus source.""" + + name: str + domain_tag: str + quality_tier: str + license_category: str + estimated_tokens: int + path: str + text_key: str = "text" + + +SOURCE_REGISTRY: tuple[SourceSpec, ...] = ( + SourceSpec("general_web", "general", "medium", "permissive", 20_000_000_000, "data/raw/general_web.jsonl"), + SourceSpec("code", "code", "high", "permissive", 8_000_000_000, "data/raw/code.jsonl"), + SourceSpec("math_science", "math", "high", "permissive", 4_000_000_000, "data/raw/math_science.jsonl"), + SourceSpec("books_longform", "general", "high", "restricted", 5_000_000_000, "data/raw/books.jsonl"), + SourceSpec("multilingual", "multilingual", "medium", "permissive", 3_000_000_000, "data/raw/multilingual.jsonl"), + SourceSpec("synthetic", "reasoning", "high", "permissive", 1_000_000_000, "data/raw/synthetic.jsonl"), +) + + +def iter_jsonl(path: Path, text_key: str = "text") -> Iterator[dict[str, object]]: + """Yield JSONL records from disk.""" + with path.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + payload = json.loads(line) + text = payload.get(text_key) + if not isinstance(text, str) or not text.strip(): + continue + yield payload + + +def stream_source(spec: SourceSpec) -> Iterator[dict[str, object]]: + """Yield normalized records for one configured source.""" + path = Path(spec.path) + if not path.exists(): + return iter(()) + return ( + { + "id": stable_record_id(spec.name, record[spec.text_key]), + "text": record[spec.text_key], + "domain_tag": spec.domain_tag, + "quality_tier": spec.quality_tier, + "license_category": spec.license_category, + "source_name": spec.name, + } + for record in iter_jsonl(path, spec.text_key) + ) + + +def stream_all_sources(sources: Iterable[SourceSpec] = SOURCE_REGISTRY) -> Iterator[dict[str, object]]: + """Yield records from every source in the registry.""" + for spec in sources: + yield from stream_source(spec) + + +def stable_record_id(source_name: str, text: str) -> str: + """Hash a source+text pair into a stable content id.""" + digest = hashlib.sha256() + digest.update(source_name.encode("utf-8")) + digest.update(b"\0") + digest.update(text.encode("utf-8")) + return digest.hexdigest() diff --git a/data/shard.py b/data/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..55a0c4d56763d517b2ec0fe54360c11f95b5c39d --- /dev/null +++ b/data/shard.py @@ -0,0 +1,95 @@ +"""Tokenization, manifesting, and Parquet sharding.""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable + +try: + import pyarrow as pa + import pyarrow.parquet as pq +except ImportError: # pragma: no cover - optional at import time + pa = None + pq = None + + +SCHEMA_COLUMNS = ("id", "text", "tokens", "domain_tag", "quality_tier", "lang", "token_count", "split") + + +@dataclass(frozen=True) +class ShardConfig: + """Parameters for Parquet shard writing.""" + + output_dir: str + shard_size: int = 2048 + validation_ratio: float = 0.01 + test_ratio: float = 0.001 + + +def assign_split(record_id: str, validation_ratio: float, test_ratio: float) -> str: + """Assign a deterministic split from the content id.""" + value = int(record_id[:8], 16) / 0xFFFFFFFF + if value < test_ratio: + return "test" + if value < test_ratio + validation_ratio: + return "validation" + return "train" + + +def build_manifest(shard_paths: Iterable[Path]) -> dict[str, object]: + """Create a manifest describing shard files.""" + shard_paths = list(shard_paths) + digest = hashlib.sha256() + for path in shard_paths: + digest.update(path.name.encode("utf-8")) + digest.update(str(path.stat().st_size).encode("utf-8")) + return { + "format": "parquet", + "schema": list(SCHEMA_COLUMNS), + "shards": [path.name for path in shard_paths], + "dataset_hash": digest.hexdigest(), + } + + +def write_shards(records: Iterable[dict[str, object]], tokenizer, config: ShardConfig) -> dict[str, object]: + """Write tokenized records to Parquet shards and emit a manifest.""" + if pa is None or pq is None: + raise ImportError("pyarrow is required to write parquet shards.") + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + buffer: list[dict[str, object]] = [] + shard_paths: list[Path] = [] + shard_index = 0 + for record in records: + tokens = tokenizer.encode(str(record["text"]), out_type=int) + row = { + "id": str(record["id"]), + "text": str(record["text"]), + "tokens": tokens, + "domain_tag": str(record["domain_tag"]), + "quality_tier": str(record["quality_tier"]), + "lang": str(record["lang"]), + "token_count": len(tokens), + "split": assign_split(str(record["id"]), config.validation_ratio, config.test_ratio), + } + buffer.append(row) + if len(buffer) >= config.shard_size: + shard_paths.append(_flush_shard(output_dir, shard_index, buffer)) + shard_index += 1 + buffer = [] + if buffer: + shard_paths.append(_flush_shard(output_dir, shard_index, buffer)) + manifest = build_manifest(shard_paths) + (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8") + return manifest + + +def _flush_shard(output_dir: Path, shard_index: int, rows: list[dict[str, object]]) -> Path: + """Persist one Parquet shard.""" + table = pa.table({column: [row[column] for row in rows] for column in SCHEMA_COLUMNS}) + path = output_dir / f"shard-{shard_index:05d}.parquet" + pq.write_table(table, path) + return path diff --git a/docs/COMMANDS.md b/docs/COMMANDS.md new file mode 100644 index 0000000000000000000000000000000000000000..b59ef90265206325375802d5aca47e344b514819 --- /dev/null +++ b/docs/COMMANDS.md @@ -0,0 +1,84 @@ +# SAGE Commands + +This file is the short command-only reference for the repo. + +## Install + +```bash +pip install -r requirements.txt +``` + +## Run tests + +```bash +pytest -q +``` + +## Train tokenizer + +```bash +python -m tokenizer.train_tokenizer \ + --input data/raw/general_web.txt data/raw/code.txt \ + --model-prefix tokenizer/tokenizer \ + --vocab-size 50000 +``` + +## Validate tokenizer + +```bash +bash scripts/run_validate_tokenizer.sh tokenizer/tokenizer.model +``` + +## Start a short training smoke run + +```bash +python -m train.trainer \ + --train-shards data/processed/shard-00000.parquet \ + --validation-shards data/processed/shard-00001.parquet \ + --output-dir runs/smoke \ + --steps 20 \ + --disable-wandb +``` + +## Start full training + +```bash +python -m train.trainer \ + --model-config configs/model/1b.yaml \ + --schedule-config configs/train/schedule.yaml \ + --train-shards data/processed/shard-00000.parquet data/processed/shard-00001.parquet \ + --validation-shards data/processed/shard-00002.parquet \ + --output-dir runs/sage-1b +``` + +## Run eval harness + +```bash +bash scripts/run_eval.sh +``` + +## Start GPU server + +```bash +bash scripts/run_serve.sh +``` + +## Start CPU server + +```bash +bash scripts/run_serve_cpu.sh +``` + +## Check server health + +```bash +curl http://127.0.0.1:8000/health +``` + +## Generate tokens from the API + +```bash +curl -X POST http://127.0.0.1:8000/generate \ + -H "Content-Type: application/json" \ + -d "{\"input_ids\": [1, 42, 99], \"max_new_tokens\": 8}" +``` diff --git a/docs/flow_llm.mmd b/docs/flow_llm.mmd new file mode 100644 index 0000000000000000000000000000000000000000..e0a4ae4e1e0df2e3abae1dce63c86a0cec5ffdd1 --- /dev/null +++ b/docs/flow_llm.mmd @@ -0,0 +1,140 @@ +flowchart TB + + %% ========================================================= + %% SAGE - Simplified Operational Flow + %% This file is pure Mermaid so .mmd renderers can open it. + %% ========================================================= + + user["You / Operator"] + + subgraph inputs["Inputs"] + raw["Raw text / JSONL corpus"] + cfg_model["configs/model/1b.yaml"] + cfg_train["configs/train/schedule.yaml"] + cfg_data["configs/data/mix.yaml"] + end + + subgraph tokenizer["Tokenizer Stage"] + tok_train["tokenizer/train_tokenizer.py"] + tok_validate["tokenizer/validate_tokenizer.py"] + tok_model["tokenizer.model + tokenizer.vocab"] + end + + subgraph prep["Data Preparation Stage"] + ingest["data/ingest.py"] + filter["data/filter.py"] + dedup["data/dedup.py"] + shard["data/shard.py"] + parquet["Parquet shards + manifest.json"] + packed["data/dataset.py
PackedDataset"] + end + + subgraph model["Model Stage"] + model_cfg["model/config.py"] + rms["RMSNorm"] + rope["RoPE"] + attn["GQA Attention + SDPA"] + mlp["SwiGLU MLP"] + blocks["Transformer Blocks x24"] + sage["SageTransformer"] + end + + subgraph train["Training Stage"] + hw["train/hardware.py"] + opt["train/optimizer.py"] + loss["train/loss.py"] + ckpt["train/checkpoint.py"] + trainer["train/trainer.py"] + metrics["runs//metrics.jsonl"] + saves["runs//ckpt_step_xxxxxxx.pt"] + end + + subgraph evals["Evaluation Stage"] + ppl["eval/perplexity.py"] + bench["eval/benchmarks.py"] + longctx["eval/long_context.py"] + regress["eval/regression.py"] + end + + subgraph serving["Serving Stage"] + kv["serve/kv_cache.py"] + quant["serve/quantize.py"] + api["serve/server.py"] + cpu["serve/server_cpu.py"] + health["/health"] + generate["/generate"] + end + + user --> raw + user --> cfg_model + user --> cfg_train + user --> cfg_data + + raw --> tok_train + tok_train --> tok_model + tok_model --> tok_validate + + raw --> ingest + cfg_data --> ingest + ingest --> filter + filter --> dedup + dedup --> shard + tok_model --> shard + shard --> parquet + parquet --> packed + + cfg_model --> model_cfg + model_cfg --> rms + model_cfg --> rope + model_cfg --> attn + model_cfg --> mlp + rms --> blocks + rope --> attn + attn --> blocks + mlp --> blocks + blocks --> sage + + packed --> trainer + cfg_train --> trainer + cfg_model --> trainer + hw --> trainer + opt --> trainer + loss --> trainer + ckpt --> trainer + sage --> trainer + + trainer --> metrics + trainer --> saves + trainer --> ppl + + sage --> ppl + sage --> bench + sage --> longctx + ppl --> regress + bench --> regress + longctx --> regress + + sage --> kv + sage --> quant + sage --> api + quant --> cpu + kv --> api + api --> health + api --> generate + cpu --> health + + classDef input fill:#0f172a,stroke:#93c5fd,color:#ffffff + classDef token fill:#1d4ed8,stroke:#bfdbfe,color:#ffffff + classDef prep fill:#0f766e,stroke:#99f6e4,color:#ffffff + classDef model fill:#581c87,stroke:#d8b4fe,color:#ffffff + classDef train fill:#92400e,stroke:#fde68a,color:#ffffff + classDef eval fill:#991b1b,stroke:#fecaca,color:#ffffff + classDef serve fill:#166534,stroke:#bbf7d0,color:#ffffff + + class user,raw,cfg_model,cfg_train,cfg_data input + class tok_train,tok_validate,tok_model token + class ingest,filter,dedup,shard,parquet,packed prep + class model_cfg,rms,rope,attn,mlp,blocks,sage model + class hw,opt,loss,ckpt,trainer,metrics,saves train + class ppl,bench,longctx,regress eval + class kv,quant,api,cpu,health,generate serve diff --git a/docs/llm_Arch.mmd b/docs/llm_Arch.mmd new file mode 100644 index 0000000000000000000000000000000000000000..997f9dc38d09fd6b6c228844feac831659f2b5b7 --- /dev/null +++ b/docs/llm_Arch.mmd @@ -0,0 +1,202 @@ +--- +title: SAGE 1B System Architecture +--- +flowchart TB + + %% ========================================================= + %% SAGE 1B - End-to-End Architecture and Flow Overview + %% ========================================================= + + user["Developer / Operator"] + + subgraph repo["SAGE Repository"] + direction TB + + subgraph configs["configs/"] + cfg_model["model/1b.yaml
24L, 2048 d_model, 16Q / 8KV, 4096 ctx"] + cfg_data["data/mix.yaml
corpus weights + split ratios"] + cfg_train["train/schedule.yaml
LR, warmup, checkpoints, logging"] + end + + subgraph tokenizer["tokenizer/"] + tok_train["train_tokenizer.py
SentencePiece BPE training"] + tok_validate["validate_tokenizer.py
roundtrip + edge-case checks"] + tok_model["tokenizer.model / tokenizer.vocab"] + end + + subgraph data_layer["data/"] + ingest["ingest.py
source registry + raw record streaming"] + filter["filter.py
license, lang, PII, safety, quality"] + dedup["dedup.py
exact + near-duplicate removal"] + shard["shard.py
tokenize -> parquet shards + manifest"] + dataset["dataset.py
PackedDataset + skip(n_batches)"] + end + + subgraph model_layer["model/"] + model_cfg["config.py
ModelConfig"] + rmsnorm["rmsnorm.py
pre-norm RMSNorm"] + rope["rope.py
RoPE cache + apply_rope"] + attn["attention.py
fused QKV + GQA + SDPA"] + mlp["mlp.py
SwiGLU FFN"] + block["block.py
Transformer block"] + full_model["model.py
SageTransformer"] + end + + subgraph train_layer["train/"] + hw["hardware.py
device / dtype / batch routing"] + dist["distributed.py
single / DDP / FSDP strategy"] + opt["optimizer.py
AdamW + cosine schedule"] + loss["loss.py
masked next-token cross entropy"] + ckpt["checkpoint.py
save / prune / resume"] + trainer["trainer.py
main training loop"] + end + + subgraph eval_layer["eval/"] + ppl["perplexity.py
validation loss + perplexity"] + benches["benchmarks.py
benchmark harness registry"] + longctx["long_context.py
needle-in-haystack probes"] + regress["regression.py
checkpoint metric comparison"] + end + + subgraph serve_layer["serve/"] + kv["kv_cache.py
cache container"] + quant["quantize.py
int8 export + GGUF command helper"] + gpu_api["server.py
FastAPI GPU server"] + cpu_api["server_cpu.py
FastAPI CPU readiness surface"] + end + + subgraph scripts["scripts/"] + s_data["run_data_pipeline.sh"] + s_train["run_training.sh"] + s_eval["run_eval.sh"] + s_serve["run_serve.sh / run_serve_cpu.sh"] + end + + subgraph outputs["Runtime Outputs"] + raw["Raw text / JSONL corpora"] + parquet["Parquet shards + manifest.json"] + runs["runs//metrics.jsonl"] + checkpoints["runs//ckpt_step_xxxxxxx.pt"] + api_out["/health + /generate responses"] + end + end + + %% ========================================================= + %% Top-level usage + %% ========================================================= + + user --> s_data + user --> s_train + user --> s_eval + user --> s_serve + + s_data --> tok_train + s_train --> trainer + s_eval --> benches + s_serve --> gpu_api + s_serve --> cpu_api + + %% ========================================================= + %% Tokenizer flow + %% ========================================================= + + raw --> tok_train + tok_train --> tok_model + tok_model --> tok_validate + + %% ========================================================= + %% Data preparation flow + %% ========================================================= + + raw --> ingest + ingest --> filter + filter --> dedup + dedup --> shard + tok_model --> shard + shard --> parquet + parquet --> dataset + + cfg_data --> ingest + cfg_data --> filter + cfg_data --> shard + + %% ========================================================= + %% Model construction flow + %% ========================================================= + + cfg_model --> model_cfg + model_cfg --> rmsnorm + model_cfg --> rope + model_cfg --> attn + model_cfg --> mlp + rmsnorm --> block + rope --> attn + attn --> block + mlp --> block + block --> full_model + + %% ========================================================= + %% Training flow + %% ========================================================= + + cfg_train --> opt + cfg_train --> trainer + cfg_train --> ckpt + model_cfg --> full_model + dataset --> trainer + full_model --> trainer + hw --> trainer + dist --> trainer + opt --> trainer + loss --> trainer + ckpt --> trainer + + trainer --> runs + trainer --> checkpoints + trainer --> ppl + + %% ========================================================= + %% Evaluation flow + %% ========================================================= + + full_model --> ppl + full_model --> benches + full_model --> longctx + ppl --> regress + benches --> regress + longctx --> regress + + %% ========================================================= + %% Serving flow + %% ========================================================= + + full_model --> kv + full_model --> quant + full_model --> gpu_api + quant --> cpu_api + kv --> gpu_api + hw --> gpu_api + gpu_api --> api_out + cpu_api --> api_out + + %% ========================================================= + %% Visual grouping + %% ========================================================= + + classDef config fill:#1f2937,stroke:#93c5fd,color:#ffffff + classDef pipeline fill:#0f766e,stroke:#5eead4,color:#ffffff + classDef model fill:#4c1d95,stroke:#c4b5fd,color:#ffffff + classDef train fill:#92400e,stroke:#fcd34d,color:#ffffff + classDef eval fill:#7f1d1d,stroke:#fca5a5,color:#ffffff + classDef serve fill:#065f46,stroke:#86efac,color:#ffffff + classDef io fill:#111827,stroke:#9ca3af,color:#ffffff + classDef actor fill:#2563eb,stroke:#bfdbfe,color:#ffffff + + class user actor + class cfg_model,cfg_data,cfg_train config + class tok_train,tok_validate,ingest,filter,dedup,shard,dataset pipeline + class model_cfg,rmsnorm,rope,attn,mlp,block,full_model model + class hw,dist,opt,loss,ckpt,trainer train + class ppl,benches,longctx,regress eval + class kv,quant,gpu_api,cpu_api serve + class raw,parquet,runs,checkpoints,api_out,tok_model,s_data,s_train,s_eval,s_serve io diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b59ba94aa085b4d0cb76f07fddfdbeb864c5c39c --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1 @@ +"""Evaluation helpers for SAGE.""" diff --git a/eval/benchmarks.py b/eval/benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7d06e2608778fc3831b2f2677662ace569273c --- /dev/null +++ b/eval/benchmarks.py @@ -0,0 +1,42 @@ +"""Benchmark harness registration for SAGE.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class BenchmarkResult: + """A normalized benchmark result.""" + + name: str + status: str + score: float | None + detail: str + + +BENCHMARKS = ( + "hellaswag", + "winogrande", + "arc_easy", + "arc_challenge", + "gsm8k", + "math", + "humaneval", + "mbpp", +) + + +def run_registered_benchmarks(model, tokenizer=None) -> list[BenchmarkResult]: + """Return a lightweight result set for the configured benchmarks.""" + _ = model + _ = tokenizer + return [ + BenchmarkResult( + name=name, + status="skipped", + score=None, + detail="Benchmark harness registered; dataset/task execution is external to unit tests.", + ) + for name in BENCHMARKS + ] diff --git a/eval/long_context.py b/eval/long_context.py new file mode 100644 index 0000000000000000000000000000000000000000..04dee76b6f7874a1d91b172cd13a1c952ab8248e --- /dev/null +++ b/eval/long_context.py @@ -0,0 +1,24 @@ +"""Long-context retrieval evaluation helpers.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class RetrievalProbe: + """A synthetic retrieval probe for long-context checks.""" + + prompt: str + needle: str + expected_index: int + + +def build_needle_in_haystack_probe(context_length: int) -> RetrievalProbe: + """Create a deterministic retrieval prompt for smoke tests.""" + needle = "SAGE_LONG_CONTEXT_NEEDLE" + haystack = ["token"] * max(context_length - 16, 16) + insert_at = min(len(haystack) // 2, max(context_length // 4, 1)) + haystack.insert(insert_at, needle) + prompt = " ".join(haystack) + return RetrievalProbe(prompt=prompt, needle=needle, expected_index=insert_at) diff --git a/eval/perplexity.py b/eval/perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..f512b9b49ce80eb2fbb2791dd81d7d32fe0797aa --- /dev/null +++ b/eval/perplexity.py @@ -0,0 +1,39 @@ +"""Validation perplexity evaluation.""" + +from __future__ import annotations + +import math + +import torch + +from train.loss import masked_cross_entropy + + +@torch.no_grad() +def evaluate_perplexity( + model: torch.nn.Module, + dataloader, + device: torch.device, + dtype: torch.dtype | None = None, + max_batches: int = 16, +) -> dict[str, float]: + """Evaluate average loss and perplexity on a validation loader.""" + model.eval() + losses: list[float] = [] + for index, batch in enumerate(dataloader): + if index >= max_batches: + break + input_ids = batch["input_ids"].to(device) + labels = batch["labels"].to(device) + loss_mask = batch["loss_mask"].to(device) + if dtype is not None and device.type != "cpu": + with torch.amp.autocast(device_type=device.type, dtype=dtype): + logits, _ = model(input_ids) + loss = masked_cross_entropy(logits, labels, loss_mask) + else: + logits, _ = model(input_ids) + loss = masked_cross_entropy(logits, labels, loss_mask) + losses.append(float(loss)) + model.train() + mean_loss = sum(losses) / max(len(losses), 1) + return {"loss": mean_loss, "perplexity": math.exp(min(mean_loss, 20.0))} diff --git a/eval/regression.py b/eval/regression.py new file mode 100644 index 0000000000000000000000000000000000000000..6148c87a964ead914315ef5885fe32b291264eb0 --- /dev/null +++ b/eval/regression.py @@ -0,0 +1,15 @@ +"""Checkpoint-to-checkpoint regression checks.""" + +from __future__ import annotations + + +def compare_metrics(previous: dict[str, float], current: dict[str, float], threshold: float = 0.005) -> dict[str, object]: + """Flag metric drops larger than the configured threshold.""" + regressions: list[str] = [] + for key, prev_value in previous.items(): + curr_value = current.get(key) + if curr_value is None: + continue + if curr_value < prev_value * (1.0 - threshold): + regressions.append(key) + return {"regressions": regressions, "passed": not regressions} diff --git a/hf_push.py b/hf_push.py new file mode 100644 index 0000000000000000000000000000000000000000..daa294a3dc8e6e1355d208e62ec1d66b336c90fc --- /dev/null +++ b/hf_push.py @@ -0,0 +1,40 @@ +"""Upload the current SAGE repository contents to the Hugging Face Hub.""" + +from __future__ import annotations + +from huggingface_hub import HfApi + + +REPO_ID = "sage002/sage" + + +def main() -> None: + """Replace the remote Hugging Face repo contents with the local folder state.""" + api = HfApi() + print(f"Syncing current repository to {REPO_ID}...") + api.upload_folder( + folder_path=".", + repo_id=REPO_ID, + repo_type="model", + ignore_patterns=[ + ".git/*", + ".venv/*", + "__pycache__/*", + "*.pyc", + "checkpoints/*", + "runs/*", + "wandb/*", + "data/raw/*", + "data/processed/*", + "tokenizer/*.model", + "tokenizer/*.vocab", + "tokenizer/training_corpus.txt", + ], + delete_patterns="*", + commit_message="feat: rewrite SAGE 1B architecture and replace legacy repo contents", + ) + print("Sync complete.") + + +if __name__ == "__main__": + main() diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..899776f1981644f2ee307fea754d64415ffd1634 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1 @@ +"""Model architecture for SAGE.""" diff --git a/model/attention.py b/model/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd7537dc55bf5208b30911f66ff19139d668cd1 --- /dev/null +++ b/model/attention.py @@ -0,0 +1,76 @@ +"""Grouped-query attention with SDPA and KV-cache support.""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from model.config import ModelConfig +from model.rope import apply_rope + + +def repeat_kv(x: torch.Tensor, num_groups: int) -> torch.Tensor: + """Expand KV heads to match the number of query heads.""" + if num_groups == 1: + return x + batch, kv_heads, seq_len, head_dim = x.shape + x = x[:, :, None, :, :].expand(batch, kv_heads, num_groups, seq_len, head_dim) + return x.reshape(batch, kv_heads * num_groups, seq_len, head_dim) + + +class GQAAttention(nn.Module): + """Fused-QKV grouped-query attention.""" + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.num_heads = config.num_attn_heads + self.num_kv_heads = config.num_kv_heads + self.head_dim = config.head_dim + self.num_groups = self.num_heads // self.num_kv_heads + qkv_dim = (self.num_heads + 2 * self.num_kv_heads) * self.head_dim + self.qkv_proj = nn.Linear(config.d_model, qkv_dim, bias=False) + self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.dropout = config.dropout + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Compute causal self-attention and return an updated KV cache.""" + batch_size, seq_len, _ = hidden_states.shape + qkv = self.qkv_proj(hidden_states) + q_end = self.num_heads * self.head_dim + k_end = q_end + self.num_kv_heads * self.head_dim + q, k, v = qkv.split((q_end, self.num_kv_heads * self.head_dim, self.num_kv_heads * self.head_dim), dim=-1) + + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + q_rope, k_rope = apply_rope(q, repeat_kv(k, self.num_groups), cos, sin) + k = k_rope[:, :: self.num_groups, :, :] + + if past_key_value is not None: + past_key, past_value = past_key_value + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + + expanded_k = repeat_kv(k, self.num_groups) + expanded_v = repeat_kv(v, self.num_groups) + attn_output = F.scaled_dot_product_attention( + q_rope, + expanded_k, + expanded_v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0.0, + is_causal=past_key_value is None, + ) + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.config.d_model) + return self.out_proj(attn_output), (k, v) diff --git a/model/block.py b/model/block.py new file mode 100644 index 0000000000000000000000000000000000000000..1639e721493ed7154b0c125fc298a3fb66b81baa --- /dev/null +++ b/model/block.py @@ -0,0 +1,37 @@ +"""Transformer block for the dense SAGE model.""" + +from __future__ import annotations + +from typing import Optional + +import torch +from torch import nn + +from model.attention import GQAAttention +from model.config import ModelConfig +from model.mlp import SwiGLUMLP +from model.rmsnorm import RMSNorm + + +class TransformerBlock(nn.Module): + """Pre-norm transformer block with attention and SwiGLU.""" + + def __init__(self, config: ModelConfig): + super().__init__() + self.norm1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) + self.attn = GQAAttention(config) + self.norm2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) + self.mlp = SwiGLUMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Forward pass with residual connections.""" + attn_output, present = self.attn(self.norm1(hidden_states), cos, sin, past_key_value=past_key_value) + hidden_states = hidden_states + attn_output + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states, present diff --git a/model/config.py b/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7636110d014a61d5ed1b5e58f4e66f7df6637d --- /dev/null +++ b/model/config.py @@ -0,0 +1,48 @@ +"""Model configuration for SAGE.""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import yaml + + +@dataclass +class ModelConfig: + """Configuration for the dense SAGE decoder-only transformer.""" + + name: str = "sage-1b" + num_layers: int = 24 + d_model: int = 2048 + num_attn_heads: int = 16 + num_kv_heads: int = 8 + head_dim: int = 128 + ffn_hidden_dim: int = 5632 + vocab_size: int = 50_000 + context_length: int = 4096 + rope_base_frequency: int = 500_000 + rope_scaling_factor: float = 1.0 + dropout: float = 0.0 + tie_word_embeddings: bool = True + rms_norm_eps: float = 1.0e-5 + initializer_range: float = 0.02 + + def __post_init__(self) -> None: + if self.num_attn_heads * self.head_dim != self.d_model: + raise ValueError("num_attn_heads * head_dim must equal d_model.") + if self.num_attn_heads % self.num_kv_heads != 0: + raise ValueError("num_attn_heads must be divisible by num_kv_heads.") + if self.ffn_hidden_dim % 256 != 0: + raise ValueError("ffn_hidden_dim must be a multiple of 256.") + + @classmethod + def from_yaml(cls, path: str | Path) -> "ModelConfig": + """Load a config from YAML.""" + payload = yaml.safe_load(Path(path).read_text(encoding="utf-8")) + return cls(**payload) + + def to_dict(self) -> dict[str, Any]: + """Serialize the config to a dict.""" + return asdict(self) diff --git a/model/mlp.py b/model/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f68f18ffc65ec323e3e99ff62e5f395c420507 --- /dev/null +++ b/model/mlp.py @@ -0,0 +1,23 @@ +"""SwiGLU feed-forward module.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn + +from model.config import ModelConfig + + +class SwiGLUMLP(nn.Module): + """Bias-free SwiGLU feed-forward network.""" + + def __init__(self, config: ModelConfig): + super().__init__() + self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False) + self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=False) + self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply SwiGLU and project back to the model width.""" + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..924bfe9d99c8d6bf9c46573ae1f7ff12d9df5e3e --- /dev/null +++ b/model/model.py @@ -0,0 +1,73 @@ +"""Full dense decoder-only transformer model for SAGE.""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +from torch import nn + +from model.block import TransformerBlock +from model.config import ModelConfig +from model.rope import build_rope_cache +from model.rmsnorm import RMSNorm + + +class SageTransformer(nn.Module): + """A dense Llama-style decoder-only transformer.""" + + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) + self.norm = RMSNorm(config.d_model, eps=config.rms_norm_eps) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + if config.tie_word_embeddings: + self.lm_head.weight = self.embed_tokens.weight + cos, sin = build_rope_cache( + seq_len=config.context_length, + head_dim=config.head_dim, + base_frequency=config.rope_base_frequency, + scaling_factor=config.rope_scaling_factor, + ) + self.register_buffer("rope_cos", cos, persistent=False) + self.register_buffer("rope_sin", sin, persistent=False) + self._reset_parameters() + + def _reset_parameters(self) -> None: + """Apply scaled initialization to the model.""" + embed_std = 1.0 / math.sqrt(self.config.d_model) + nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=embed_std) + for module in self.modules(): + if not isinstance(module, nn.Linear): + continue + std = self.config.initializer_range + if module is self.lm_head and self.config.tie_word_embeddings: + continue + if module.out_features == self.config.d_model: + std = std / math.sqrt(2 * self.config.num_layers) + nn.init.normal_(module.weight, mean=0.0, std=std) + + def forward( + self, + input_ids: torch.Tensor, + past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: + """Return logits and the updated KV cache.""" + batch_size, seq_len = input_ids.shape + hidden_states = self.embed_tokens(input_ids) + past_key_values = past_key_values or [None] * self.config.num_layers + start = 0 + if past_key_values[0] is not None: + start = past_key_values[0][0].size(-2) + cos = self.rope_cos[start : start + seq_len].to(hidden_states.device) + sin = self.rope_sin[start : start + seq_len].to(hidden_states.device) + presents: list[tuple[torch.Tensor, torch.Tensor]] = [] + for layer, past in zip(self.layers, past_key_values): + hidden_states, present = layer(hidden_states, cos, sin, past_key_value=past) + presents.append(present) + hidden_states = self.norm(hidden_states) + logits = self.lm_head(hidden_states) + return logits, presents diff --git a/model/rmsnorm.py b/model/rmsnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8e020779654c5de33ed672f03fc1404728f05598 --- /dev/null +++ b/model/rmsnorm.py @@ -0,0 +1,23 @@ +"""RMSNorm implementation used by SAGE.""" + +from __future__ import annotations + +import torch +from torch import nn + + +class RMSNorm(nn.Module): + """Root mean square normalization with float32 accumulation.""" + + def __init__(self, dim: int, eps: float = 1.0e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Normalize the last dimension and cast back to the input dtype.""" + if x.ndim < 2: + raise ValueError("RMSNorm expects at least 2 dimensions.") + variance = x.float().pow(2).mean(dim=-1, keepdim=True) + normalized = x.float() * torch.rsqrt(variance + self.eps) + return (normalized.to(dtype=x.dtype)) * self.weight diff --git a/model/rope.py b/model/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a5c4501d7cedca68a451331f592f683ad3df99 --- /dev/null +++ b/model/rope.py @@ -0,0 +1,57 @@ +"""Rotary positional embedding helpers.""" + +from __future__ import annotations + +import torch + + +def _scaled_positions(seq_len: int, scaling_factor: float, device: torch.device) -> torch.Tensor: + """Apply a simple YaRN-style position scaling factor.""" + positions = torch.arange(seq_len, device=device, dtype=torch.float32) + if scaling_factor > 1.0: + positions = positions / scaling_factor + return positions + + +def build_rope_cache( + seq_len: int, + head_dim: int, + base_frequency: int = 500_000, + scaling_factor: float = 1.0, + device: torch.device | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Precompute cosine and sine tables for RoPE.""" + if head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE.") + device = device or torch.device("cpu") + positions = _scaled_positions(seq_len, scaling_factor, device) + inv_freq = 1.0 / (base_frequency ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)) + freqs = torch.outer(positions, inv_freq) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return cos, sin + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate the last dimension in pairs.""" + even = x[..., ::2] + odd = x[..., 1::2] + rotated = torch.stack((-odd, even), dim=-1) + return rotated.flatten(start_dim=-2) + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to query and key tensors.""" + if q.shape != k.shape: + raise ValueError("q and k must share the same shape for RoPE application.") + seq_len = q.size(-2) + cos = cos[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1) + sin = sin[:seq_len].unsqueeze(0).unsqueeze(0).repeat_interleave(2, dim=-1) + q_out = (q * cos) + (rotate_half(q) * sin) + k_out = (k * cos) + (rotate_half(k) * sin) + return q_out, k_out diff --git a/requirements.txt b/requirements.txt index 4eecbcc3eab614721ec71246bacd1b10163e898b..9c1a18500f1563237a4a5e9437034081163e5fcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,27 +1,13 @@ -# SAGE - Self-Adaptive General Engine -# ====================================== -# Core dependencies - -# PyTorch - GPU compatibility notes: -# - For Tesla P100 (sm_60), V100, T4, A100: torch>=2.1.0 -# - For older GPUs (sm_60): use torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121 -# - The code auto-detects GPU compatibility and falls back to CPU if needed torch>=2.1.0 - -# Tokenization & Data -tiktoken>=0.5.1 -datasets>=2.14.0 - -# Vector search (for RAG) -faiss-cpu>=1.7.4 - -# Utilities -tqdm>=4.66.1 -numpy<2.0.0 - -# Quantization (optional GPU support) -bitsandbytes>=0.41.0 - -# Model Hub & Experiment Tracking -huggingface_hub>=0.20.0 -wandb>=0.16.0 +fastapi>=0.110.0 +uvicorn>=0.29.0 +python-multipart>=0.0.9 +pydantic>=2.7.0 +pyyaml>=6.0.1 +sentencepiece>=0.2.0 +pyarrow>=16.0.0 +psutil>=5.9.8 +wandb>=0.17.0 +pytest>=8.2.0 +httpx>=0.27.0 +bitsandbytes>=0.43.0 diff --git a/sage/__init__.py b/sage/__init__.py deleted file mode 100644 index ad3e43525d4936e00aabb76949f8c0176e0fa0d6..0000000000000000000000000000000000000000 --- a/sage/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" -SAGE โ€” Self-Adaptive General Engine -A complete mini-LLM system built from scratch. -""" - -__version__ = "1.0.0" - -from .model import SageModel -from .config import SageConfig -from .data import SageTokenizer -from .inference import generate -from .memory import ConversationHistory, RAGManager -from .train import train -from .finetune import finetune_instruction as finetune -from .utils import get_compatible_device diff --git a/sage/cli.py b/sage/cli.py deleted file mode 100644 index 6c992d46242f6935ab2532f1fc9b2b4bc15a1045..0000000000000000000000000000000000000000 --- a/sage/cli.py +++ /dev/null @@ -1,299 +0,0 @@ -""" -SAGE CLI โ€” Interactive Terminal Interface -========================================== -Provides a REPL with slash-commands for training, fine-tuning, quantization, -RAG toggling, and real-time chat with streaming output. -""" - -import sys -import os -import torch -from typing import Optional - -from .config import SageConfig -from .model import SageModel -from .data import SageTokenizer -from .train import train -from .inference import generate -from .finetune import finetune_instruction, DEMO_INSTRUCTION_SAMPLES -from .optimize import quantize_int8 -from .memory import RAGManager, ConversationHistory -from .utils import setup_logger, save_checkpoint, load_checkpoint -from . import __version__ - -logger = setup_logger("sage.cli") - -# =================================================================== -# Banner -# =================================================================== - -BANNER = r""" -โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•— -โ•‘ โ•‘ -โ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ•‘ -โ•‘ โ•‘ -โ•‘ Self-Adaptive General Engine v{version} โ•‘ -โ•‘ โ•‘ -โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ• -""" - - -def print_banner(model: SageModel, config: SageConfig) -> None: - """Display startup banner with model statistics.""" - base_model = getattr(model, "module", model) - total_params = sum(p.numel() for p in base_model.parameters()) - trainable_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad) - - print(BANNER.format(version=__version__)) - print(f" Model params : {total_params:,} ({total_params/1e6:.1f}M)") - print(f" Trainable : {trainable_params:,}") - print(f" Context length: {config.max_seq_len}") - print(f" Device : {config.device}") - print(f" Layers: {config.n_layers} | Heads: {config.n_heads} | Experts: {config.n_experts}") - print() - print(" Type /help for commands, or start chatting!\n") - - -# =================================================================== -# Help text -# =================================================================== - -HELP_TEXT = """ -Available Commands: - /train [steps] Train the model (default: 100 steps) - /finetune [steps] Instruction-tune with LoRA (default: 200 steps) - /save Save current model checkpoint - /load Load latest checkpoint - /quantize Quantize model to INT8 (CPU only) - /rag on|off Enable/disable retrieval-augmented generation - /rag add Add a document for RAG retrieval - /clear Clear conversation history - /help Show this message - /exit Exit SAGE -""" - - -# =================================================================== -# Command handlers -# =================================================================== - -def handle_train(model, config, tokenizer, args): - """Handle /train [steps]""" - steps = 100 - if args: - try: - steps = int(args[0]) - except ValueError: - print(f" Invalid step count: {args[0]}") - return model - - print(f"\n Starting training for {steps} steps โ€ฆ\n") - model = train(model, config, total_steps=steps, tokenizer=tokenizer, resume=True) - - # Show a quick sample after training - print("\n --- Sample generation after training ---") - generate(model, tokenizer, "Once upon a time", max_new_tokens=80, stream=True, device=config.device) - print() - return model - - -def handle_finetune(model, config, tokenizer, args): - """Handle /finetune [steps]""" - steps = 200 - if args: - try: - steps = int(args[0]) - except ValueError: - print(f" Invalid step count: {args[0]}") - return model - - print(f"\n Starting instruction fine-tuning for {steps} steps (LoRA) โ€ฆ\n") - model = finetune_instruction( - model, config, - samples=DEMO_INSTRUCTION_SAMPLES, - total_steps=steps, - use_lora=True, - tokenizer=tokenizer, - ) - - print("\n --- Sample after fine-tuning ---") - prompt = "### Instruction:\nWhat is the speed of light?\n\n### Response:\n" - generate(model, tokenizer, prompt, max_new_tokens=100, stream=True, device=config.device) - print() - return model - - -def handle_save(model, config): - """Handle /save""" - path = save_checkpoint(model, None, 0, config.checkpoint_dir) - print(f" Model saved to {path}") - - -def handle_load(model, config): - """Handle /load""" - model, _, step = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device)) - model = model.to(config.device) - print(f" Model loaded (step {step})") - return model - - -def handle_quantize(model): - """Handle /quantize""" - print(" Quantizing model to INT8 (model will be on CPU) โ€ฆ") - model = quantize_int8(model) - print(" Quantization complete.") - return model - - -def handle_rag(rag_manager: RAGManager, args): - """Handle /rag on|off|add """ - if not args: - state = "enabled" if rag_manager.enabled else "disabled" - print(f" RAG is currently {state} ({rag_manager.store.size} chunks indexed)") - return - - subcmd = args[0].lower() - if subcmd == "on": - rag_manager.toggle(True) - print(" RAG enabled.") - elif subcmd == "off": - rag_manager.toggle(False) - print(" RAG disabled.") - elif subcmd == "add": - text = " ".join(args[1:]) - if text: - rag_manager.add_documents([text]) - print(f" Document added. Store now has {rag_manager.store.size} chunks.") - else: - print(" Usage: /rag add ") - else: - print(" Usage: /rag on|off|add ") - - -# =================================================================== -# Main REPL -# =================================================================== - -def main() -> None: - """Entry point for the SAGE interactive CLI.""" - config = SageConfig() - tokenizer = SageTokenizer() - - # Ensure vocab_size matches the tokenizer - config.vocab_size = tokenizer.vocab_size - - print(" Initializing SAGE model โ€ฆ") - model = SageModel(config) - model = model.to(config.device) - - if torch.cuda.is_available() and torch.cuda.device_count() > 1: - print(f" Multi-GPU detected! Wrapping model in DataParallel across {torch.cuda.device_count()} GPUs.") - model = torch.nn.DataParallel(model) - - # Attempt to load existing checkpoint - model, _, loaded_step = load_checkpoint( - model, None, config.checkpoint_dir, device=str(config.device) - ) - if loaded_step > 0: - print(f" Resumed from checkpoint at step {loaded_step}") - - print_banner(model, config) - - # Initialize RAG and conversation history - rag_manager = RAGManager(model, tokenizer, config.device) - history = ConversationHistory(tokenizer, max_tokens=config.max_seq_len - 128) - - # ---------- One-liner CLI arguments ---------- - if len(sys.argv) > 1: - cmd = sys.argv[1].lower() - args = sys.argv[2:] - if cmd == "--train": - handle_train(model, config, tokenizer, args) - elif cmd == "--finetune": - handle_finetune(model, config, tokenizer, args) - elif cmd == "--quantize": - handle_quantize(model) - else: - print(f" Unknown argument: {cmd}") - print(" Usage: --train [steps] | --finetune [steps] | --quantize") - return - - # ---------- REPL loop ---------- - while True: - try: - user_input = input("You: ").strip() - except (EOFError, KeyboardInterrupt): - print("\n Goodbye!") - break - - if not user_input: - continue - - # ---------- Slash commands ---------- - if user_input.startswith("/"): - parts = user_input.split() - cmd = parts[0].lower() - args = parts[1:] - - if cmd == "/exit": - print(" Goodbye!") - break - elif cmd == "/help": - print(HELP_TEXT) - elif cmd == "/train": - model = handle_train(model, config, tokenizer, args) - elif cmd == "/finetune": - model = handle_finetune(model, config, tokenizer, args) - elif cmd == "/save": - handle_save(model, config) - elif cmd == "/load": - model = handle_load(model, config) - # Re-attach to RAG manager since model changed - rag_manager.model = model - elif cmd == "/quantize": - model = handle_quantize(model) - rag_manager.model = model - elif cmd == "/rag": - handle_rag(rag_manager, args) - elif cmd == "/clear": - history.clear() - print(" Conversation history cleared.") - else: - print(f" Unknown command: {cmd}. Type /help for a list.") - continue - - # ---------- Chat mode ---------- - # Build prompt with history and optional RAG context - rag_context = rag_manager.retrieve_context(user_input) - prompt = history.build_prompt(user_input, rag_context=rag_context) - - history.add("user", user_input) - - print("SAGE: ", end="", flush=True) - response = generate( - model, - tokenizer, - prompt, - max_new_tokens=256, - temperature=0.8, - top_k=50, - top_p=0.9, - stream=True, - device=config.device, - ) - - # Extract only the SAGE response part from the full generation - if "SAGE:" in response: - reply = response.split("SAGE:")[-1].strip() - else: - reply = response[len(prompt):].strip() - - history.add("assistant", reply) - - -if __name__ == "__main__": - main() diff --git a/sage/config.py b/sage/config.py deleted file mode 100644 index d1f185bc80c84431164e0235d002f7d3d5132127..0000000000000000000000000000000000000000 --- a/sage/config.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from dataclasses import dataclass, field -from typing import Any - -@dataclass -class SageConfig: - # Model dimensions corresponding to T4 (16GB VRAM) fit - d_model: int = 512 - n_heads: int = 8 - n_kv_heads: int = 4 # GQA: must divide n_heads - n_layers: int = 6 - d_ff: int = 2048 - - # MoE (Mixture of Experts) config - n_experts: int = 4 - num_experts_per_tok: int = 2 - - # Vocabulary and sequence parameters - vocab_size: int = 100277 # Default for tiktoken "cl100k_base" - max_seq_len: int = 1024 - - # Regularization - dropout: float = 0.1 - - # Training Loop defaults - batch_size: int = 4 - gradient_accumulation_steps: int = 16 - learning_rate: float = 3e-4 - min_learning_rate: float = 1e-5 - warmup_steps: int = 100 - weight_decay: float = 0.01 - max_grad_norm: float = 1.0 - - # Checkpointing and path details - checkpoint_dir: str = "checkpoints" - project_name: str = "sage-v2" - - # Cache for device (set on first access) - _device: Any = field(default=None, repr=False) - - @property - def device(self): - """Returns the best available device with CUDA compatibility checking.""" - if self._device is None: - # Import here to avoid circular imports - from .utils import get_compatible_device - self._device = get_compatible_device() - return self._device diff --git a/sage/data.py b/sage/data.py deleted file mode 100644 index 4c80b5d9c584ad1ebdf8fdf8a2f65feebab433cd..0000000000000000000000000000000000000000 --- a/sage/data.py +++ /dev/null @@ -1,255 +0,0 @@ -""" -SAGE Data Pipeline -================== -Handles tokenization (tiktoken), streaming dataset loading from HuggingFace, -text cleaning, chunking into fixed-length sequences, and batched DataLoader -construction with shuffle buffering. -""" - -import re -import random -import tiktoken -import torch -from torch.utils.data import IterableDataset, DataLoader -from typing import Iterator, List, Optional -from .config import SageConfig -from .utils import setup_logger - -logger = setup_logger("sage.data") - -# --------------------------------------------------------------------------- -# Tokenizer wrapper -# --------------------------------------------------------------------------- - -class SageTokenizer: - """Thin wrapper around tiktoken providing encode/decode and special tokens.""" - - def __init__(self, encoding_name: str = "cl100k_base"): - self.enc = tiktoken.get_encoding(encoding_name) - self.encoding_name = encoding_name - - # Use the last token in the vocabulary as the EOS sentinel. - # tiktoken doesn't expose a dedicated EOS, so we pick one that - # won't collide with real text. - self.eos_token_id: int = self.enc.n_vocab - 1 - self.pad_token_id: int = self.enc.n_vocab - 2 - self.vocab_size: int = self.enc.n_vocab - - def encode(self, text: str, add_eos: bool = False) -> List[int]: - """Encode text to token IDs.""" - tokens = self.enc.encode(text, allowed_special="all") - if add_eos: - tokens.append(self.eos_token_id) - return tokens - - def decode(self, tokens: List[int]) -> str: - """Decode token IDs back to text, filtering out special sentinel IDs.""" - # Filter out our custom pad/eos sentinels before decoding - filtered = [t for t in tokens if t not in (self.eos_token_id, self.pad_token_id)] - return self.enc.decode(filtered) - -# --------------------------------------------------------------------------- -# Text cleaning -# --------------------------------------------------------------------------- - -_HTML_TAG_RE = re.compile(r"<[^>]+>") -_MULTI_SPACE_RE = re.compile(r"[ \t]+") -_MULTI_NEWLINE_RE = re.compile(r"\n{3,}") - - -def clean_text(text: str) -> str: - """Strip HTML tags, collapse whitespace, and trim to reasonable length.""" - text = _HTML_TAG_RE.sub("", text) # remove HTML tags - text = _MULTI_SPACE_RE.sub(" ", text) # collapse horizontal whitespace - text = _MULTI_NEWLINE_RE.sub("\n\n", text) # collapse vertical whitespace - return text.strip() - -# --------------------------------------------------------------------------- -# Streaming iterable dataset -# --------------------------------------------------------------------------- - -class StreamingTextDataset(IterableDataset): - """ - An IterableDataset that streams data from HuggingFace ``datasets``, - tokenizes on the fly, and yields fixed-length chunks. - - It maintains an internal shuffle buffer so that consecutive chunks are - not always from the same document. - """ - - def __init__( - self, - dataset_name: str = "HuggingFaceFW/fineweb-edu", - split: str = "train", - seq_len: int = 512, - tokenizer: Optional[SageTokenizer] = None, - shuffle_buffer_size: int = 1000, - text_field: str = "text", - min_doc_len: int = 50, - max_doc_len: int = 50000, - ): - super().__init__() - self.dataset_name = dataset_name - self.split = split - self.seq_len = seq_len - self.tokenizer = tokenizer or SageTokenizer() - self.shuffle_buffer_size = shuffle_buffer_size - self.text_field = text_field - self.min_doc_len = min_doc_len - self.max_doc_len = max_doc_len - - # Auto-adjust configuration based on popular datasets - if "fineweb-edu" in dataset_name.lower(): - self.text_field = "text" - self.split = "train" if split == "train" else split - elif "tinystories" in dataset_name.lower(): - self.text_field = "text" - - def _stream_tokens(self) -> Iterator[int]: - """Yields individual token IDs from the HuggingFace dataset stream.""" - try: - from datasets import load_dataset - except ImportError: - raise ImportError( - "The 'datasets' library is required. Install it with: " - "pip install datasets" - ) - - logger.info( - f"Streaming dataset '{self.dataset_name}' (split={self.split}) โ€ฆ" - ) - ds = load_dataset( - self.dataset_name, - split=self.split, - streaming=True, - ) - - for sample in ds: - raw = sample.get(self.text_field, "") - if not raw: - continue - - text = clean_text(raw) - - # Filter documents that are too short or too long - if len(text) < self.min_doc_len or len(text) > self.max_doc_len: - continue - - tokens = self.tokenizer.encode(text, add_eos=True) - yield from tokens - - def _chunk_tokens(self) -> Iterator[torch.Tensor]: - """Groups raw token stream into fixed-length chunks of (seq_len + 1). - - The extra token is needed so that input = chunk[:-1] and - target = chunk[1:] for next-token-prediction. - """ - chunk: List[int] = [] - for tok in self._stream_tokens(): - chunk.append(tok) - if len(chunk) == self.seq_len + 1: - yield torch.tensor(chunk, dtype=torch.long) - chunk = [] - # Discard any trailing partial chunk - - def __iter__(self) -> Iterator[torch.Tensor]: - """Yields shuffled chunks from an internal buffer.""" - buffer: List[torch.Tensor] = [] - for chunk in self._chunk_tokens(): - buffer.append(chunk) - if len(buffer) >= self.shuffle_buffer_size: - random.shuffle(buffer) - while len(buffer) > self.shuffle_buffer_size // 2: - yield buffer.pop() - # Flush remaining items - random.shuffle(buffer) - yield from buffer - -# --------------------------------------------------------------------------- -# DataLoader factory -# --------------------------------------------------------------------------- - -def create_dataloader( - config: SageConfig, - dataset_name: str = "HuggingFaceFW/fineweb-edu", - split: str = "train", - tokenizer: Optional[SageTokenizer] = None, -) -> DataLoader: - """Creates a streaming DataLoader ready for the training loop.""" - tok = tokenizer or SageTokenizer() - ds = StreamingTextDataset( - dataset_name=dataset_name, - split=split, - seq_len=config.max_seq_len, - tokenizer=tok, - ) - return DataLoader( - ds, - batch_size=config.batch_size, - num_workers=2, - pin_memory=True, - drop_last=True, - ) - -# --------------------------------------------------------------------------- -# Instruction-tuning data helpers -# --------------------------------------------------------------------------- - -INSTRUCTION_TEMPLATE = ( - "### Instruction:\n{instruction}\n\n### Response:\n{response}" -) - - -def format_instruction_sample(instruction: str, response: str) -> str: - """Formats an instruction/response pair into the chat template.""" - return INSTRUCTION_TEMPLATE.format( - instruction=instruction.strip(), - response=response.strip(), - ) - - -def create_instruction_batch( - samples: List[dict], - tokenizer: SageTokenizer, - max_len: int = 512, -) -> dict: - """ - Tokenize a list of {instruction, response} dicts and produce input_ids, - labels, and a loss_mask that zeros out the instruction portion. - - Returns a dict with keys: input_ids, labels, loss_mask โ€” all as tensors. - """ - all_input_ids: List[List[int]] = [] - all_labels: List[List[int]] = [] - all_masks: List[List[int]] = [] - - for sample in samples: - instruction_text = f"### Instruction:\n{sample['instruction'].strip()}\n\n### Response:\n" - response_text = sample["response"].strip() - full_text = instruction_text + response_text - - instruction_tokens = tokenizer.encode(instruction_text) - full_tokens = tokenizer.encode(full_text, add_eos=True) - - # Truncate to max_len - full_tokens = full_tokens[:max_len] - n_instruction = min(len(instruction_tokens), len(full_tokens)) - - # Labels are the same as input shifted by 1 (handled by caller), - # but we need a mask to zero out loss on instruction tokens. - mask = [0] * n_instruction + [1] * (len(full_tokens) - n_instruction) - - # Pad to max_len - pad_len = max_len - len(full_tokens) - full_tokens += [tokenizer.pad_token_id] * pad_len - mask += [0] * pad_len - - all_input_ids.append(full_tokens) - all_labels.append(full_tokens) # shift will be done in the loss fn - all_masks.append(mask) - - return { - "input_ids": torch.tensor(all_input_ids, dtype=torch.long), - "labels": torch.tensor(all_labels, dtype=torch.long), - "loss_mask": torch.tensor(all_masks, dtype=torch.float32), - } diff --git a/sage/finetune.py b/sage/finetune.py deleted file mode 100644 index a8b324ef91cc4d26b5cf57a254327b957f1dc5a1..0000000000000000000000000000000000000000 --- a/sage/finetune.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -SAGE Fine-Tuning -================ -Provides two fine-tuning modes: - -1. **Instruction tuning** โ€” trains on instruction/response pairs with loss - masked on the instruction portion. -2. **LoRA (Low-Rank Adaptation)** โ€” injects small trainable matrices into - attention layers while keeping the base model frozen. -""" - -import math -import time -import copy -import torch -import torch.nn as nn -from torch.amp import GradScaler, autocast -from tqdm import tqdm -import wandb -from typing import Optional, List - -from .config import SageConfig -from .model import SageModel, CausalSelfAttention -from .data import SageTokenizer, create_instruction_batch -from .train import create_optimizer, get_lr, set_lr -from .utils import setup_logger, save_checkpoint - -logger = setup_logger("sage.finetune") - - -# =================================================================== -# LoRA Implementation -# =================================================================== - -class LoRALinear(nn.Module): - """ - Wraps an existing ``nn.Linear`` with a low-rank adapter (A @ B). - - During fine-tuning only *A* and *B* are trained; the original weight - is frozen. After fine-tuning the adapter can be merged back into - the original weight for zero-overhead inference. - """ - - def __init__(self, original: nn.Linear, rank: int = 8, alpha: float = 16.0): - super().__init__() - self.original = original - self.rank = rank - self.alpha = alpha - self.scaling = alpha / rank - - in_features = original.in_features - out_features = original.out_features - - # Low-rank matrices - device, dtype = original.weight.device, original.weight.dtype - self.lora_A = nn.Parameter(torch.randn(in_features, rank, device=device, dtype=dtype) * 0.01) - self.lora_B = nn.Parameter(torch.zeros(rank, out_features, device=device, dtype=dtype)) - - # Freeze the original weight - self.original.weight.requires_grad = False - if self.original.bias is not None: - self.original.bias.requires_grad = False - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """original(x) + x @ A @ B * scaling""" - base_out = self.original(x) - lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling - return base_out + lora_out - - def merge(self) -> nn.Linear: - """Merge LoRA weights back into the original linear layer.""" - merged = copy.deepcopy(self.original) - merged.weight.data += (self.lora_B.T @ self.lora_A.T).T * self.scaling - merged.weight.requires_grad = True - return merged - - -# --------------------------------------------------------------------------- -# LoRA injection / removal helpers -# --------------------------------------------------------------------------- - -def inject_lora(model: SageModel, rank: int = 8, alpha: float = 16.0) -> SageModel: - """ - Replace the Q, K, V, O projection layers in every attention block with - LoRA-wrapped versions. Returns the same model (mutated in-place). - """ - base_model = getattr(model, "module", model) - for layer in base_model.layers: - attn: CausalSelfAttention = layer.attn - attn.wq = LoRALinear(attn.wq, rank=rank, alpha=alpha) - attn.wk = LoRALinear(attn.wk, rank=rank, alpha=alpha) - attn.wv = LoRALinear(attn.wv, rank=rank, alpha=alpha) - attn.wo = LoRALinear(attn.wo, rank=rank, alpha=alpha) - - # Log trainable parameter count - trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) - total = sum(p.numel() for p in model.parameters()) - logger.info( - f"LoRA injected (rank={rank}). Trainable: {trainable:,} / {total:,} " - f"({100 * trainable / total:.2f}%)" - ) - return model - - -def merge_lora(model: SageModel) -> SageModel: - """ - Merge all LoRA adapters back into the base weights and replace the - LoRALinear wrappers with plain nn.Linear modules. - """ - base_model = getattr(model, "module", model) - for layer in base_model.layers: - attn: CausalSelfAttention = layer.attn - for name in ("wq", "wk", "wv", "wo"): - module = getattr(attn, name) - if isinstance(module, LoRALinear): - setattr(attn, name, module.merge()) - logger.info("LoRA weights merged into base model.") - return model - - -# =================================================================== -# Instruction fine-tuning loop -# =================================================================== - -def finetune_instruction( - model: SageModel, - config: SageConfig, - samples: List[dict], - total_steps: int = 200, - use_lora: bool = True, - lora_rank: int = 8, - tokenizer: Optional[SageTokenizer] = None, -) -> SageModel: - """ - Fine-tune the model on instruction/response pairs. - - Parameters - ---------- - model : SageModel - config : SageConfig - samples : list[dict] - Each dict must contain ``instruction`` and ``response`` string keys. - total_steps : int - use_lora : bool - If True, inject LoRA adapters before training. - lora_rank : int - tokenizer : SageTokenizer, optional - - Returns - ------- - SageModel โ€” the fine-tuned model (LoRA merged if applicable). - """ - # --- TURBO MODE: TF32 & COMPILE --- - if torch.cuda.is_available(): - torch.set_float32_matmul_precision('high') - - device = config.device - model = model.to(device) - - # Wrap model with torch.compile for graph-level optimization - if hasattr(torch, "compile"): - logger.info("Turbo Mode: Compiling fine-tune engine...") - base = getattr(model, "module", model) - compiled_base = torch.compile(base, mode="reduce-overhead") - if hasattr(model, "module"): - model.module = compiled_base - else: - model = compiled_base - - tok = tokenizer or SageTokenizer() - - if use_lora: - model = inject_lora(model, rank=lora_rank) - - # ------- W&B Logging ------- - wandb.init( - project=config.project_name, - name=f"finetune-{time.strftime('%Y%m%d-%H%M')}", - config=config.__dict__, - ) - - optimizer = create_optimizer(model, config) - - # AMP setup - use_amp = device.type == "cuda" - amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 - scaler = GradScaler("cuda", enabled=(use_amp and amp_dtype == torch.float16)) - - model.train() - pbar = tqdm(range(total_steps), desc="Fine-tuning", unit="step") - accum_loss = 0.0 - - for step in pbar: - lr = get_lr(step, config, total_steps) - set_lr(optimizer, lr) - - # Build a batch by sampling from the instruction dataset - batch_size = min(config.batch_size, len(samples)) - import random - batch_samples = random.choices(samples, k=batch_size) - batch = create_instruction_batch(batch_samples, tok, max_len=config.max_seq_len) - - input_ids = batch["input_ids"].to(device) - labels = batch["labels"].to(device) - loss_mask = batch["loss_mask"].to(device) - - optimizer.zero_grad(set_to_none=True) - - with autocast(device.type, dtype=amp_dtype, enabled=use_amp): - logits, _ = model(input_ids) - # Shift: predict next token - shift_logits = logits[:, :-1, :].contiguous() - shift_labels = labels[:, 1:].contiguous() - shift_mask = loss_mask[:, 1:].contiguous() - - # Compute per-token loss - per_token_loss = nn.functional.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1), - reduction="none", - ) - per_token_loss = per_token_loss.view(shift_labels.size()) - - # Mask out instruction tokens so we only learn from responses - masked_loss = (per_token_loss * shift_mask).sum() / shift_mask.sum().clamp(min=1) - - scaler.scale(masked_loss).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) - scaler.step(optimizer) - scaler.update() - - accum_loss += masked_loss.item() - - if (step + 1) % 10 == 0: - avg = accum_loss / 10 - pbar.set_postfix(loss=f"{avg:.4f}", lr=f"{lr:.2e}") - logger.info(f"finetune step={step+1} | loss={avg:.4f}") - wandb.log({ - "finetune/loss": avg, - "finetune/lr": lr, - }, step=step + 1) - accum_loss = 0.0 - - # Merge LoRA weights back for clean inference - if use_lora: - model = merge_lora(model) - - save_checkpoint(model, None, total_steps, config.checkpoint_dir, filename="sage_finetuned.pt") - logger.info("Instruction fine-tuning complete. Checkpoint saved as sage_finetuned.pt") - wandb.finish() - return model - - -# --------------------------------------------------------------------------- -# Demo instruction samples (used when no dataset is provided) -# --------------------------------------------------------------------------- - -DEMO_INSTRUCTION_SAMPLES = [ - {"instruction": "What is the capital of France?", "response": "The capital of France is Paris."}, - {"instruction": "Explain gravity in simple terms.", "response": "Gravity is the force that pulls objects toward each other. The more mass an object has, the stronger its gravitational pull."}, - {"instruction": "Write a short poem about the ocean.", "response": "Waves crash upon the sandy shore,\nThe ocean's song forevermore.\nDeep blue stretching to the sky,\nSeagulls dance and clouds float by."}, - {"instruction": "What is 15 times 12?", "response": "15 times 12 equals 180."}, - {"instruction": "Summarize photosynthesis.", "response": "Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen, providing energy for the plant."}, - {"instruction": "Tell me a fun fact about space.", "response": "A day on Venus is longer than a year on Venus. It takes Venus 243 Earth days to rotate once on its axis but only 225 Earth days to orbit the Sun."}, - {"instruction": "How do airplanes fly?", "response": "Airplanes fly by generating lift through their wings. Air moves faster over the curved top of the wing than the flat bottom, creating lower pressure above and higher pressure below, which pushes the wing upward."}, - {"instruction": "What is machine learning?", "response": "Machine learning is a branch of artificial intelligence where computers learn patterns from data instead of being explicitly programmed, allowing them to make predictions or decisions."}, -] diff --git a/sage/inference.py b/sage/inference.py deleted file mode 100644 index dfad8b96b9d7b1904bc9bc6a195b15a7be6637bf..0000000000000000000000000000000000000000 --- a/sage/inference.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -SAGE Inference Engine -===================== -Text generation with greedy, temperature, top-k, and nucleus (top-p) sampling. -Supports KV-cache for O(1)-per-token generation and streaming output. -""" - -import sys -import torch -import torch.nn.functional as F -from typing import Optional, List - -from .config import SageConfig -from .model import SageModel -from .data import SageTokenizer -from .utils import setup_logger - -logger = setup_logger("sage.inference") - - -# --------------------------------------------------------------------------- -# Sampling helpers -# --------------------------------------------------------------------------- - -def _top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor: - """Zero out all logits outside the top-k highest values.""" - if k <= 0 or k >= logits.size(-1): - return logits - values, _ = torch.topk(logits, k) - min_val = values[:, -1].unsqueeze(-1) - return torch.where(logits < min_val, torch.full_like(logits, float("-inf")), logits) - - -def _top_p_filter(logits: torch.Tensor, p: float) -> torch.Tensor: - """Nucleus sampling: keep the smallest set of tokens whose cumulative - probability exceeds *p*.""" - if p >= 1.0: - return logits - sorted_logits, sorted_idx = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) - - # Identify tokens to remove (cumulative prob exceeds p) - sorted_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= p - sorted_logits[sorted_mask] = float("-inf") - - # Scatter back to original order - logits = logits.scatter(1, sorted_idx, sorted_logits) - return logits - - -def sample_next_token( - logits: torch.Tensor, - temperature: float = 0.8, - top_k: int = 50, - top_p: float = 0.9, - greedy: bool = False, -) -> torch.Tensor: - """ - Given raw logits for the last position, sample or greedily select the - next token. - - Parameters - ---------- - logits : Tensor [batch, vocab] - temperature : float - top_k : int - top_p : float - greedy : bool โ€” if True, ignore temperature/top-k/top-p and pick argmax. - - Returns - ------- - Tensor [batch, 1] - """ - if greedy: - return logits.argmax(dim=-1, keepdim=True) - - logits = logits / max(temperature, 1e-8) - logits = _top_k_filter(logits, top_k) - logits = _top_p_filter(logits, top_p) - - probs = F.softmax(logits, dim=-1) - return torch.multinomial(probs, num_samples=1) - - -# --------------------------------------------------------------------------- -# Main generation function -# --------------------------------------------------------------------------- - -@torch.no_grad() -def generate( - model: SageModel, - tokenizer: SageTokenizer, - prompt: str, - max_new_tokens: int = 256, - temperature: float = 0.8, - top_k: int = 50, - top_p: float = 0.9, - greedy: bool = False, - stream: bool = True, - device: Optional[torch.device] = None, -) -> str: - """ - Generate text from *prompt* using the SAGE model. - - Parameters - ---------- - model : SageModel - tokenizer : SageTokenizer - prompt : str - max_new_tokens : int - temperature, top_k, top_p : sampling hyper-parameters - greedy : bool โ€” use argmax decoding - stream : bool โ€” print tokens as they are generated - device : torch.device - - Returns - ------- - str โ€” the complete generated text (prompt + new tokens). - """ - if device is None: - device = next(model.parameters()).device - - base_model = getattr(model, "module", model) - base_model.eval() - - # Encode prompt - prompt_tokens = tokenizer.encode(prompt) - if not prompt_tokens: - prompt_tokens = [tokenizer.eos_token_id] - - input_ids = torch.tensor([prompt_tokens], dtype=torch.long, device=device) - - generated_tokens: List[int] = list(prompt_tokens) - kv_caches = None - - # --- Prefill: run the full prompt through the model once --- - logits, kv_caches = base_model(input_ids) - next_logits = logits[:, -1, :] - - for _ in range(max_new_tokens): - next_id = sample_next_token( - next_logits, - temperature=temperature, - top_k=top_k, - top_p=top_p, - greedy=greedy, - ) - - token_id = next_id.item() - - # Stop on EOS - if token_id == tokenizer.eos_token_id: - break - - generated_tokens.append(token_id) - - # Stream output: decode and print only the new token - if stream: - token_str = tokenizer.decode([token_id]) - print(token_str, end="", flush=True) - - # --- Decode step: feed only the new token, reuse KV-cache --- - next_input = next_id.view(1, 1) - logits, kv_caches = base_model(next_input, kv_caches=kv_caches) - next_logits = logits[:, -1, :] - - if stream: - print() # newline after streaming completes - - base_model.train() - return tokenizer.decode(generated_tokens) diff --git a/sage/memory.py b/sage/memory.py deleted file mode 100644 index 913edfdb277e74995b916b25fcaab8431055fd68..0000000000000000000000000000000000000000 --- a/sage/memory.py +++ /dev/null @@ -1,240 +0,0 @@ -""" -SAGE Memory & RAG Module -========================= -Provides: - - A FAISS-backed vector store for retrieval-augmented generation (RAG). - - A rolling conversation-history manager that truncates intelligently - to stay within the model's context window. -""" - -import os -import numpy as np -import torch -import torch.nn.functional as F -from typing import List, Optional, Tuple - -from .data import SageTokenizer -from .utils import setup_logger - -logger = setup_logger("sage.memory") - - -# =================================================================== -# Simple embedding helper (uses mean-pooled token embeddings) -# =================================================================== - -def _embed_text(text: str, tokenizer: SageTokenizer, model: torch.nn.Module, device: torch.device) -> np.ndarray: - """ - Produce a fixed-length embedding for *text* by mean-pooling the - model's token embeddings. This is lightweight and avoids a full - forward pass โ€” suitable for a small retrieval index. - """ - tokens = tokenizer.encode(text) - if not tokens: - # Return a zero vector when text is empty - d_model = model.wte.weight.shape[1] - return np.zeros(d_model, dtype=np.float32) - - ids = torch.tensor([tokens], dtype=torch.long, device=device) - with torch.no_grad(): - embeddings = model.wte(ids) # [1, seq_len, d_model] - mean_emb = embeddings.mean(dim=1) # [1, d_model] - # L2-normalize for cosine similarity in FAISS - mean_emb = F.normalize(mean_emb, p=2, dim=-1) - return mean_emb.squeeze(0).cpu().numpy() - - -# =================================================================== -# FAISS-backed Vector Store -# =================================================================== - -class VectorStore: - """ - A lightweight document store backed by FAISS (Inner Product index, - which equals cosine similarity when vectors are L2-normalized). - """ - - def __init__(self, dim: int): - try: - import faiss - except ImportError: - raise ImportError( - "FAISS is required for RAG. Install it with: pip install faiss-cpu" - ) - self.dim = dim - self.index = faiss.IndexFlatIP(dim) # inner-product (cosine after L2-norm) - self.documents: List[str] = [] - logger.info(f"VectorStore initialized (dim={dim})") - - def add(self, texts: List[str], embeddings: np.ndarray) -> None: - """Add documents and their embeddings to the store.""" - assert embeddings.shape[0] == len(texts) - assert embeddings.shape[1] == self.dim - self.index.add(embeddings.astype(np.float32)) - self.documents.extend(texts) - logger.info(f"Added {len(texts)} documents. Total: {len(self.documents)}") - - def search(self, query_embedding: np.ndarray, top_k: int = 3) -> List[Tuple[str, float]]: - """Return the top-k most similar documents with their scores.""" - if self.index.ntotal == 0: - return [] - query_embedding = query_embedding.reshape(1, -1).astype(np.float32) - scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal)) - results = [] - for score, idx in zip(scores[0], indices[0]): - if idx < 0: - continue - results.append((self.documents[idx], float(score))) - return results - - @property - def size(self) -> int: - return self.index.ntotal - - -# =================================================================== -# RAG Manager -# =================================================================== - -class RAGManager: - """ - High-level retrieval-augmented generation manager. - - Call ``add_documents`` to ingest text, then ``retrieve_context`` at - inference time to prepend relevant chunks to the user prompt. - """ - - def __init__( - self, - model: torch.nn.Module, - tokenizer: SageTokenizer, - device: torch.device, - chunk_size: int = 200, - chunk_overlap: int = 50, - ): - self.model = model - self.tokenizer = tokenizer - self.device = device - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - d_model = model.wte.weight.shape[1] - self.store = VectorStore(dim=d_model) - self.enabled = False - - def _chunk_text(self, text: str) -> List[str]: - """Split text into overlapping word-level chunks.""" - words = text.split() - chunks: List[str] = [] - start = 0 - while start < len(words): - end = start + self.chunk_size - chunk = " ".join(words[start:end]) - chunks.append(chunk) - start += self.chunk_size - self.chunk_overlap - return chunks - - def add_documents(self, texts: List[str]) -> None: - """Chunk and embed documents, then add to the vector store.""" - all_chunks: List[str] = [] - for text in texts: - all_chunks.extend(self._chunk_text(text)) - - if not all_chunks: - logger.warning("No document chunks to add.") - return - - embeddings = np.stack([ - _embed_text(chunk, self.tokenizer, self.model, self.device) - for chunk in all_chunks - ]) - self.store.add(all_chunks, embeddings) - - def retrieve_context(self, query: str, top_k: int = 3) -> str: - """ - Retrieve the top-k most relevant chunks for *query* and - concatenate them into a context string. - """ - if not self.enabled or self.store.size == 0: - return "" - - q_emb = _embed_text(query, self.tokenizer, self.model, self.device) - results = self.store.search(q_emb, top_k=top_k) - - if not results: - return "" - - context_parts = [f"[Context {i+1}] {doc}" for i, (doc, _score) in enumerate(results)] - return "\n\n".join(context_parts) + "\n\n" - - def toggle(self, on: bool) -> None: - self.enabled = on - state = "enabled" if on else "disabled" - logger.info(f"RAG {state}. Store contains {self.store.size} chunks.") - - -# =================================================================== -# Conversation History Manager -# =================================================================== - -DEFAULT_SYSTEM_PROMPT = ( - "You are SAGE, a high-quality reasoning assistant. " - "Your goal is to provide accurate, structured, and deep logical explanations.\n\n" - "CRITICAL GUIDELINES:\n" - "1. THINKING PHASE: You must ALWAYS start your response with a section. " - "In this section, break down the user's request, identify key constraints, and plan your logical steps.\n" - "2. RESPONSE PHASE: After completing your internal reasoning, provide your final answer within tags.\n" - "3. QUALITY: Prioritize step-by-step mathematical or logical derivation over short answers.\n" - "4. NO REPETITION: Avoid filler words or circular logic.\n\n" - "RESPONSE TEMPLATE:\n" - "\n[Step-by-step logic here]\n\n" - "\n[Final clear answer here]\n" -) - -class ConversationHistory: - """ - Rolling conversation history that stays within a token budget. - - Older turns are dropped when the history would exceed the context window. - """ - - def __init__(self, tokenizer: SageTokenizer, max_tokens: int = 900): - self.tokenizer = tokenizer - self.max_tokens = max_tokens - self.turns: List[dict] = [] # [{"role": "user"/"assistant", "text": ...}, ...] - - def add(self, role: str, text: str) -> None: - """Record a new conversational turn.""" - self.turns.append({"role": role, "text": text}) - self._trim() - - def _trim(self) -> None: - """Drop oldest turns until the total token count is within budget.""" - while self._total_tokens() > self.max_tokens and len(self.turns) > 1: - self.turns.pop(0) - - def _total_tokens(self) -> int: - return sum(len(self.tokenizer.encode(t["text"])) for t in self.turns) - - def build_prompt(self, new_user_message: str, rag_context: str = "") -> str: - """ - Assemble the full prompt from history + RAG context + new message. - """ - parts: List[str] = [] - - parts.append(DEFAULT_SYSTEM_PROMPT) - - if rag_context: - parts.append(rag_context) - - for turn in self.turns: - prefix = "User:" if turn["role"] == "user" else "SAGE:" - parts.append(f"{prefix} {turn['text']}") - - parts.append(f"User: {new_user_message}") - parts.append("SAGE:") - - return "\n".join(parts) - - def clear(self) -> None: - self.turns.clear() diff --git a/sage/model.py b/sage/model.py deleted file mode 100644 index 1c4a2e619dbbd3f7985feffd25b8a5c2e2d83664..0000000000000000000000000000000000000000 --- a/sage/model.py +++ /dev/null @@ -1,267 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -from typing import Optional, Tuple -from .config import SageConfig - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: - """Precomputes rotary positional embedding frequencies.""" - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - freqs = torch.outer(t, freqs).float() - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - -def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Applies rotary positional embeddings to queries and keys.""" - # Ensure freqs_cis is complex (DataParallel can sometimes replicate it as real) - if not torch.is_complex(freqs_cis) and freqs_cis.shape[-1] == 2: - freqs_cis = torch.view_as_complex(freqs_cis) - - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - - # Reshape freqs_cis to broadcast with xq_ and xk_ - # xq_, xk_ shape: [batch, seq_len, n_heads, dim_head//2] - # freqs_cis shape: [seq_len, dim_head//2] - freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2) - - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - - return xq_out.type_as(xq), xk_out.type_as(xk) - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """Repeat Key/Value heads n_rep times to match number of Query heads.""" - if n_rep == 1: - return x - B, T, n_kv_heads, head_dim = x.size() - return ( - x[:, :, :, None, :] - .expand(B, T, n_kv_heads, n_rep, head_dim) - .reshape(B, T, n_kv_heads * n_rep, head_dim) - ) - -class CausalSelfAttention(nn.Module): - def __init__(self, config: SageConfig): - super().__init__() - self.n_heads = config.n_heads - self.n_kv_heads = config.n_kv_heads - self.n_rep = self.n_heads // self.n_kv_heads - self.d_model = config.d_model - assert self.d_model % self.n_heads == 0 - self.head_dim = self.d_model // self.n_heads - - self.wq = nn.Linear(self.d_model, self.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(self.d_model, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(self.d_model, self.d_model, bias=False) - - self.resid_dropout = nn.Dropout(config.dropout) - - # Flash attention handles causality via is_causal flag if seq_len > 1 - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - B, T, C = x.size() # batch, seq_len, d_model - q, k, v = self.wq(x), self.wk(x), self.wv(x) - - q = q.view(B, T, self.n_heads, self.head_dim) - k = k.view(B, T, self.n_kv_heads, self.head_dim) - v = v.view(B, T, self.n_kv_heads, self.head_dim) - - q, k = apply_rotary_emb(q, k, freqs_cis) - - if kv_cache is not None: - # We are generating token by token - k_cache, v_cache = kv_cache - k = torch.cat([k_cache, k], dim=1) - v = torch.cat([v_cache, v], dim=1) - new_kv_cache = (k, v) - else: - new_kv_cache = None - - # Repeat KV heads to match Q heads (GQA) - k = repeat_kv(k, self.n_rep) - v = repeat_kv(v, self.n_rep) - - # Move heads to correct dimension: (B, n_heads, T, head_dim) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Flash attention natively supported via scaled_dot_product_attention - is_causal = (kv_cache is None and T > 1) - try: - y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.resid_dropout.p if self.training else 0.0, is_causal=is_causal) - except Exception: - # Manual attention fallback for older architectures (like P100 sm_60) - attn_weights = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - if is_causal: - # Use a causal mask - causal_mask = torch.tril(torch.ones(T, T, device=q.device)).view(1, 1, T, T) - attn_weights = attn_weights.masked_fill(causal_mask == 0, float('-inf')) - - attn_weights = F.softmax(attn_weights, dim=-1) - if self.training: - attn_weights = self.resid_dropout(attn_weights) - - y = attn_weights @ v - - y = y.transpose(1, 2).contiguous().view(B, T, C) - y = self.resid_dropout(self.wo(y)) - - return y, new_kv_cache - -class ExpertFFN(nn.Module): - def __init__(self, config: SageConfig): - super().__init__() - self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False) - self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # SwiGLU activation structure - hidden = F.silu(self.w1(x)) * self.w3(x) - return self.dropout(self.w2(hidden)) - -class MoE(nn.Module): - def __init__(self, config: SageConfig): - super().__init__() - self.n_experts = config.n_experts - self.top_k = config.num_experts_per_tok - self.d_model = config.d_model - - self.router = nn.Linear(self.d_model, self.n_experts, bias=False) - self.experts = nn.ModuleList([ExpertFFN(config) for _ in range(self.n_experts)]) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, T, C = x.size() - x_flat = x.view(-1, C) # [B*T, C] - - router_logits = self.router(x_flat) # [B*T, n_experts] - routing_weights = F.softmax(router_logits, dim=-1) - - # Select Top K experts - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # [B*T, top_k] - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # re-normalize - - final_out = torch.zeros_like(x_flat) - - # Iterate over experts and compute their outputs - for i, expert in enumerate(self.experts): - # Find which tokens chose this expert - expert_mask = (selected_experts == i) - token_idx, kth_expert = torch.where(expert_mask) - - if token_idx.shape[0] > 0: - expert_inputs = x_flat[token_idx] - expert_outputs = expert(expert_inputs) - - # Apply router weight - weights = routing_weights[token_idx, kth_expert].unsqueeze(-1) - final_out[token_idx] += expert_outputs * weights - - return final_out.view(B, T, C) - -class TransformerBlock(nn.Module): - def __init__(self, config: SageConfig): - super().__init__() - self.norm1 = nn.LayerNorm(config.d_model) - self.attn = CausalSelfAttention(config) - self.norm2 = nn.LayerNorm(config.d_model) - self.moe = MoE(config) - - def forward( - self, - x: torch.Tensor, - freqs_cis: torch.Tensor, - kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - # Pre-LayerNorm architecture - h, new_kv_cache = self.attn(self.norm1(x), freqs_cis, kv_cache) - x = x + h - x = x + self.moe(self.norm2(x)) - return x, new_kv_cache - -class SageModel(nn.Module): - def __init__(self, config: SageConfig): - super().__init__() - self.config = config - - self.wte = nn.Embedding(config.vocab_size, config.d_model) - self.drop = nn.Dropout(config.dropout) - - self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) - - self.ln_f = nn.LayerNorm(config.d_model) - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) - - # Weight tying - self.wte.weight = self.lm_head.weight - - # Precompute RoPE frequencies - self.register_buffer("freqs_cis", precompute_freqs_cis(config.d_model // config.n_heads, config.max_seq_len * 2), persistent=False) - - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.LayerNorm): - torch.nn.init.zeros_(module.bias) - torch.nn.init.ones_(module.weight) - - def forward( - self, - idx: torch.Tensor, - kv_caches: Optional[list] = None - ) -> Tuple[torch.Tensor, Optional[list]]: - B, T = idx.size() - - if kv_caches is not None: - # generating context, token is at specific position - start_pos = kv_caches[0][0].shape[1] - else: - start_pos = 0 - - freqs_cis = self.freqs_cis[start_pos : start_pos + T] - - x = self.drop(self.wte(idx)) - - new_kv_caches = [] - for i, layer in enumerate(self.layers): - kv_cache = kv_caches[i] if kv_caches else None - - # Use gradient checkpointing during training - if self.training and kv_cache is None: - def create_custom_forward(module): - def custom_forward(x_in, freqs_cis_in): - return module(x_in, freqs_cis_in, None) - return custom_forward - - x, new_kv_cache = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - x, freqs_cis, - use_reentrant=False - ) - else: - x, new_kv_cache = layer(x, freqs_cis, kv_cache) - - if new_kv_cache is not None: - new_kv_caches.append(new_kv_cache) - - x = self.ln_f(x) - logits = self.lm_head(x) # [B, T, vocab_size] - - return logits, new_kv_caches if len(new_kv_caches) > 0 else None diff --git a/sage/optimize.py b/sage/optimize.py deleted file mode 100644 index 87dfecfbf9d8482ad979bc1ed023002571d44ea5..0000000000000000000000000000000000000000 --- a/sage/optimize.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -SAGE Optimization Layer -======================= -Post-training quantization (INT8), optional pruning, and knowledge-distillation -loss utilities. -""" - -import torch -import torch.nn as nn -import torch.nn.utils.prune as prune -from typing import Optional - -from .model import SageModel -from .config import SageConfig -from .utils import setup_logger - -logger = setup_logger("sage.optimize") - - -# =================================================================== -# INT8 Dynamic Quantization -# =================================================================== - -def quantize_int8(model: SageModel) -> nn.Module: - """ - Apply dynamic INT8 quantization to all Linear layers in the model. - - This reduces model size by ~2-4x and can speed up CPU inference. - The model is moved to CPU before quantization because PyTorch's - dynamic quantization only supports CPU tensors. - - Returns - ------- - nn.Module โ€” the quantized model (on CPU). - """ - base_model = getattr(model, "module", model) - base_model = base_model.cpu().eval() - - quantized = torch.quantization.quantize_dynamic( - base_model, - {nn.Linear}, # quantize all linear layers - dtype=torch.qint8, - ) - - # Report size reduction - orig_size = sum(p.numel() * p.element_size() for p in base_model.parameters()) - # Quantized parameters may not report element_size correctly, - # so we estimate based on INT8 = 1 byte per weight. - quant_size = sum(p.numel() for p in quantized.parameters()) # * 1 byte - logger.info( - f"Quantization complete. " - f"Original: {orig_size / 1e6:.1f} MB โ†’ Quantized: ~{quant_size / 1e6:.1f} MB (INT8)" - ) - return quantized - - -# =================================================================== -# Weight Pruning -# =================================================================== - -def prune_model(model: SageModel, amount: float = 0.3) -> SageModel: - """ - Apply unstructured L1 pruning to all Linear layers, removing the - *amount* fraction of weights with the smallest magnitude. - - Parameters - ---------- - model : SageModel - amount : float - Fraction of weights to prune (0.0 โ€“ 1.0). - - Returns - ------- - SageModel โ€” the pruned model (pruning masks are permanent after this call). - """ - pruned_count = 0 - total_count = 0 - - base_model = getattr(model, "module", model) - for name, module in base_model.named_modules(): - if isinstance(module, nn.Linear): - prune.l1_unstructured(module, name="weight", amount=amount) - prune.remove(module, "weight") # make the pruning permanent - pruned_count += (module.weight == 0).sum().item() - total_count += module.weight.numel() - - sparsity = pruned_count / max(total_count, 1) * 100 - logger.info( - f"Pruning complete. {pruned_count:,} / {total_count:,} weights zeroed " - f"({sparsity:.1f}% sparsity)" - ) - return model - - -# =================================================================== -# Knowledge Distillation Loss -# =================================================================== - -def distillation_loss( - student_logits: torch.Tensor, - teacher_logits: torch.Tensor, - labels: torch.Tensor, - temperature: float = 2.0, - alpha: float = 0.5, - ignore_index: int = -100, -) -> torch.Tensor: - """ - Combined knowledge-distillation loss. - - ``L = alpha * KL(softmax(teacher/T), softmax(student/T)) * T^2 - + (1 - alpha) * CE(student, labels)`` - - Parameters - ---------- - student_logits : Tensor [B, T, V] - teacher_logits : Tensor [B, T, V] - labels : Tensor [B, T] - temperature : float - alpha : float โ€” weight for the distillation term (0 โ†’ pure CE, 1 โ†’ pure KD). - ignore_index : int โ€” label value to ignore in cross-entropy. - - Returns - ------- - Tensor (scalar) - """ - # Soft targets - soft_student = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1) - soft_teacher = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1) - - kd_loss = torch.nn.functional.kl_div( - soft_student.view(-1, soft_student.size(-1)), - soft_teacher.view(-1, soft_teacher.size(-1)), - reduction="batchmean", - ) * (temperature ** 2) - - # Hard-label cross-entropy - ce_loss = torch.nn.functional.cross_entropy( - student_logits.view(-1, student_logits.size(-1)), - labels.view(-1), - ignore_index=ignore_index, - ) - - return alpha * kd_loss + (1 - alpha) * ce_loss - - -# =================================================================== -# torch.compile wrapper (PyTorch 2.0+) -# =================================================================== - -def try_compile(model: nn.Module) -> nn.Module: - """ - Attempt to compile the model with ``torch.compile`` for faster - execution. Falls back gracefully if compilation is not available. - """ - if hasattr(torch, "compile"): - try: - compiled = torch.compile(model) - logger.info("Model compiled with torch.compile for accelerated execution.") - return compiled - except Exception as e: - logger.warning(f"torch.compile failed ({e}). Using eager mode.") - else: - logger.info("torch.compile not available (requires PyTorch 2.0+). Using eager mode.") - return model diff --git a/sage/train.py b/sage/train.py deleted file mode 100644 index 3f790babea91796b19ec207539bff7019286ea1b..0000000000000000000000000000000000000000 --- a/sage/train.py +++ /dev/null @@ -1,266 +0,0 @@ -""" -SAGE Training System -==================== -Complete training loop with AdamW, cosine-decay LR schedule, mixed-precision -(AMP), gradient accumulation, gradient clipping, and checkpoint management. -""" - -import math -import time -import torch -import torch.nn as nn -from torch.amp import GradScaler, autocast -from tqdm import tqdm -import wandb -from typing import Optional - -from .config import SageConfig -from .model import SageModel -from .data import SageTokenizer, create_dataloader -from .utils import setup_logger, save_checkpoint, load_checkpoint - -logger = setup_logger("sage.train") - - -# --------------------------------------------------------------------------- -# Learning-rate scheduler helpers -# --------------------------------------------------------------------------- - -def get_lr(step: int, config: SageConfig, total_steps: int) -> float: - """Cosine decay with linear warmup. Returns the learning rate for *step*.""" - if step < config.warmup_steps: - # Linear warmup - return config.learning_rate * (step + 1) / config.warmup_steps - - # Cosine decay phase - decay_steps = total_steps - config.warmup_steps - progress = (step - config.warmup_steps) / max(1, decay_steps) - coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) - return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate) - - -def set_lr(optimizer: torch.optim.Optimizer, lr: float) -> None: - """Manually sets the learning rate for every parameter group.""" - for pg in optimizer.param_groups: - pg["lr"] = lr - - -# --------------------------------------------------------------------------- -# Optimizer factory -# --------------------------------------------------------------------------- - -def create_optimizer(model: SageModel, config: SageConfig) -> torch.optim.AdamW: - """ - Create an AdamW optimizer with weight-decay applied only to weight - matrices (not biases or LayerNorm parameters). - """ - decay_params = [] - no_decay_params = [] - - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - # Biases and LayerNorm weights should not be decayed - if param.ndim == 1 or "bias" in name: - no_decay_params.append(param) - else: - decay_params.append(param) - - param_groups = [ - {"params": decay_params, "weight_decay": config.weight_decay}, - {"params": no_decay_params, "weight_decay": 0.0}, - ] - - # Enable Fused AdamW for 10% speedup if CUDA is active - use_fused = torch.cuda.is_available() and 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames - optimizer = torch.optim.AdamW( - param_groups, - lr=config.learning_rate, - betas=(0.9, 0.95), - eps=1e-8, - fused=use_fused, - ) - return optimizer - - -# --------------------------------------------------------------------------- -# Main training loop -# --------------------------------------------------------------------------- - -def train( - model: SageModel, - config: SageConfig, - total_steps: int = 500, - dataset_name: str = "roneneldan/TinyStories", - resume: bool = True, - tokenizer: Optional[SageTokenizer] = None, -) -> SageModel: - """ - Run pre-training for *total_steps* gradient-update steps. - - Parameters - ---------- - model : SageModel - The model to train (will be moved to config.device). - config : SageConfig - Hyperparameters. - total_steps : int - Number of optimizer steps to run. - dataset_name : str - HuggingFace dataset identifier. - resume : bool - If True, attempt to load the latest checkpoint before training. - tokenizer : SageTokenizer, optional - Tokenizer instance; one will be created if not supplied. - - Returns - ------- - SageModel - The trained model (on config.device). - """ - # --- TURBO MODE: TF32 & COMPILE --- - if torch.cuda.is_available(): - torch.set_float32_matmul_precision('high') - - device = config.device - model = model.to(device) - - # Wrap model with torch.compile for graph-level optimization - if hasattr(torch, "compile"): - try: - logger.info("Turbo Mode: Compiling model graph...") - base = getattr(model, "module", model) - compiled_base = torch.compile(base, mode="reduce-overhead") - if hasattr(model, "module"): - model.module = compiled_base - else: - model = compiled_base - except (ValueError, RuntimeError, ImportError) as e: - # Graceful fallback: numpy compatibility issues or other compilation errors - logger.warning(f"torch.compile failed ({type(e).__name__}), proceeding without optimization: {str(e)[:100]}") - - tok = tokenizer or SageTokenizer() - optimizer = create_optimizer(model, config) - - # ------- resume from checkpoint if available ------- - start_step = 0 - if resume: - model, optimizer, start_step = load_checkpoint( - model, optimizer, config.checkpoint_dir, device=str(device) - ) - if start_step >= total_steps: - logger.info("Checkpoint already at or past requested steps. Nothing to do.") - return model - - # ------- mixed precision setup ------- - use_amp = device.type == "cuda" - # prefer bf16 if the GPU supports it - amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 - scaler = GradScaler("cuda", enabled=(use_amp and amp_dtype == torch.float16)) - - # ------- data loader ------- - loader = create_dataloader(config, dataset_name=dataset_name, tokenizer=tok) - data_iter = iter(loader) - - # ------- W&B Logging ------- - wandb.init( - project=config.project_name, - name=f"pretrain-{time.strftime('%Y%m%d-%H%M')}", - config=config.__dict__, - ) - - # ------- gradient checkpointing (saves VRAM) ------- - base_model = getattr(model, "module", model) - if hasattr(base_model, "layers"): - for layer in base_model.layers: - layer: nn.Module - # PyTorch gradient checkpointing - try: - from torch.utils.checkpoint import checkpoint # noqa: F401 - # We wrap the forward below instead, using it at call-site. - except ImportError: - pass - - # ------- training loop ------- - model.train() - accum_loss = 0.0 - log_interval = 10 - t0 = time.time() - - pbar = tqdm(range(start_step, total_steps), desc="Training", unit="step") - micro_step = 0 - - for step in pbar: - # Update learning rate - lr = get_lr(step, config, total_steps) - set_lr(optimizer, lr) - - # Accumulate gradients over multiple micro-batches - optimizer.zero_grad(set_to_none=True) - step_loss = 0.0 - - for micro in range(config.gradient_accumulation_steps): - try: - batch = next(data_iter) - except StopIteration: - # Restart the data stream when exhausted - data_iter = iter(loader) - batch = next(data_iter) - - batch = batch.to(device) - inputs = batch[:, :-1] # all tokens except last - targets = batch[:, 1:] # all tokens except first - - with autocast(device.type, dtype=amp_dtype, enabled=use_amp): - logits, _ = model(inputs) - loss = nn.functional.cross_entropy( - logits.reshape(-1, logits.size(-1)), - targets.reshape(-1), - ignore_index=tok.pad_token_id, - ) - # Scale loss by accumulation steps so the effective loss - # is independent of the number of micro-batches. - loss = loss / config.gradient_accumulation_steps - - scaler.scale(loss).backward() - step_loss += loss.item() - - # Gradient clipping (unscale first for correct norm computation) - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) - - scaler.step(optimizer) - scaler.update() - - accum_loss += step_loss - micro_step += 1 - - # ------- logging ------- - if (step + 1) % log_interval == 0 or step == total_steps - 1: - avg_loss = accum_loss / log_interval - elapsed = time.time() - t0 - perplexity = math.exp(min(avg_loss, 20)) # clamp to avoid overflow - pbar.set_postfix( - loss=f"{avg_loss:.4f}", - ppl=f"{perplexity:.2f}", - lr=f"{lr:.2e}", - elapsed=f"{elapsed:.1f}s", - ) - logger.info( - f"step={step+1} | loss={avg_loss:.4f} | ppl={perplexity:.2f} | lr={lr:.2e}" - ) - wandb.log({ - "train/loss": avg_loss, - "train/perplexity": perplexity, - "train/lr": lr, - }, step=step + 1) - accum_loss = 0.0 - - # ------- checkpoint every 100 steps ------- - if (step + 1) % 100 == 0 or step == total_steps - 1: - save_checkpoint(model, optimizer, step + 1, config.checkpoint_dir) - logger.info(f"Checkpoint saved at step {step + 1}") - - logger.info("Training complete.") - wandb.finish() - return model diff --git a/sage/utils.py b/sage/utils.py deleted file mode 100644 index 2181e29989c9a16d9bef0fcab353dab9090e161d..0000000000000000000000000000000000000000 --- a/sage/utils.py +++ /dev/null @@ -1,143 +0,0 @@ -import os -import logging -import torch -from typing import Optional, Tuple - -def _get_logger(name: str) -> logging.Logger: - """Simple logger getter to avoid circular imports.""" - logger = logging.getLogger(name) - if not logger.handlers: - logger.setLevel(logging.INFO) - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - logger.propagate = False - return logger - -def get_compatible_device() -> torch.device: - """ - Returns the best available device with CUDA compatibility checking. - - Automatically detects GPU compute capability and falls back to CPU - if the current PyTorch installation doesn't support the GPU. - """ - logger = _get_logger("sage.device") - - # Check CUDA availability and compatibility - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(0) - capability = torch.cuda.get_device_capability() - major, minor = capability - sm_version = f"sm_{major}{minor}" - - logger.info(f"Detected GPU: {gpu_name} (CUDA Capability: {sm_version})") - - # PyTorch 2.0+ minimum is sm_70, PyTorch 1.13 supports sm_60 - # Check if we can actually run model operations (embedding, linear, etc.) - try: - # Test 1: Basic tensor operation - test_tensor = torch.zeros(2, 4).cuda() - _ = test_tensor + test_tensor - - # Test 2: Embedding (this is where P100/sm_60 often fails) - import torch.nn as nn - test_emb = nn.Embedding(10, 8).cuda() - test_indices = torch.tensor([0, 1, 2], dtype=torch.long).cuda() - _ = test_emb(test_indices) - - # Test 3: Linear layer - test_linear = nn.Linear(8, 4).cuda() - _ = test_linear(test_emb(test_indices)) - - logger.info(f"โœ… GPU is compatible with current PyTorch") - return torch.device("cuda") - except RuntimeError as e: - if "no kernel image is available" in str(e).lower(): - logger.warning(f"โš ๏ธ GPU {sm_version} not supported by current PyTorch") - logger.warning(f" Current PyTorch supports: {torch.cuda.get_arch_list() or 'sm_70+'}") - logger.warning(f" Install compatible PyTorch:") - if major < 7: - logger.warning(f" !pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu121") - else: - logger.warning(f" !pip install torch --index-url https://download.pytorch.org/whl/cu118") - logger.warning(f" Falling back to CPU...") - else: - raise - - # Check MPS (Apple Silicon) - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - logger.info("Using Apple Silicon (MPS)") - return torch.device("mps") - - logger.info("Using CPU") - return torch.device("cpu") - -def setup_logger(name: str) -> logging.Logger: - """Sets up a standardized logger for the SAGE system.""" - logger = logging.getLogger(name) - if not logger.handlers: - logger.setLevel(logging.INFO) - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - # Prevent propagation to the root logger to avoid double printing - logger.propagate = False - return logger - -def save_checkpoint( - model: torch.nn.Module, - optimizer: Optional[torch.optim.Optimizer], - step: int, - checkpoint_dir: str, - filename: str = "sage_latest.pt" -) -> str: - """Saves the model and optimizer state to a checkpoint file.""" - os.makedirs(checkpoint_dir, exist_ok=True) - path = os.path.join(checkpoint_dir, filename) - - base_model = getattr(model, "module", model) - checkpoint = { - 'step': step, - 'model_state_dict': base_model.state_dict(), - } - - if optimizer is not None: - checkpoint['optimizer_state_dict'] = optimizer.state_dict() - - torch.save(checkpoint, path) - return path - -def load_checkpoint( - model: torch.nn.Module, - optimizer: Optional[torch.optim.Optimizer], - checkpoint_dir: str, - filename: str = "sage_latest.pt", - device: str = "cpu" -) -> Tuple[torch.nn.Module, Optional[torch.optim.Optimizer], int]: - """Loads a checkpoint and restores the model and optimizer states.""" - path = os.path.join(checkpoint_dir, filename) - - if not os.path.exists(path): - logger = setup_logger("utils") - logger.warning(f"No checkpoint found at {path}. Starting from scratch.") - return model, optimizer, 0 - - # Load to CPU first to avoid VRAM spikes, then the module will be moved later if needed - checkpoint = torch.load(path, map_location=device) - - base_model = getattr(model, "module", model) - base_model.load_state_dict(checkpoint['model_state_dict'], strict=False) - - if optimizer is not None and 'optimizer_state_dict' in checkpoint: - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - step = checkpoint.get('step', 0) - - logger = setup_logger("utils") - logger.info(f"Loaded checkpoint from {path} at step {step}") - - return model, optimizer, step diff --git a/sage_single.py b/sage_single.py deleted file mode 100644 index c6131306a895ac2f258c126285676cf1db0a8247..0000000000000000000000000000000000000000 --- a/sage_single.py +++ /dev/null @@ -1,824 +0,0 @@ -#!/usr/bin/env python3 -""" -SAGE โ€” Self-Adaptive General Engine (Single-File Edition) -========================================================= -A complete mini-LLM in one file. Run with: - - python sage_single.py - -All architecture, data, training, inference, fine-tuning, quantization, -RAG, and CLI components are included below. -""" - -import os -import re -import sys -import math -import copy -import time -import random -import logging -from dataclasses import dataclass -from typing import Iterator, List, Optional, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.utils.prune as prune -from torch.amp import GradScaler, autocast -from torch.utils.data import IterableDataset, DataLoader -from tqdm import tqdm -import tiktoken -import wandb - -__version__ = "1.0.0" - - -# =================================================================== -# Section 1 โ€” Configuration -# =================================================================== - -@dataclass -class SageConfig: - d_model: int = 512 - n_heads: int = 8 - n_kv_heads: int = 4 - n_layers: int = 6 - d_ff: int = 2048 - n_experts: int = 4 - num_experts_per_tok: int = 2 - vocab_size: int = 100277 - max_seq_len: int = 1024 - dropout: float = 0.1 - batch_size: int = 4 - gradient_accumulation_steps: int = 16 - learning_rate: float = 3e-4 - min_learning_rate: float = 1e-5 - warmup_steps: int = 100 - weight_decay: float = 0.01 - max_grad_norm: float = 1.0 - checkpoint_dir: str = "checkpoints" - project_name: str = "sage-v2" - - @property - def device(self): - if torch.cuda.is_available(): - return torch.device("cuda") - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - return torch.device("mps") - return torch.device("cpu") - - -# =================================================================== -# Section 2 โ€” Logging & Checkpoint Utilities -# =================================================================== - -def setup_logger(name: str) -> logging.Logger: - logger = logging.getLogger(name) - if not logger.handlers: - logger.setLevel(logging.INFO) - h = logging.StreamHandler() - h.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s", datefmt="%H:%M:%S")) - logger.addHandler(h) - logger.propagate = False - return logger - -logger = setup_logger("sage") - -def save_checkpoint(model, optimizer, step, checkpoint_dir, filename="sage_latest.pt"): - os.makedirs(checkpoint_dir, exist_ok=True) - path = os.path.join(checkpoint_dir, filename) - base = getattr(model, "module", model) - ckpt = {"step": step, "model_state_dict": base.state_dict()} - if optimizer is not None: - ckpt["optimizer_state_dict"] = optimizer.state_dict() - torch.save(ckpt, path) - return path - -def load_checkpoint(model, optimizer, checkpoint_dir, filename="sage_latest.pt", device="cpu"): - path = os.path.join(checkpoint_dir, filename) - if not os.path.exists(path): - logger.warning(f"No checkpoint at {path}, starting fresh.") - return model, optimizer, 0 - ckpt = torch.load(path, map_location=device) - base = getattr(model, "module", model) - base.load_state_dict(ckpt["model_state_dict"], strict=False) - if optimizer and "optimizer_state_dict" in ckpt: - optimizer.load_state_dict(ckpt["optimizer_state_dict"]) - step = ckpt.get("step", 0) - logger.info(f"Loaded checkpoint from {path} (step {step})") - return model, optimizer, step - - -# =================================================================== -# Section 3 โ€” Tokenizer -# =================================================================== - -class SageTokenizer: - def __init__(self, encoding_name="cl100k_base"): - self.enc = tiktoken.get_encoding(encoding_name) - self.eos_token_id = self.enc.n_vocab - 1 - self.pad_token_id = self.enc.n_vocab - 2 - self.vocab_size = self.enc.n_vocab - - def encode(self, text, add_eos=False): - tokens = self.enc.encode(text, allowed_special="all") - if add_eos: - tokens.append(self.eos_token_id) - return tokens - - def decode(self, tokens): - filtered = [t for t in tokens if t not in (self.eos_token_id, self.pad_token_id)] - return self.enc.decode(filtered) - - -# =================================================================== -# Section 4 โ€” Model Architecture (RoPE, Attention, MoE, Transformer) -# =================================================================== - -def precompute_freqs_cis(dim, end, theta=10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: dim // 2].float() / dim)) - t = torch.arange(end, dtype=torch.float32) - freqs = torch.outer(t, freqs) - return torch.polar(torch.ones_like(freqs), freqs) - -def apply_rotary_emb(xq, xk, freqs_cis): - # Ensure freqs_cis is complex (DataParallel can sometimes replicate it as real) - if not torch.is_complex(freqs_cis) and freqs_cis.shape[-1] == 2: - freqs_cis = torch.view_as_complex(freqs_cis) - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - fc = freqs_cis.unsqueeze(0).unsqueeze(2) - xq_out = torch.view_as_real(xq_ * fc).flatten(3) - xk_out = torch.view_as_real(xk_ * fc).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - -def repeat_kv(x, n_rep): - if n_rep == 1: return x - B, T, n_kv_heads, head_dim = x.size() - return x[:, :, :, None, :].expand(B, T, n_kv_heads, n_rep, head_dim).reshape(B, T, n_kv_heads * n_rep, head_dim) - -class CausalSelfAttention(nn.Module): - def __init__(self, config): - super().__init__() - self.n_heads = config.n_heads - self.n_kv_heads = config.n_kv_heads - self.n_rep = self.n_heads // self.n_kv_heads - self.d_model = config.d_model - self.head_dim = config.d_model // config.n_heads - self.wq = nn.Linear(config.d_model, config.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(config.d_model, config.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(config.d_model, config.d_model, bias=False) - self.resid_dropout = nn.Dropout(config.dropout) - - def forward(self, x, freqs_cis, kv_cache=None): - B, T, C = x.size() - q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = q.view(B, T, self.n_heads, self.head_dim) - k = k.view(B, T, self.n_kv_heads, self.head_dim) - v = v.view(B, T, self.n_kv_heads, self.head_dim) - q, k = apply_rotary_emb(q, k, freqs_cis) - if kv_cache is not None: - k = torch.cat([kv_cache[0], k], dim=1) - v = torch.cat([kv_cache[1], v], dim=1) - new_kv = (k, v) - else: - new_kv = None - k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) - q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) - is_causal = kv_cache is None and T > 1 - try: - y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0 if not self.training else 0.1, is_causal=is_causal) - except Exception: - attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) - if is_causal: - mask = torch.tril(torch.ones(T, T, device=q.device)).view(1, 1, T, T) - attn = attn.masked_fill(mask == 0, float('-inf')) - attn = F.softmax(attn, dim=-1) - if self.training: attn = F.dropout(attn, p=0.1) - y = attn @ v - y = y.transpose(1, 2).contiguous().view(B, T, C) - return self.resid_dropout(self.wo(y)), new_kv - -class ExpertFFN(nn.Module): - def __init__(self, config): - super().__init__() - self.w1 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.w2 = nn.Linear(config.d_ff, config.d_model, bias=False) - self.w3 = nn.Linear(config.d_model, config.d_ff, bias=False) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x): - return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) - -class MoE(nn.Module): - def __init__(self, config): - super().__init__() - self.n_experts = config.n_experts - self.top_k = config.num_experts_per_tok - self.router = nn.Linear(config.d_model, config.n_experts, bias=False) - self.experts = nn.ModuleList([ExpertFFN(config) for _ in range(config.n_experts)]) - - def forward(self, x): - B, T, C = x.size() - flat = x.view(-1, C) - weights = F.softmax(self.router(flat), dim=-1) - weights, indices = torch.topk(weights, self.top_k, dim=-1) - weights = weights / weights.sum(dim=-1, keepdim=True) - out = torch.zeros_like(flat) - for i, expert in enumerate(self.experts): - mask = (indices == i) - tok_idx, kth = torch.where(mask) - if tok_idx.shape[0] > 0: - out[tok_idx] += expert(flat[tok_idx]) * weights[tok_idx, kth].unsqueeze(-1) - return out.view(B, T, C) - -class TransformerBlock(nn.Module): - def __init__(self, config): - super().__init__() - self.norm1 = nn.LayerNorm(config.d_model) - self.attn = CausalSelfAttention(config) - self.norm2 = nn.LayerNorm(config.d_model) - self.moe = MoE(config) - - def forward(self, x, freqs_cis, kv_cache=None): - h, new_kv = self.attn(self.norm1(x), freqs_cis, kv_cache) - x = x + h - x = x + self.moe(self.norm2(x)) - return x, new_kv - -class SageModel(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.wte = nn.Embedding(config.vocab_size, config.d_model) - self.drop = nn.Dropout(config.dropout) - self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) - self.ln_f = nn.LayerNorm(config.d_model) - self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # tied - self.wte.weight = self.lm_head.weight - self.register_buffer("freqs_cis", precompute_freqs_cis(config.d_model // config.n_heads, config.max_seq_len * 2), persistent=False) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - if m.bias is not None: nn.init.zeros_(m.bias) - elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) - elif isinstance(m, nn.LayerNorm): - nn.init.ones_(m.weight); nn.init.zeros_(m.bias) - - def forward(self, idx, kv_caches=None): - B, T = idx.size() - start = kv_caches[0][0].shape[1] if kv_caches else 0 - fc = self.freqs_cis[start:start + T] - x = self.drop(self.wte(idx)) - new_kvs = [] - for i, layer in enumerate(self.layers): - kv = kv_caches[i] if kv_caches else None - if self.training and kv is None: - def create_custom_forward(module): - def custom_forward(x_in, freqs_cis_in): - return module(x_in, freqs_cis_in, None) - return custom_forward - x, nkv = torch.utils.checkpoint.checkpoint(create_custom_forward(layer), x, fc, use_reentrant=False) - else: - x, nkv = layer(x, fc, kv) - if nkv is not None: new_kvs.append(nkv) - return self.lm_head(self.ln_f(x)), new_kvs if new_kvs else None - - -# =================================================================== -# Section 5 โ€” Data Pipeline -# =================================================================== - -_HTML_RE = re.compile(r"<[^>]+>") - -def clean_text(text): - text = _HTML_RE.sub("", text) - text = re.sub(r"[ \t]+", " ", text) - text = re.sub(r"\n{3,}", "\n\n", text) - return text.strip() - -class StreamingTextDataset(IterableDataset): - def __init__(self, dataset_name="HuggingFaceFW/fineweb-edu", split="train", seq_len=512, tokenizer=None, buffer_size=1000, text_field="text"): - super().__init__() - self.dataset_name, self.split, self.seq_len = dataset_name, split, seq_len - self.tokenizer = tokenizer or SageTokenizer() - self.buffer_size, self.text_field = buffer_size, text_field - if "fineweb-edu" in dataset_name.lower(): self.text_field = "text" - elif "tinystories" in dataset_name.lower(): self.text_field = "text" - - def _tokens(self): - from datasets import load_dataset - ds = load_dataset(self.dataset_name, split=self.split, streaming=True) - for s in ds: - raw = s.get(self.text_field, "") - if not raw or len(raw) < 50: continue - text = clean_text(raw) - yield from self.tokenizer.encode(text, add_eos=True) - - def __iter__(self): - chunk, buf = [], [] - for tok in self._tokens(): - chunk.append(tok) - if len(chunk) == self.seq_len + 1: - buf.append(torch.tensor(chunk, dtype=torch.long)) - chunk = [] - if len(buf) >= self.buffer_size: - random.shuffle(buf) - while len(buf) > self.buffer_size // 2: yield buf.pop() - random.shuffle(buf) - yield from buf - -def create_dataloader(config, dataset_name="HuggingFaceFW/fineweb-edu", tokenizer=None): - tok = tokenizer or SageTokenizer() - ds = StreamingTextDataset(dataset_name=dataset_name, seq_len=config.max_seq_len, tokenizer=tok) - return DataLoader(ds, batch_size=config.batch_size, num_workers=2, pin_memory=True, drop_last=True) - - -# =================================================================== -# Section 6 โ€” Training -# =================================================================== - -def get_lr(step, config, total_steps): - if step < config.warmup_steps: - return config.learning_rate * (step + 1) / config.warmup_steps - progress = (step - config.warmup_steps) / max(1, total_steps - config.warmup_steps) - coeff = 0.5 * (1.0 + math.cos(math.pi * progress)) - return config.min_learning_rate + coeff * (config.learning_rate - config.min_learning_rate) - -def create_optimizer(model, config): - decay, no_decay = [], [] - for n, p in model.named_parameters(): - if not p.requires_grad: continue - (no_decay if p.ndim == 1 or "bias" in n else decay).append(p) - # Enable Fused AdamW for 10% speedup if CUDA is active - use_fused = torch.cuda.is_available() and 'fused' in torch.optim.AdamW.__init__.__code__.co_varnames - return torch.optim.AdamW([ - {"params": decay, "weight_decay": config.weight_decay}, - {"params": no_decay, "weight_decay": 0.0}, - ], lr=config.learning_rate, betas=(0.9, 0.95), fused=use_fused) - -def train_model(model, config, total_steps=500, dataset_name="roneneldan/TinyStories", resume=True, tokenizer=None): - device = config.device - # --- TURBO MODE: TF32 & COMPILE --- - if torch.cuda.is_available(): - torch.set_float32_matmul_precision('high') - - model = model.to(device) - tok = tokenizer or SageTokenizer() - - # Wrap model with torch.compile for graph-level optimization - # mode="reduce-overhead" is ideal for smaller-to-medium models like SAGE - if hasattr(torch, "compile"): - try: - logger.info("Turbo Mode: Compiling model graph...") - # Compile the base model (unwrapped from DataParallel if present) - base = getattr(model, "module", model) - compiled_base = torch.compile(base, mode="reduce-overhead") - if hasattr(model, "module"): - model.module = compiled_base - else: - model = compiled_base - except (ValueError, RuntimeError, ImportError) as e: - # Graceful fallback: numpy compatibility issues or other compilation errors - logger.warning(f"torch.compile failed ({type(e).__name__}), proceeding without optimization: {str(e)[:100]}") - # Continue with uncompiled model - - opt = create_optimizer(model, config) - start_step = 0 - if resume: - model, opt, start_step = load_checkpoint(model, opt, config.checkpoint_dir, device=str(device)) - if start_step >= total_steps: return model - use_amp = device.type == "cuda" - amp_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16 - scaler = GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16) - loader = create_dataloader(config, dataset_name, tok) - data_iter = iter(loader) - wandb.init(project=config.project_name, name=f"pretrain-{time.strftime('%Y%m%d-%H%M')}", config=config.__dict__) - model.train() - accum_loss, t0 = 0.0, time.time() - pbar = tqdm(range(start_step, total_steps), desc="Training") - for step in pbar: - lr = get_lr(step, config, total_steps) - for pg in opt.param_groups: pg["lr"] = lr - opt.zero_grad(set_to_none=True) - step_loss = 0.0 - for _ in range(config.gradient_accumulation_steps): - try: batch = next(data_iter) - except StopIteration: data_iter = iter(loader); batch = next(data_iter) - batch = batch.to(device) - with autocast(device.type, dtype=amp_dtype, enabled=use_amp): - logits, _ = model(batch[:, :-1]) - loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), batch[:, 1:].reshape(-1), ignore_index=tok.pad_token_id) - loss = loss / config.gradient_accumulation_steps - scaler.scale(loss).backward() - step_loss += loss.item() - scaler.unscale_(opt) - nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) - scaler.step(opt); scaler.update() - accum_loss += step_loss - if (step + 1) % 10 == 0: - avg = accum_loss / 10 - pbar.set_postfix(loss=f"{avg:.4f}", ppl=f"{math.exp(min(avg,20)):.1f}", lr=f"{lr:.2e}") - wandb.log({"train/loss": avg, "train/perplexity": math.exp(min(avg, 20)), "train/lr": lr}, step=step + 1) - accum_loss = 0.0 - if (step + 1) % 100 == 0: - save_checkpoint(model, opt, step + 1, config.checkpoint_dir) - save_checkpoint(model, opt, total_steps, config.checkpoint_dir) - logger.info("Training complete.") - wandb.finish() - return model - - -# =================================================================== -# Section 7 โ€” Inference -# =================================================================== - -def sample_next(logits, temperature=0.8, top_k=50, top_p=0.9, greedy=False): - if greedy: return logits.argmax(-1, keepdim=True) - logits = logits / max(temperature, 1e-8) - if 0 < top_k < logits.size(-1): - v, _ = torch.topk(logits, top_k) - logits[logits < v[:, -1:]] = float("-inf") - if top_p < 1.0: - sorted_l, sorted_i = torch.sort(logits, descending=True) - cum = torch.cumsum(F.softmax(sorted_l, -1), -1) - mask = cum - F.softmax(sorted_l, -1) >= top_p - sorted_l[mask] = float("-inf") - logits = logits.scatter(1, sorted_i, sorted_l) - return torch.multinomial(F.softmax(logits, -1), 1) - -@torch.no_grad() -def generate(model, tokenizer, prompt, max_new=256, temperature=0.8, top_k=50, top_p=0.9, stream=True, device=None): - device = device or next(model.parameters()).device - base = getattr(model, "module", model) - base.eval() - ids = tokenizer.encode(prompt) or [tokenizer.eos_token_id] - inp = torch.tensor([ids], dtype=torch.long, device=device) - logits, kvs = base(inp) - gen = list(ids) - nl = logits[:, -1, :] - for _ in range(max_new): - nid = sample_next(nl, temperature, top_k, top_p) - tid = nid.item() - if tid == tokenizer.eos_token_id: break - gen.append(tid) - if stream: print(tokenizer.decode([tid]), end="", flush=True) - logits, kvs = base(nid.view(1, 1), kv_caches=kvs) - nl = logits[:, -1, :] - if stream: print() - base.train() - return tokenizer.decode(gen) - - -# =================================================================== -# Section 8 โ€” LoRA Fine-tuning -# =================================================================== - -class LoRALinear(nn.Module): - def __init__(self, original, rank=8, alpha=16.0): - super().__init__() - self.original = original - self.scaling = alpha / rank - device, dtype = original.weight.device, original.weight.dtype - self.lora_A = nn.Parameter(torch.randn(original.in_features, rank, device=device, dtype=dtype) * 0.01) - self.lora_B = nn.Parameter(torch.zeros(rank, original.out_features, device=device, dtype=dtype)) - original.weight.requires_grad = False - if original.bias is not None: original.bias.requires_grad = False - - def forward(self, x): - return self.original(x) + (x @ self.lora_A @ self.lora_B) * self.scaling - - def merge(self): - m = copy.deepcopy(self.original) - m.weight.data += (self.lora_B.T @ self.lora_A.T).T * self.scaling - m.weight.requires_grad = True - return m - -def inject_lora(model, rank=8, alpha=16.0): - base = getattr(model, "module", model) - for layer in base.layers: - a = layer.attn - for name in ("wq", "wk", "wv", "wo"): - setattr(a, name, LoRALinear(getattr(a, name), rank, alpha)) - tp = sum(p.numel() for p in base.parameters() if p.requires_grad) - logger.info(f"LoRA injected (rank={rank}). Trainable params: {tp:,}") - return model - -def merge_lora(model): - base = getattr(model, "module", model) - for layer in base.layers: - a = layer.attn - for name in ("wq", "wk", "wv", "wo"): - m = getattr(a, name) - if isinstance(m, LoRALinear): setattr(a, name, m.merge()) - logger.info("LoRA merged.") - return model - -INSTRUCTION_TEMPLATE = "### Instruction:\n{instruction}\n\n### Response:\n{response}" - -DEMO_SAMPLES = [ - {"instruction": "What is the capital of France?", "response": "The capital of France is Paris."}, - {"instruction": "Explain gravity simply.", "response": "Gravity pulls objects toward each other. More mass means stronger pull."}, - {"instruction": "Write a short poem about the ocean.", "response": "Waves crash on sandy shore,\nThe ocean sings forevermore.\nDeep blue meets the sky,\nSeagulls dance and clouds float by."}, - {"instruction": "What is 15 times 12?", "response": "15 times 12 equals 180."}, - {"instruction": "Summarize photosynthesis.", "response": "Plants convert sunlight, water, and CO2 into glucose and oxygen."}, - {"instruction": "Tell me a fun fact about space.", "response": "A day on Venus is longer than its year โ€” 243 Earth days to rotate vs 225 to orbit the Sun."}, - {"instruction": "How do airplanes fly?", "response": "Wings generate lift because air moves faster over the curved top, creating lower pressure above."}, - {"instruction": "What is machine learning?", "response": "ML is AI where computers learn patterns from data instead of being explicitly programmed."}, -] - -def create_instruction_batch(samples, tokenizer, max_len=512): - all_ids, all_masks = [], [] - for s in samples: - inst_text = f"### Instruction:\n{s['instruction'].strip()}\n\n### Response:\n" - full_text = inst_text + s["response"].strip() - inst_toks = tokenizer.encode(inst_text) - full_toks = tokenizer.encode(full_text, add_eos=True)[:max_len] - ni = min(len(inst_toks), len(full_toks)) - mask = [0] * ni + [1] * (len(full_toks) - ni) - pad = max_len - len(full_toks) - full_toks += [tokenizer.pad_token_id] * pad - mask += [0] * pad - all_ids.append(full_toks); all_masks.append(mask) - return {"input_ids": torch.tensor(all_ids), "labels": torch.tensor(all_ids), "loss_mask": torch.tensor(all_masks, dtype=torch.float32)} - -def finetune(model, config, samples=None, steps=200, use_lora=True, tokenizer=None): - device = config.device; model = model.to(device) - tok = tokenizer or SageTokenizer() - samples = samples or DEMO_SAMPLES - if use_lora: model = inject_lora(model) - opt = create_optimizer(model, config) - use_amp = device.type == "cuda" - amp_dtype = torch.bfloat16 if use_amp and torch.cuda.is_bf16_supported() else torch.float16 - scaler = GradScaler("cuda", enabled=use_amp and amp_dtype == torch.float16) - wandb.init(project=config.project_name, name=f"finetune-{time.strftime('%Y%m%d-%H%M')}", config=config.__dict__) - model.train(); accum = 0.0 - for step in tqdm(range(steps), desc="Fine-tuning"): - lr = get_lr(step, config, steps) - for pg in opt.param_groups: pg["lr"] = lr - batch = create_instruction_batch(random.choices(samples, k=min(config.batch_size, len(samples))), tok, config.max_seq_len) - ids, labels, mask = batch["input_ids"].to(device), batch["labels"].to(device), batch["loss_mask"].to(device) - opt.zero_grad(set_to_none=True) - with autocast(device.type, dtype=amp_dtype, enabled=use_amp): - logits, _ = model(ids) - sl, slb, sm = logits[:, :-1, :].contiguous(), labels[:, 1:].contiguous(), mask[:, 1:].contiguous() - ptl = F.cross_entropy(sl.view(-1, sl.size(-1)), slb.view(-1), reduction="none").view(slb.size()) - loss = (ptl * sm).sum() / sm.sum().clamp(min=1) - scaler.scale(loss).backward() - scaler.unscale_(opt); nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) - scaler.step(opt); scaler.update() - accum += loss.item() - if (step + 1) % 10 == 0: accum = 0.0 - if use_lora: model = merge_lora(model) - save_checkpoint(model, None, steps, config.checkpoint_dir, "sage_finetuned.pt") - logger.info("Fine-tuning complete.") - wandb.finish() - return model - - -# =================================================================== -# Section 9 โ€” Optimization (Quantize / Prune) -# =================================================================== - -def quantize_int8(model): - base = getattr(model, "module", model) - model = base.cpu().eval() - q = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8) - logger.info("INT8 quantization complete.") - return q - -def prune_model(model, amount=0.3): - base = getattr(model, "module", model) - for _, m in base.named_modules(): - if isinstance(m, nn.Linear): - prune.l1_unstructured(m, "weight", amount=amount) - prune.remove(m, "weight") - logger.info(f"Pruning complete ({amount*100:.0f}% sparsity target).") - return model - - -# =================================================================== -# Section 10 โ€” RAG & Memory -# =================================================================== - -def _embed(text, tokenizer, model, device): - toks = tokenizer.encode(text) - base = getattr(model, "module", model) - if not toks: return np.zeros(base.wte.weight.shape[1], dtype=np.float32) - with torch.no_grad(): - emb = base.wte(torch.tensor([toks], device=device)).mean(1) - emb = F.normalize(emb, p=2, dim=-1) - return emb.squeeze(0).cpu().numpy() - -class VectorStore: - def __init__(self, dim): - self.dim = dim; self.docs = []; self.index = None - try: - import faiss - self.index = faiss.IndexFlatIP(dim) - except ImportError: - logger.warning("FAISS not installed. RAG will use brute-force search.") - - def add(self, texts, embeddings): - if self.index is not None: - self.index.add(embeddings.astype(np.float32)) - else: - # Brute-force fallback - if not hasattr(self, '_embeddings'): - self._embeddings = [] - self._embeddings.extend(embeddings.astype(np.float32)) - self.docs.extend(texts) - - def search(self, qemb, k=3): - if not self.docs: return [] - k = min(k, len(self.docs)) - if self.index is not None: - scores, idx = self.index.search(qemb.reshape(1, -1).astype(np.float32), k) - return [(self.docs[i], float(s)) for s, i in zip(scores[0], idx[0]) if i >= 0] - else: - # Brute-force cosine similarity - import numpy as np - qemb = qemb.reshape(1, -1).astype(np.float32) - embs = np.array(self._embeddings) - sims = np.dot(embs, qemb.T).flatten() - top_k = np.argsort(sims)[-k:][::-1] - return [(self.docs[i], float(sims[i])) for i in top_k] - - @property - def size(self): return len(self.docs) - -class RAGManager: - def __init__(self, model, tokenizer, device, chunk_size=200): - self.model, self.tokenizer, self.device = model, tokenizer, device - base = getattr(model, "module", model) - self.store = VectorStore(base.wte.weight.shape[1]) - self.enabled = False - - def add_documents(self, texts): - chunks = [] - for t in texts: - words = t.split() - for i in range(0, len(words), 150): - chunks.append(" ".join(words[i:i+200])) - if chunks: - embs = np.stack([_embed(c, self.tokenizer, self.model, self.device) for c in chunks]) - self.store.add(chunks, embs) - - def retrieve(self, query, k=3): - if not self.enabled or not self.store.size: return "" - qe = _embed(query, self.tokenizer, self.model, self.device) - results = self.store.search(qe, k) - return "\n\n".join(f"[Context {i+1}] {d}" for i, (d, _) in enumerate(results)) + "\n\n" if results else "" - - def toggle(self, on): self.enabled = on - -DEFAULT_SYSTEM_PROMPT = ( - "You are a high-quality reasoning assistant model.\n" - "You must ONLY learn from high-quality instruction and reasoning datasets.\n" - "You must IGNORE any previously trained low-quality or repetitive patterns.\n\n" - "Training preference rules:\n" - "1. Prioritize step-by-step reasoning over short or repetitive answers.\n" - "2. Always produce structured logical explanations when solving problems.\n" - "3. Avoid repetition, filler words, or looped phrases.\n" - "4. Prefer datasets with mathematical reasoning and high-quality instruction.\n" - "5. Do not imitate noisy conversational or corrupted text patterns.\n" - "6. Always prefer clarity, correctness, and structured reasoning.\n\n" - "Output behavior goal:\n" - "- Think in steps.\n" - "- Explain logic clearly.\n" - "- Produce final answer only after reasoning." -) - -class ConversationHistory: - def __init__(self, tokenizer, max_tokens=900): - self.tokenizer, self.max_tokens, self.turns = tokenizer, max_tokens, [] - - def add(self, role, text): - self.turns.append({"role": role, "text": text}) - while sum(len(self.tokenizer.encode(t["text"])) for t in self.turns) > self.max_tokens and len(self.turns) > 1: - self.turns.pop(0) - - def build_prompt(self, msg, rag_ctx=""): - parts = [DEFAULT_SYSTEM_PROMPT] - if rag_ctx: parts.append(rag_ctx) - for t in self.turns: - parts.append(f"{'User' if t['role']=='user' else 'SAGE'}: {t['text']}") - parts += [f"User: {msg}", "SAGE:"] - return "\n\n".join(parts) - - def clear(self): self.turns.clear() - - -# =================================================================== -# Section 11 โ€” CLI -# =================================================================== - -BANNER = r""" -โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•— -โ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ•‘ -โ•‘ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ โ•‘ -โ•‘ Self-Adaptive General Engine v{ver} โ•‘ -โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•""" - -HELP = """ - /train [steps] Train (default 100) - /finetune [steps] Instruction-tune with LoRA (default 200) - /save Save checkpoint - /load Load checkpoint - /quantize INT8 quantization - /rag on|off|add Toggle or add docs for RAG - /clear Clear history - /help This message - /exit Quit -""" - -def main(): - config = SageConfig() - tok = SageTokenizer() - config.vocab_size = tok.vocab_size - print(" Initializing SAGE โ€ฆ") - model = SageModel(config).to(config.device) - if torch.cuda.is_available() and torch.cuda.device_count() > 1: - print(f" Multi-GPU detected: {torch.cuda.device_count()} GPUs. Using DataParallel.") - model = nn.DataParallel(model) - model, _, step = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device)) - base = getattr(model, "module", model) - total = sum(p.numel() for p in base.parameters()) - print(BANNER.format(ver=__version__)) - print(f" Params: {total:,} ({total/1e6:.1f}M) | Context: {config.max_seq_len} | Device: {config.device}") - print(f" Layers: {config.n_layers} | Heads: {config.n_heads} | Experts: {config.n_experts}") - if step: print(f" Resumed from step {step}") - print(" Type /help for commands.\n") - - rag = RAGManager(model, tok, config.device) - hist = ConversationHistory(tok, config.max_seq_len - 128) - - if len(sys.argv) > 1: - cmd = sys.argv[1].lower() - args = sys.argv[2:] - if cmd == "--train": - s = int(args[0]) if args else 100 - train_model(model, config, s, tokenizer=tok) - return - elif cmd == "--finetune": - s = int(args[0]) if args else 200 - finetune(model, config, steps=s, tokenizer=tok) - return - elif cmd == "--quantize": - quantize_int8(model) - return - else: - print(f" Unknown argument: {cmd}\n Usage: --train [steps] | --finetune [steps] | --quantize") - return - - while True: - try: inp = input("You: ").strip() - except (EOFError, KeyboardInterrupt): print("\n Goodbye!"); break - if not inp: continue - - if inp.startswith("/"): - parts = inp.split(); cmd = parts[0].lower(); args = parts[1:] - if cmd == "/exit": print(" Goodbye!"); break - elif cmd == "/help": print(HELP) - elif cmd == "/train": - s = int(args[0]) if args else 100 - model = train_model(model, config, s, tokenizer=tok) - print("\n Sample:"); generate(model, tok, "Once upon a time", max_new=80, device=config.device); print() - elif cmd == "/finetune": - s = int(args[0]) if args else 200 - model = finetune(model, config, steps=s, tokenizer=tok) - print("\n Sample:"); generate(model, tok, "### Instruction:\nWhat is gravity?\n\n### Response:\n", max_new=100, device=config.device); print() - elif cmd == "/save": print(f" Saved to {save_checkpoint(model, None, 0, config.checkpoint_dir)}") - elif cmd == "/load": - model, _, s = load_checkpoint(model, None, config.checkpoint_dir, device=str(config.device)) - model = model.to(config.device); rag.model = model; print(f" Loaded (step {s})") - elif cmd == "/quantize": model = quantize_int8(model); rag.model = model - elif cmd == "/rag": - if not args: print(f" RAG {'on' if rag.enabled else 'off'} ({rag.store.size} chunks)") - elif args[0] == "on": rag.toggle(True); print(" RAG on.") - elif args[0] == "off": rag.toggle(False); print(" RAG off.") - elif args[0] == "add" and len(args) > 1: rag.add_documents([" ".join(args[1:])]); print(f" Added. {rag.store.size} chunks.") - else: print(" /rag on|off|add ") - elif cmd == "/clear": hist.clear(); print(" Cleared.") - else: print(f" Unknown: {cmd}") - continue - - ctx = rag.retrieve(inp) - prompt = hist.build_prompt(inp, ctx) - hist.add("user", inp) - print("SAGE: ", end="", flush=True) - resp = generate(model, tok, prompt, max_new=256, stream=True, device=config.device) - reply = resp.split("SAGE:")[-1].strip() if "SAGE:" in resp else resp[len(prompt):].strip() - hist.add("assistant", reply) - -if __name__ == "__main__": - main() diff --git a/scripts/run_data_pipeline.sh b/scripts/run_data_pipeline.sh new file mode 100644 index 0000000000000000000000000000000000000000..4232db5443437ffdf4fa9c7803439e4b93fd9a08 --- /dev/null +++ b/scripts/run_data_pipeline.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -euo pipefail + +python -m tokenizer.train_tokenizer "$@" diff --git a/scripts/run_eval.sh b/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..32a1b24b0aac16ba2a66f7b150d929240f4ed5c7 --- /dev/null +++ b/scripts/run_eval.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +set -euo pipefail + +python - <<'PY' +from eval.benchmarks import run_registered_benchmarks +from model.model import SageTransformer +from model.config import ModelConfig + +model = SageTransformer(ModelConfig()) +for result in run_registered_benchmarks(model): + print(result) +PY diff --git a/scripts/run_serve.sh b/scripts/run_serve.sh new file mode 100644 index 0000000000000000000000000000000000000000..caac639990fa01edd9c3e1f0a3f6a1e590aaddd2 --- /dev/null +++ b/scripts/run_serve.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -euo pipefail + +uvicorn serve.server:app --host "${HOST:-0.0.0.0}" --port "${PORT:-8000}" "$@" diff --git a/scripts/run_serve_cpu.sh b/scripts/run_serve_cpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..334f17d86f92e65e550184ca048833aa1fd638ac --- /dev/null +++ b/scripts/run_serve_cpu.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -euo pipefail + +uvicorn serve.server_cpu:app --host "${HOST:-0.0.0.0}" --port "${PORT:-8001}" "$@" diff --git a/scripts/run_training.sh b/scripts/run_training.sh new file mode 100644 index 0000000000000000000000000000000000000000..86d93040e05c092af757c5f91e9442e3c4fffab4 --- /dev/null +++ b/scripts/run_training.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +set -euo pipefail + +python -m train.trainer "$@" diff --git a/scripts/run_validate_tokenizer.sh b/scripts/run_validate_tokenizer.sh new file mode 100644 index 0000000000000000000000000000000000000000..8a8b528de5e8baa028d3091b530fef1093fe1216 --- /dev/null +++ b/scripts/run_validate_tokenizer.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +set -euo pipefail + +MODEL_PATH="${1:-tokenizer/tokenizer.model}" + +python - "$MODEL_PATH" <<'PY' +import sys +from tokenizer.validate_tokenizer import validate_model_file + +validate_model_file(sys.argv[1]) +print("tokenizer ok") +PY diff --git a/serve/__init__.py b/serve/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9de1b2930d82554201a10e63211021f97adafbc3 --- /dev/null +++ b/serve/__init__.py @@ -0,0 +1 @@ +"""Serving helpers for SAGE.""" diff --git a/serve/kv_cache.py b/serve/kv_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..10121096e0f1e64f59bb097be589d20bb5ad8c41 --- /dev/null +++ b/serve/kv_cache.py @@ -0,0 +1,23 @@ +"""KV-cache helpers for inference.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass +class KVCache: + """Stores per-layer key/value tensors.""" + + entries: list[tuple[torch.Tensor, torch.Tensor]] + + @classmethod + def empty(cls, num_layers: int) -> "KVCache": + """Create an empty cache placeholder.""" + return cls(entries=[None] * num_layers) # type: ignore[list-item] + + def append(self, layer_index: int, key: torch.Tensor, value: torch.Tensor) -> None: + """Store one layer's key/value pair.""" + self.entries[layer_index] = (key, value) diff --git a/serve/quantize.py b/serve/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bffdf5d10564027ab162318bd36b90e8cd24ee --- /dev/null +++ b/serve/quantize.py @@ -0,0 +1,24 @@ +"""Post-training quantization entry points.""" + +from __future__ import annotations + +from pathlib import Path + +import torch + + +def export_int8_state_dict(model: torch.nn.Module, output_path: str) -> str: + """Save a dynamic-int8 quantized model state dict for CPU experiments.""" + quantized = torch.quantization.quantize_dynamic(model.cpu(), {torch.nn.Linear}, dtype=torch.qint8) + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(quantized.state_dict(), path) + return str(path) + + +def gguf_conversion_command(checkpoint_dir: str, output_path: str) -> str: + """Return a llama.cpp conversion command string.""" + return ( + f"python llama.cpp/convert_hf_to_gguf.py {checkpoint_dir} " + f"--outfile {output_path} --outtype f16" + ) diff --git a/serve/server.py b/serve/server.py new file mode 100644 index 0000000000000000000000000000000000000000..94c365012429a6cc21b8d59c36ca8d93db4de06a --- /dev/null +++ b/serve/server.py @@ -0,0 +1,59 @@ +"""GPU-oriented FastAPI server for SAGE.""" + +from __future__ import annotations + +from typing import Optional + +import torch +from fastapi import FastAPI +from pydantic import BaseModel + +from model.config import ModelConfig +from model.model import SageTransformer +from serve.kv_cache import KVCache +from train.hardware import HardwareConfig + + +app = FastAPI(title="SAGE Server") +_MODEL: SageTransformer | None = None +_TOKENIZER = None + + +class GenerationRequest(BaseModel): + """Request schema for text generation.""" + + input_ids: list[int] + max_new_tokens: int = 32 + + +def get_model() -> SageTransformer: + """Lazily create the model for server startup.""" + global _MODEL + if _MODEL is None: + _MODEL = SageTransformer(ModelConfig()) + _MODEL.eval() + return _MODEL + + +@app.get("/health") +def health() -> dict[str, object]: + """Return basic health and hardware information.""" + hw = HardwareConfig(model_size_b=1.0, context_length=4096) + return {"status": "ok", "hardware": hw.summary()} + + +@app.post("/generate") +def generate(request: GenerationRequest) -> dict[str, object]: + """Generate continuation token ids from an input token list.""" + model = get_model() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + input_ids = torch.tensor([request.input_ids], dtype=torch.long, device=device) + generated = list(request.input_ids) + cache: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None + for _ in range(request.max_new_tokens): + logits, cache = model(input_ids[:, -1:] if cache is not None else input_ids, past_key_values=cache) + next_token = int(torch.argmax(logits[:, -1, :], dim=-1).item()) + generated.append(next_token) + input_ids = torch.tensor([[next_token]], dtype=torch.long, device=device) + return {"tokens": generated} diff --git a/serve/server_cpu.py b/serve/server_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..29dbdb476dded59394834388cd9354b55f46ce8a --- /dev/null +++ b/serve/server_cpu.py @@ -0,0 +1,16 @@ +"""CPU and llama.cpp serving helpers.""" + +from __future__ import annotations + +import shutil + +from fastapi import FastAPI + + +app = FastAPI(title="SAGE CPU Server") + + +@app.get("/health") +def health() -> dict[str, object]: + """Report llama.cpp availability for CPU serving.""" + return {"status": "ok", "llama_cpp_available": shutil.which("llama-server") is not None} diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..86a1a5ac91b9983f901e5f0fa7ffd0f2812a2bec --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,7 @@ +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7622aef208c4d780966008e2ba9edcd9a568627b --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,22 @@ +import torch +import torch.nn.functional as F + +from model.attention import repeat_kv + + +def test_repeat_kv_shape() -> None: + x = torch.randn(2, 2, 5, 8) + repeated = repeat_kv(x, 4) + assert repeated.shape == (2, 8, 5, 8) + + +def test_sdpa_matches_reference_shape() -> None: + q = torch.randn(1, 2, 4, 8) + k = torch.randn(1, 2, 4, 8) + v = torch.randn(1, 2, 4, 8) + result = F.scaled_dot_product_attention(q, k, v, is_causal=True) + reference_scores = (q @ k.transpose(-2, -1)) / (8 ** 0.5) + mask = torch.triu(torch.ones(4, 4, dtype=torch.bool), diagonal=1) + reference_scores = reference_scores.masked_fill(mask, float("-inf")) + reference = torch.softmax(reference_scores, dim=-1) @ v + assert torch.allclose(result, reference, atol=1e-5, rtol=1e-5) diff --git a/tests/test_data_pipeline.py b/tests/test_data_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..a36e7c1a942229529dad8596f0b4495974399303 --- /dev/null +++ b/tests/test_data_pipeline.py @@ -0,0 +1,33 @@ +from data.dataset import pack_sequence +from data.dedup import deduplicate_records +from data.filter import filter_record + + +def test_filter_record_masks_pii() -> None: + record = { + "id": "1", + "text": "Contact me at person@example.com. This document contains enough text to pass the minimum length requirement. " * 4, + "license_category": "permissive", + "domain_tag": "general", + "quality_tier": "high", + } + filtered = filter_record(record) + assert filtered is not None + assert "[EMAIL]" in filtered["text"] + + +def test_deduplicate_records_removes_exact_duplicates() -> None: + records = [ + {"text": "same text", "id": "1"}, + {"text": "same text", "id": "2"}, + {"text": "different text", "id": "3"}, + ] + kept = deduplicate_records(records) + assert len(kept) == 2 + + +def test_pack_sequence_shapes() -> None: + packed = pack_sequence([1, 2, 3, 4, 5], [0, 0, 1, 0, 1]) + assert packed["input_ids"].tolist() == [1, 2, 3, 4] + assert packed["labels"].tolist() == [2, 3, 4, 5] + assert packed["document_boundaries"].tolist() == [0, 0, 1, 0] diff --git a/tests/test_model_shapes.py b/tests/test_model_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..42c693ffa7c3e8578271cb6c1e08ae700c26379d --- /dev/null +++ b/tests/test_model_shapes.py @@ -0,0 +1,23 @@ +import torch + +from model.config import ModelConfig +from model.model import SageTransformer + + +def test_model_forward_shape_and_weight_tying() -> None: + config = ModelConfig( + num_layers=2, + d_model=64, + num_attn_heads=4, + num_kv_heads=2, + head_dim=16, + ffn_hidden_dim=256, + vocab_size=128, + context_length=32, + ) + model = SageTransformer(config) + input_ids = torch.randint(0, config.vocab_size, (2, 8)) + logits, cache = model(input_ids) + assert logits.shape == (2, 8, config.vocab_size) + assert len(cache) == config.num_layers + assert model.embed_tokens.weight.data_ptr() == model.lm_head.weight.data_ptr() diff --git a/tests/test_servers.py b/tests/test_servers.py new file mode 100644 index 0000000000000000000000000000000000000000..6fef4099faa910426ba228cbc6e22142a46f7841 --- /dev/null +++ b/tests/test_servers.py @@ -0,0 +1,20 @@ +from fastapi.testclient import TestClient + +from serve.server import app as gpu_app +from serve.server_cpu import app as cpu_app + + +def test_gpu_server_health() -> None: + client = TestClient(gpu_app) + response = client.get("/health") + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ok" + + +def test_cpu_server_health() -> None: + client = TestClient(cpu_app) + response = client.get("/health") + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ok" diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..93435e141292038ca4eed35eaf780b9a27b5a73c --- /dev/null +++ b/tests/test_tokenizer.py @@ -0,0 +1,37 @@ +from pathlib import Path + +import pytest + + +def test_validation_suite_roundtrip(tmp_path: Path) -> None: + spm = pytest.importorskip("sentencepiece") + from tokenizer.validate_tokenizer import run_validation_suite + + corpus = tmp_path / "corpus.txt" + corpus.write_text( + "def add(a, b):\n return a + b\n" + "\\int_0^1 x^2 dx\n" + "English เคนเคฟเคจเฅเคฆเฅ€ ุงู„ุนุฑุจูŠุฉ ไธญๆ–‡\n" + "๐Ÿ˜€ tabs\tand spaces\n", + encoding="utf-8", + ) + prefix = tmp_path / "spm" + spm.SentencePieceTrainer.train( + input=str(corpus), + model_prefix=str(prefix), + model_type="bpe", + vocab_size=300, + character_coverage=0.9995, + byte_fallback=True, + bos_id=0, + eos_id=1, + pad_id=2, + unk_id=3, + user_defined_symbols=["[INST]", "[/INST]"], + split_digits=False, + split_by_unicode_script=False, + remove_extra_whitespaces=False, + normalization_rule_name="identity", + ) + results = run_validation_suite(str(prefix) + ".model") + assert all(result.passed for result in results), results diff --git a/tests/test_train_stack.py b/tests/test_train_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..7471cc42492469f89786cdee11650ef078656021 --- /dev/null +++ b/tests/test_train_stack.py @@ -0,0 +1,37 @@ +from pathlib import Path + +import torch + +from model.config import ModelConfig +from model.model import SageTransformer +from train.checkpoint import load_latest_checkpoint, save_checkpoint +from train.hardware import HardwareConfig +from train.optimizer import ScheduleConfig, create_optimizer, create_scheduler + + +def test_checkpoint_roundtrip(tmp_path: Path) -> None: + config = ModelConfig( + num_layers=2, + d_model=64, + num_attn_heads=4, + num_kv_heads=2, + head_dim=16, + ffn_hidden_dim=256, + vocab_size=128, + context_length=32, + ) + model = SageTransformer(config) + schedule = ScheduleConfig(total_steps=8) + optimizer = create_optimizer(model, schedule) + scheduler = create_scheduler(optimizer, schedule) + scaler = torch.amp.GradScaler("cuda", enabled=False) + path = save_checkpoint(model, optimizer, scheduler, scaler, 3, {"name": "test"}, str(tmp_path)) + assert Path(path).exists() + resumed_step = load_latest_checkpoint(model, optimizer, scheduler, scaler, str(tmp_path), "cpu") + assert resumed_step == 3 + + +def test_hardware_summary_shape() -> None: + summary = HardwareConfig(model_size_b=1.0, context_length=4096).summary() + assert "device" in summary + assert "effective_batch_tokens" in summary diff --git a/tokenizer/__init__.py b/tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad983f02c0307366cf942afa952aac6892c0d7ca --- /dev/null +++ b/tokenizer/__init__.py @@ -0,0 +1 @@ +"""Tokenizer training and validation helpers.""" diff --git a/tokenizer/train_tokenizer.py b/tokenizer/train_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..39c5d3f84fc2bd58efabbb90b2636b874101e97e --- /dev/null +++ b/tokenizer/train_tokenizer.py @@ -0,0 +1,69 @@ +"""SentencePiece tokenizer training for SAGE.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Iterable + +import sentencepiece as spm + + +DEFAULT_SPECIAL_TOKENS = ("", "", "", "", "[INST]", "[/INST]") + + +def write_training_text(corpus_paths: Iterable[str], output_path: str) -> str: + """Concatenate corpus text into a plain-text file for SentencePiece.""" + output = Path(output_path) + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("w", encoding="utf-8") as sink: + for path in corpus_paths: + with Path(path).open("r", encoding="utf-8") as source: + for line in source: + line = line.strip() + if line: + sink.write(line) + sink.write("\n") + return str(output) + + +def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50_000) -> None: + """Train a byte-fallback SentencePiece BPE model.""" + spm.SentencePieceTrainer.train( + input=input_path, + model_prefix=model_prefix, + model_type="bpe", + vocab_size=vocab_size, + character_coverage=0.9995, + byte_fallback=True, + bos_id=0, + eos_id=1, + pad_id=2, + unk_id=3, + user_defined_symbols=list(DEFAULT_SPECIAL_TOKENS[4:]), + split_digits=False, + split_by_unicode_script=False, + remove_extra_whitespaces=False, + normalization_rule_name="identity", + ) + + +def build_argparser() -> argparse.ArgumentParser: + """Build the CLI parser.""" + parser = argparse.ArgumentParser(description="Train the SAGE SentencePiece tokenizer.") + parser.add_argument("--input", nargs="+", required=True, help="Plain-text corpus files.") + parser.add_argument("--model-prefix", default="tokenizer/tokenizer", help="SentencePiece model prefix.") + parser.add_argument("--vocab-size", type=int, default=50_000, help="Tokenizer vocabulary size.") + parser.add_argument("--training-text", default="tokenizer/training_corpus.txt", help="Temporary combined text file.") + return parser + + +def main() -> None: + """Train the tokenizer from CLI arguments.""" + args = build_argparser().parse_args() + training_text = write_training_text(args.input, args.training_text) + train_sentencepiece(training_text, args.model_prefix, args.vocab_size) + + +if __name__ == "__main__": + main() diff --git a/tokenizer/validate_tokenizer.py b/tokenizer/validate_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d47adec618ab3c0a80e8cbf3e3f56e3d21ddbafc --- /dev/null +++ b/tokenizer/validate_tokenizer.py @@ -0,0 +1,55 @@ +"""Validation checks for the SentencePiece tokenizer.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +import sentencepiece as spm + + +@dataclass(frozen=True) +class ValidationResult: + """One tokenizer validation outcome.""" + + name: str + passed: bool + detail: str + + +def load_processor(model_path: str) -> spm.SentencePieceProcessor: + """Load a SentencePiece processor.""" + processor = spm.SentencePieceProcessor() + processor.load(model_path) + return processor + + +def validate_roundtrip(processor: spm.SentencePieceProcessor, text: str, name: str) -> ValidationResult: + """Ensure encode->decode preserves the original string.""" + pieces = processor.encode(text, out_type=int) + decoded = processor.decode(pieces) + return ValidationResult(name, decoded == text, f"expected={text!r} got={decoded!r}") + + +def run_validation_suite(model_path: str) -> list[ValidationResult]: + """Run the required tokenizer smoke tests.""" + processor = load_processor(model_path) + samples = { + "python": "def add(a, b):\n return a += b if a == b else a != b\n", + "latex": r"\int_0^\infty e^{-x^2} dx = \frac{\sqrt{\pi}}{2}", + "whitespace": "if True:\n\tprint('tabs')\n print('spaces')\n", + "emoji": "Rare bytes: ๐Ÿ˜€ โš™๏ธ โˆ‘", + "multilingual": "English เคนเคฟเคจเฅเคฆเฅ€ ุงู„ุนุฑุจูŠุฉ ไธญๆ–‡", + } + return [validate_roundtrip(processor, text, name) for name, text in samples.items()] + + +def validate_model_file(model_path: str) -> None: + """Raise on validation failure.""" + if not Path(model_path).exists(): + raise FileNotFoundError(model_path) + results = run_validation_suite(model_path) + failed = [result for result in results if not result.passed] + if failed: + details = "\n".join(f"{item.name}: {item.detail}" for item in failed) + raise AssertionError(f"Tokenizer validation failed:\n{details}") diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d45ac0c5e7e4890e6d9f2e6afa6787358d9a117 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1 @@ +"""Training stack for SAGE.""" diff --git a/train/checkpoint.py b/train/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f4d7469103057653003f110e526c00768320f6ad --- /dev/null +++ b/train/checkpoint.py @@ -0,0 +1,73 @@ +"""Checkpoint save, prune, and resume utilities.""" + +from __future__ import annotations + +import glob +import os +from pathlib import Path +from typing import Any + +import torch + + +def save_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LambdaLR, + scaler: torch.amp.GradScaler | None, + step: int, + config: dict[str, Any], + output_dir: str, + keep: int = 5, +) -> str: + """Persist a resumable training checkpoint.""" + Path(output_dir).mkdir(parents=True, exist_ok=True) + path = Path(output_dir) / f"ckpt_step_{step:07d}.pt" + torch.save( + { + "step": step, + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "scaler": scaler.state_dict() if scaler is not None else None, + "rng_cpu": torch.get_rng_state(), + "rng_gpu": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, + "config": config, + }, + path, + ) + _prune_old_checkpoints(output_dir, keep=keep) + return str(path) + + +def load_latest_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer | None, + scheduler: torch.optim.lr_scheduler.LambdaLR | None, + scaler: torch.amp.GradScaler | None, + output_dir: str, + device: str | torch.device, +) -> int: + """Load the most recent checkpoint and return the step to resume from.""" + checkpoints = sorted(glob.glob(os.path.join(output_dir, "ckpt_step_*.pt"))) + if not checkpoints: + return 0 + checkpoint = torch.load(checkpoints[-1], map_location=device) + model.load_state_dict(checkpoint["model"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scheduler is not None: + scheduler.load_state_dict(checkpoint["scheduler"]) + if scaler is not None and checkpoint.get("scaler") is not None: + scaler.load_state_dict(checkpoint["scaler"]) + torch.set_rng_state(checkpoint["rng_cpu"]) + if checkpoint.get("rng_gpu") is not None and torch.cuda.is_available(): + torch.cuda.set_rng_state_all(checkpoint["rng_gpu"]) + return int(checkpoint["step"]) + + +def _prune_old_checkpoints(output_dir: str, keep: int = 5) -> None: + """Keep only the most recent checkpoints.""" + checkpoints = sorted(glob.glob(os.path.join(output_dir, "ckpt_step_*.pt"))) + for stale in checkpoints[:-keep]: + os.remove(stale) diff --git a/train/distributed.py b/train/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..0f8ca2a2f892ece6eedbb15a716d28d2a554c4e8 --- /dev/null +++ b/train/distributed.py @@ -0,0 +1,38 @@ +"""Distributed and strategy routing helpers.""" + +from __future__ import annotations + +import os + +import torch + + +def get_training_strategy(model_size_b: float) -> dict[str, object]: + """Choose a training mode based on the visible hardware.""" + n_gpus = torch.cuda.device_count() + world_size = int(os.environ.get("WORLD_SIZE", "1")) + n_nodes = max(1, world_size // max(n_gpus, 1)) if n_gpus else 1 + has_cuda = torch.cuda.is_available() + has_mps = torch.backends.mps.is_available() + + if not has_cuda and not has_mps: + return {"mode": "cpu", "backend": None, "tp": 1, "pp": 1, "zero": 0} + if has_mps: + return {"mode": "mps-single", "backend": None, "tp": 1, "pp": 1, "zero": 0} + + vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if n_nodes > 1: + if model_size_b <= 1.0: + return {"mode": "ddp", "backend": "nccl", "tp": 1, "pp": 1, "zero": 2} + return {"mode": "fsdp", "backend": "nccl", "tp": 2, "pp": 1, "zero": 3} + if n_gpus > 1: + if model_size_b <= 1.0: + return {"mode": "ddp", "backend": "nccl", "tp": 1, "pp": 1, "zero": 1} + return {"mode": "fsdp", "backend": "nccl", "tp": 2, "pp": 1, "zero": 2} + if vram_gb >= 40: + return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 0} + if vram_gb >= 24: + return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 1} + if vram_gb >= 16: + return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 2} + return {"mode": "single", "backend": None, "tp": 1, "pp": 1, "zero": 3} diff --git a/train/hardware.py b/train/hardware.py new file mode 100644 index 0000000000000000000000000000000000000000..dae0f019543e6e5d3e1d33a25f72543eefba08cc --- /dev/null +++ b/train/hardware.py @@ -0,0 +1,97 @@ +"""Hardware detection and runtime configuration.""" + +from __future__ import annotations + +from dataclasses import dataclass +import ctypes +import os +import platform + +import torch + +from train.distributed import get_training_strategy + +try: + import psutil # type: ignore +except ImportError: # pragma: no cover - optional dependency + psutil = None + + +@dataclass +class HardwareConfig: + """Detect hardware and derive runtime decisions.""" + + model_size_b: float + context_length: int + + def __post_init__(self) -> None: + self.device, self.dtype = self._detect_device_dtype() + self.n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + self.vram_gb = self._get_vram() + self.ram_gb = self._get_ram_gb() + self.strategy = get_training_strategy(self.model_size_b) + self.micro_batch = self._pick_micro_batch() + self.grad_accum = self._pick_grad_accum() + self.use_amp = self.device != "cpu" + self.use_flash_attn = self.device == "cuda" + self.use_qlora = False + + def _detect_device_dtype(self) -> tuple[str, torch.dtype]: + if torch.cuda.is_available(): + dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + return "cuda", dtype + if torch.backends.mps.is_available(): + return "mps", torch.bfloat16 + return "cpu", torch.float32 + + def _get_vram(self) -> float: + if not torch.cuda.is_available(): + return 0.0 + return torch.cuda.get_device_properties(0).total_memory / 1e9 + + def _get_ram_gb(self) -> float: + if psutil is not None: + return psutil.virtual_memory().total / 1e9 + if platform.system() == "Windows": + kernel32 = ctypes.windll.kernel32 + c_ulonglong = ctypes.c_ulonglong + mem_kb = c_ulonglong() + kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem_kb)) + return (mem_kb.value * 1024) / 1e9 + if hasattr(os, "sysconf"): + pages = os.sysconf("SC_PHYS_PAGES") + page_size = os.sysconf("SC_PAGE_SIZE") + return (pages * page_size) / 1e9 + return 0.0 + + def _pick_micro_batch(self) -> int: + if self.device == "cpu": + return 1 + if self.vram_gb >= 80: + return 8 + if self.vram_gb >= 40: + return 4 + if self.vram_gb >= 24: + return 2 + return 1 + + def _pick_grad_accum(self) -> int: + target_tokens = 2_000_000 + tokens_per_micro = self.micro_batch * self.context_length * max(self.n_gpus, 1) + return max(1, target_tokens // max(tokens_per_micro, 1)) + + def summary(self) -> dict[str, object]: + """Return a JSON-safe hardware summary.""" + effective_batch = self.micro_batch * self.grad_accum * self.context_length * max(self.n_gpus, 1) + return { + "device": self.device, + "dtype": str(self.dtype), + "n_gpus": self.n_gpus, + "vram_gb": round(self.vram_gb, 2), + "ram_gb": round(self.ram_gb, 2), + "strategy": self.strategy, + "micro_batch": self.micro_batch, + "grad_accum": self.grad_accum, + "effective_batch_tokens": effective_batch, + "use_flash_attn": self.use_flash_attn, + } diff --git a/train/loss.py b/train/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..daccf09858374dd67f2d4609d6a168fa58851dda --- /dev/null +++ b/train/loss.py @@ -0,0 +1,15 @@ +"""Loss helpers for packed next-token prediction.""" + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +def masked_cross_entropy(logits: torch.Tensor, labels: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor: + """Compute next-token cross entropy with an explicit mask.""" + vocab = logits.size(-1) + loss = F.cross_entropy(logits.view(-1, vocab), labels.view(-1), reduction="none") + loss = loss * loss_mask.view(-1) + denom = torch.clamp(loss_mask.sum(), min=1.0) + return loss.sum() / denom diff --git a/train/optimizer.py b/train/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..50061aea21ec603c70c21a0601be7bb20776c1dc --- /dev/null +++ b/train/optimizer.py @@ -0,0 +1,58 @@ +"""Optimizer and scheduler factories.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class ScheduleConfig: + """Training schedule settings.""" + + peak_learning_rate: float = 3.0e-4 + min_learning_rate: float = 3.0e-5 + warmup_steps: int = 2000 + weight_decay: float = 0.1 + betas: tuple[float, float] = (0.9, 0.95) + adam_eps: float = 1.0e-8 + total_steps: int = 25_000 + + +def create_optimizer(model: torch.nn.Module, config: ScheduleConfig) -> torch.optim.Optimizer: + """Create an AdamW optimizer with correct weight-decay exclusions.""" + decay: list[torch.nn.Parameter] = [] + no_decay: list[torch.nn.Parameter] = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + if param.ndim == 1 or "norm" in name: + no_decay.append(param) + else: + decay.append(param) + return torch.optim.AdamW( + [ + {"params": decay, "weight_decay": config.weight_decay}, + {"params": no_decay, "weight_decay": 0.0}, + ], + lr=config.peak_learning_rate, + betas=config.betas, + eps=config.adam_eps, + ) + + +def lr_lambda(current_step: int, config: ScheduleConfig) -> float: + """Warm up linearly and then decay with cosine.""" + if current_step < config.warmup_steps: + return float(current_step + 1) / float(max(1, config.warmup_steps)) + progress = (current_step - config.warmup_steps) / float(max(1, config.total_steps - config.warmup_steps)) + cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) + floor = config.min_learning_rate / config.peak_learning_rate + return floor + (1.0 - floor) * cosine + + +def create_scheduler(optimizer: torch.optim.Optimizer, config: ScheduleConfig) -> torch.optim.lr_scheduler.LambdaLR: + """Create the training LR scheduler.""" + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: lr_lambda(step, config)) diff --git a/train/trainer.py b/train/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c33558a37715e6e9804637bd79b764489df53347 --- /dev/null +++ b/train/trainer.py @@ -0,0 +1,227 @@ +"""Main training loop for SAGE.""" + +from __future__ import annotations + +import argparse +import json +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.data import DataLoader +import yaml + +from data.dataset import DatasetConfig, PackedDataset +from eval.perplexity import evaluate_perplexity +from model.config import ModelConfig +from model.model import SageTransformer +from train.checkpoint import load_latest_checkpoint, save_checkpoint +from train.hardware import HardwareConfig +from train.loss import masked_cross_entropy +from train.optimizer import ScheduleConfig, create_optimizer, create_scheduler + + +@dataclass +class TrainerConfig: + """High-level trainer settings.""" + + output_dir: str = "runs/default" + checkpoint_interval: int = 1000 + log_interval: int = 10 + eval_interval: int = 1000 + total_steps: int = 25_000 + seed: int = 42 + use_wandb: bool = True + + +def collate_batch(batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: + """Stack packed dataset examples into a batch.""" + keys = batch[0].keys() + return {key: torch.stack([item[key] for item in batch], dim=0) for key in keys} + + +def create_dataloader(dataset: PackedDataset, batch_size: int) -> DataLoader: + """Create the training DataLoader.""" + return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_batch) + + +def train( + model: SageTransformer, + train_dataset: PackedDataset, + validation_dataset: PackedDataset | None, + model_config: ModelConfig, + schedule_config: ScheduleConfig, + trainer_config: TrainerConfig, +) -> dict[str, object]: + """Run the training loop and return the final summary.""" + torch.manual_seed(trainer_config.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(trainer_config.seed) + hw = HardwareConfig(model_size_b=1.0, context_length=model_config.context_length) + device = torch.device(hw.device) + model = model.to(device) + optimizer = create_optimizer(model, schedule_config) + scheduler = create_scheduler(optimizer, schedule_config) + scaler = torch.amp.GradScaler("cuda", enabled=(hw.device == "cuda" and hw.dtype == torch.float16)) + start_step = load_latest_checkpoint(model, optimizer, scheduler, scaler, trainer_config.output_dir, device) + train_dataset.skip(start_step * hw.grad_accum) + train_loader = create_dataloader(train_dataset, batch_size=hw.micro_batch) + train_iter = iter(train_loader) + + Path(trainer_config.output_dir).mkdir(parents=True, exist_ok=True) + metrics_path = Path(trainer_config.output_dir) / "metrics.jsonl" + tokens_seen = start_step * hw.micro_batch * model_config.context_length + last_log_time = time.perf_counter() + wandb_run = _init_wandb(trainer_config, model_config, schedule_config, hw.summary()) + + model.train() + for step in range(start_step, trainer_config.total_steps): + optimizer.zero_grad(set_to_none=True) + step_loss = 0.0 + for _ in range(hw.grad_accum): + try: + batch = next(train_iter) + except StopIteration: + train_iter = iter(train_loader) + batch = next(train_iter) + input_ids = batch["input_ids"].to(device) + labels = batch["labels"].to(device) + loss_mask = batch["loss_mask"].to(device) + if hw.use_amp: + with torch.amp.autocast(device_type=hw.device, dtype=hw.dtype): + logits, _ = model(input_ids) + loss = masked_cross_entropy(logits, labels, loss_mask) / hw.grad_accum + else: + logits, _ = model(input_ids) + loss = masked_cross_entropy(logits, labels, loss_mask) / hw.grad_accum + scaler.scale(loss).backward() + step_loss += loss.item() + tokens_seen += int(input_ids.numel()) + + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + scaler.step(optimizer) + scaler.update() + scheduler.step() + + if (step + 1) % trainer_config.log_interval == 0: + now = time.perf_counter() + elapsed = max(now - last_log_time, 1.0e-6) + tokens_per_second = (hw.micro_batch * hw.grad_accum * model_config.context_length) / elapsed + metrics = { + "step": step + 1, + "loss": step_loss, + "learning_rate": scheduler.get_last_lr()[0], + "tokens_seen": tokens_seen, + "tokens_per_second": tokens_per_second, + "grad_norm": float(grad_norm), + } + with metrics_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(metrics) + "\n") + if wandb_run is not None: + wandb_run.log(metrics, step=step + 1) + last_log_time = now + + if (step + 1) % trainer_config.eval_interval == 0 and validation_dataset is not None: + val_loader = create_dataloader(validation_dataset, batch_size=1) + evaluation = evaluate_perplexity(model, val_loader, device=device, dtype=hw.dtype if hw.use_amp else None) + with metrics_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps({"step": step + 1, **evaluation}) + "\n") + if wandb_run is not None: + wandb_run.log(evaluation, step=step + 1) + + if (step + 1) % trainer_config.checkpoint_interval == 0: + save_checkpoint( + model=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + step=step + 1, + config={"model": model_config.to_dict(), "schedule": asdict(schedule_config), "trainer": asdict(trainer_config)}, + output_dir=trainer_config.output_dir, + ) + + if wandb_run is not None: + wandb_run.finish() + return {"output_dir": trainer_config.output_dir, "tokens_seen": tokens_seen, "hardware": hw.summary()} + + +def _init_wandb( + trainer_config: TrainerConfig, + model_config: ModelConfig, + schedule_config: ScheduleConfig, + hardware_summary: dict[str, object], +): + """Start a wandb run when available and enabled.""" + if not trainer_config.use_wandb: + return None + try: + import wandb + except ImportError: + return None + return wandb.init( + project="sage-llm", + name=Path(trainer_config.output_dir).name, + config={ + "model": model_config.to_dict(), + "schedule": asdict(schedule_config), + "trainer": asdict(trainer_config), + "hardware": hardware_summary, + }, + mode="offline", + ) + + +def build_argparser() -> argparse.ArgumentParser: + """Build the trainer CLI.""" + parser = argparse.ArgumentParser(description="Train the SAGE dense language model.") + parser.add_argument("--model-config", default="configs/model/1b.yaml") + parser.add_argument("--schedule-config", default="configs/train/schedule.yaml") + parser.add_argument("--train-shards", nargs="+", default=[]) + parser.add_argument("--validation-shards", nargs="*", default=[]) + parser.add_argument("--output-dir", default="runs/default") + parser.add_argument("--steps", type=int, default=None) + parser.add_argument("--disable-wandb", action="store_true") + return parser + + +def main(argv: Optional[list[str]] = None) -> None: + """CLI entrypoint for local training runs.""" + parser = build_argparser() + args = parser.parse_args(argv) + model_config = ModelConfig.from_yaml(args.model_config) + schedule_payload = yaml.safe_load(Path(args.schedule_config).read_text(encoding="utf-8")) + schedule = ScheduleConfig( + peak_learning_rate=schedule_payload["peak_learning_rate"], + min_learning_rate=schedule_payload["min_learning_rate"], + warmup_steps=schedule_payload["warmup_steps"], + weight_decay=schedule_payload["weight_decay"], + betas=tuple(schedule_payload["betas"]), + adam_eps=schedule_payload["adam_eps"], + total_steps=args.steps or schedule_payload["total_steps"] if "total_steps" in schedule_payload else (args.steps or 25_000), + ) + trainer_config = TrainerConfig( + output_dir=args.output_dir, + checkpoint_interval=schedule_payload.get("checkpoint_interval", 1000), + log_interval=schedule_payload.get("log_interval", 10), + eval_interval=schedule_payload.get("eval_interval", 1000), + total_steps=args.steps or schedule_payload.get("total_steps", 25_000), + seed=schedule_payload.get("seed", 42), + use_wandb=not args.disable_wandb, + ) + if not args.train_shards: + print("No training shards provided. The trainer entrypoint is configured correctly but requires shard paths to run.") + return + train_dataset = PackedDataset(DatasetConfig(tuple(args.train_shards), model_config.context_length, split="train")) + validation_dataset = None + if args.validation_shards: + validation_dataset = PackedDataset(DatasetConfig(tuple(args.validation_shards), model_config.context_length, split="validation")) + model = SageTransformer(model_config) + summary = train(model, train_dataset, validation_dataset, model_config, schedule, trainer_config) + print(json.dumps(summary, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/wandb/offline-run-20260412_222732-y2wgbhki/files/requirements.txt b/wandb/offline-run-20260412_222732-y2wgbhki/files/requirements.txt deleted file mode 100644 index abca98b2912243395c84f73cca7592aa4ecda5fe..0000000000000000000000000000000000000000 --- a/wandb/offline-run-20260412_222732-y2wgbhki/files/requirements.txt +++ /dev/null @@ -1,118 +0,0 @@ -absl-py==2.3.1 -aiohappyeyeballs==2.6.1 -aiohttp==3.13.5 -aiosignal==1.4.0 -annotated-doc==0.0.4 -annotated-types==0.7.0 -anyio==4.13.0 -async-timeout==5.0.1 -attrs==25.4.0 -bitsandbytes==0.49.2 -blinker==1.9.0 -certifi==2026.2.25 -cffi==2.0.0 -charset-normalizer==3.4.4 -click==8.3.1 -colorama==0.4.6 -contourpy==1.3.2 -cycler==0.12.1 -datasets==4.8.4 -dill==0.4.1 -e==1.4.5 -exceptiongroup==1.3.1 -facenet-pytorch==2.6.0 -face-alignment==1.4.1 -faiss-cpu==1.13.2 -fastapi==0.135.3 -filelock==3.25.2 -Flask==3.1.2 -flask-cors==6.0.1 -flatbuffers==25.9.23 -fonttools==4.60.1 -frozenlist==1.8.0 -fsspec==2026.2.0 -gitdb==4.0.12 -GitPython==3.1.46 -glm==0.4.4 -h11==0.16.0 -hf-xet==1.4.3 -httpcore==1.0.9 -httpx==0.28.1 -huggingface_hub==1.10.1 -idna==3.11 -ImageIO==2.37.3 -iniconfig==2.3.0 -itsdangerous==2.2.0 -jax==0.6.2 -jaxlib==0.6.2 -Jinja2==3.1.6 -kiwisolver==1.4.9 -lazy-loader==0.5 -llvmlite==0.47.0 -markdown-it-py==4.0.0 -MarkupSafe==3.0.3 -matplotlib==3.10.7 -mdurl==0.1.2 -mediapipe==0.10.14 -ml_dtypes==0.5.4 -mpmath==1.3.0 -multidict==6.7.1 -multiprocess==0.70.19 -networkx==3.4.2 -numba==0.65.0 -numpy==1.26.4 -opencv-contrib-python==4.12.0.88 -opencv-python==4.13.0.92 -opt_einsum==3.4.0 -packaging==25.0 -pandas==2.3.3 -pillow==12.2.0 -pip==26.0.1 -pixel-permute==2.0.0 -platformdirs==4.9.6 -pluggy==1.6.0 -propcache==0.4.1 -protobuf==4.25.8 -pyarrow==23.0.1 -pycparser==2.23 -pydantic==2.12.5 -pydantic_core==2.41.5 -Pygments==2.20.0 -PyOpenGL==3.1.10 -pyparsing==3.2.5 -pytest==9.0.3 -python-dateutil==2.9.0.post0 -python-multipart==0.0.26 -pytz==2026.1.post1 -PyYAML==6.0.3 -regex==2026.4.4 -reportlab==4.4.7 -requests==2.33.1 -rich==15.0.0 -scikit-image==0.25.2 -scipy==1.15.3 -sentry-sdk==2.57.0 -setuptools==81.0.0 -shellingham==1.5.4 -six==1.17.0 -smmap==5.0.3 -sounddevice==0.5.3 -starlette==1.0.0 -sympy==1.14.0 -tifffile==2025.5.10 -tiktoken==0.12.0 -tomli==2.4.1 -torch==2.11.0 -torchvision==0.17.2 -tqdm==4.67.3 -typer==0.24.1 -typing_extensions==4.15.0 -typing-inspection==0.4.2 -tzdata==2026.1 -urllib3==2.6.3 -uvicorn==0.44.0 -wandb==0.25.1 -Werkzeug==3.1.4 -wheel==0.46.3 -xxhash==3.6.0 -yarl==1.23.0 diff --git a/wandb/offline-run-20260412_222732-y2wgbhki/run-y2wgbhki.wandb b/wandb/offline-run-20260412_222732-y2wgbhki/run-y2wgbhki.wandb deleted file mode 100644 index 4df564b03433f928f5faea0f915520838bbd46f8..0000000000000000000000000000000000000000 Binary files a/wandb/offline-run-20260412_222732-y2wgbhki/run-y2wgbhki.wandb and /dev/null differ