Restructure to src/ layout with attention, per-layer MoE, and working chat
Browse files- Add GQA + RoPE attention (AttentionLayer) using config's n_heads/n_kv_heads/head_dim
- Wire 28 transformer layers (attention + shared MoE FFN per layer) in forward()
- Fix expert dispatch weight bug (per-token weights, not flattened scalar)
- Fix chat() stub: tokenize with GPT-2, generate, decode new tokens only
- Fix GRPO reward: remove proxy metrics, correctness + CoT bonus only
- Add load balance loss to GRPO train_step to prevent expert collapse
- Fix JSONLibrary recall() write-back and add 1000-entry cap per category
- Add tests/test_core.py with 11 tests covering all core components
- Add .gitignore excluding venv/, outputs/, data/, .vscode/, .claude/
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- .claude/settings.local.json +0 -15
- .gitattributes +0 -35
- .gitignore +11 -51
- DESCRIPTION.MD +0 -0
- README.md +25 -88
- config.py +0 -213
- configs/memory.yaml +16 -0
- configs/model.yaml +63 -0
- configs/model_15b.yaml +77 -0
- configs/sandbox.yaml +24 -0
- configs/training.yaml +28 -0
- inference/__init__.py +0 -1
- inference/api.py +0 -148
- inference/chat.py +0 -59
- inference/chat_simple.py +0 -54
- inference/daemon.py +0 -173
- inference/engine.py +0 -406
- memory/__init__.py +0 -1
- memory/database.py +0 -379
- memory/vector_store.py +0 -109
- model/__init__.py +0 -7
- model/base.py +0 -152
- model/echo.py +0 -71
- model/ensemble.py +0 -346
- model/expert.py +0 -45
- model/herald.py +0 -62
- model/lazy_expert_loader.py +0 -120
- model/sentinel.py +0 -90
- requirements.txt +14 -4
- scripts/01_download_15b_data.py +112 -0
- scripts/01_download_7b_150gb.py +272 -0
- scripts/01_download_stem_data.py +144 -0
- scripts/04_train.py +310 -0
- scripts/04_train_5090_optimized.py +146 -0
- scripts/04_train_stem.py +134 -0
- scripts/04_train_universal.py +426 -0
- scripts/05_grpo_train.py +325 -0
- scripts/07_run_shorekeeper.py +104 -0
- scripts/09_run_tests.py +70 -0
- scripts/full_training_loop.py +0 -40
- scripts/push_to_github.py +0 -30
- scripts/quick_test.py +0 -48
- scripts/run_training.py +0 -54
- scripts/run_training.sh +0 -112
- src/__init__.py +1 -0
- src/council/__init__.py +4 -0
- src/council/attention.py +62 -0
- src/council/base_expert.py +27 -0
- src/council/experts.py +73 -0
- src/council/sentinel.py +48 -0
.claude/settings.local.json
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"permissions": {
|
| 3 |
-
"allow": [
|
| 4 |
-
"Bash(du -sh /Users/georjanorellana/Downloads/shorekeeper/data/raw/*)",
|
| 5 |
-
"Bash(.venv/bin/pip install:*)",
|
| 6 |
-
"Bash(ls /home/albedogames/shorekeeper/.venv/bin/pip*)",
|
| 7 |
-
"Bash(ls /home/albedogames/shorekeeper/venv/bin/pip*)",
|
| 8 |
-
"Bash(/home/albedogames/shorekeeper/.venv/bin/pip install:*)",
|
| 9 |
-
"Bash(python3 -c \"import sys; print\\(sys.executable\\)\")",
|
| 10 |
-
"Bash(/home/albedogames/shorekeeper/.venv/bin/python3 -m pip install psutil -q)",
|
| 11 |
-
"Bash(/home/albedogames/shorekeeper/.venv/bin/python3 -c \"import psutil; print\\(''psutil OK''\\)\")",
|
| 12 |
-
"Bash(.venv/bin/python -c \":*)"
|
| 13 |
-
]
|
| 14 |
-
}
|
| 15 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitattributes
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
CHANGED
|
@@ -1,56 +1,16 @@
|
|
| 1 |
-
# Python
|
| 2 |
-
__pycache__/
|
| 3 |
-
*.py[cod]
|
| 4 |
-
*$py.class
|
| 5 |
-
*.so
|
| 6 |
-
.Python
|
| 7 |
-
env/
|
| 8 |
venv/
|
| 9 |
.venv/
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
.
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
parts/
|
| 19 |
-
sdist/
|
| 20 |
-
var/
|
| 21 |
-
wheels/
|
| 22 |
*.egg-info/
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# Large data files (too big for GitHub)
|
| 27 |
-
data/raw/
|
| 28 |
-
data/processed/
|
| 29 |
-
|
| 30 |
-
# Model checkpoints (too big for GitHub)
|
| 31 |
-
checkpoints/
|
| 32 |
-
|
| 33 |
-
# Runtime databases
|
| 34 |
-
memory_db/*.db
|
| 35 |
-
memory_store/
|
| 36 |
-
|
| 37 |
-
# Logs
|
| 38 |
-
logs/
|
| 39 |
-
|
| 40 |
-
# macOS resource forks
|
| 41 |
-
._*
|
| 42 |
.DS_Store
|
| 43 |
-
|
| 44 |
-
# IDE / Editor
|
| 45 |
.vscode/
|
| 46 |
-
.
|
| 47 |
-
*.swp
|
| 48 |
-
*.swo
|
| 49 |
-
*~
|
| 50 |
-
|
| 51 |
-
# Misc
|
| 52 |
-
*.log
|
| 53 |
-
.env
|
| 54 |
-
|
| 55 |
-
# Tokenizer training output (keep in repo if small)
|
| 56 |
-
# tokenizer/shorekeeper_tok/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
venv/
|
| 2 |
.venv/
|
| 3 |
+
outputs/
|
| 4 |
+
data/
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*.pth
|
| 8 |
+
*.pt
|
| 9 |
+
*.bin
|
| 10 |
+
.env
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
*.egg-info/
|
| 12 |
+
dist/
|
| 13 |
+
build/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
.DS_Store
|
|
|
|
|
|
|
| 15 |
.vscode/
|
| 16 |
+
.claude/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DESCRIPTION.MD
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -1,91 +1,28 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
| 20 |
|
| 21 |
## Quick Start
|
| 22 |
|
| 23 |
-
``
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
python3 tokenizer/train_tokenizer.py
|
| 28 |
-
python3 scripts/full_training_loop.py
|
| 29 |
-
python3 inference/chat.py
|
| 30 |
-
|
| 31 |
-
# Cross-platform daemon (Linux/macOS UNIX socket by default)
|
| 32 |
-
python3 inference/daemon.py --mode auto
|
| 33 |
-
# If Windows, run:
|
| 34 |
-
python3 inference/daemon.py --mode tcp --host 127.0.0.1 --port 8500
|
| 35 |
-
|
| 36 |
-
# GitHub Migration
|
| 37 |
-
python3 scripts/push_to_github.py
|
| 38 |
-
```
|
| 39 |
-
|
| 40 |
-
## Manual Steps
|
| 41 |
-
|
| 42 |
-
```bash
|
| 43 |
-
pip install torch tokenizers
|
| 44 |
-
|
| 45 |
-
python data/generate_sample_data.py # create sample data
|
| 46 |
-
python tokenizer/train_tokenizer.py # train BPE tokenizer
|
| 47 |
-
python data/processor.py # tokenize -> tensors
|
| 48 |
-
python memory/database.py # init SQLite
|
| 49 |
-
# Cross-platform quick runner:
|
| 50 |
-
python scripts/run_training.py
|
| 51 |
-
|
| 52 |
-
python training/train_base.py # phase 1: shared base
|
| 53 |
-
python training/train_expert.py --all # phase 2: all experts
|
| 54 |
-
python training/train_ensemble.py # phase 6: end-to-end
|
| 55 |
-
|
| 56 |
-
python inference/chat.py # launch chat
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
## Scale Up (Real Training)
|
| 60 |
-
|
| 61 |
-
- `VOCAB_SIZE`: 32000
|
| 62 |
-
- `BASE_CONFIG`: 1024 dim / 12 layers / 16 heads
|
| 63 |
-
- `n_positions`: 2048
|
| 64 |
-
- `TRAIN_CONFIG max_steps`: 50000
|
| 65 |
-
- **Device**: MPS (Metal) for Mac Silicon
|
| 66 |
-
|
| 67 |
-
## Chat Commands
|
| 68 |
-
|
| 69 |
-
```
|
| 70 |
-
/route <query> show routing without generating
|
| 71 |
-
/expert <name> force specific expert
|
| 72 |
-
/routing on|off toggle routing display
|
| 73 |
-
/incidents Sentinel incident log
|
| 74 |
-
/reset clear session memory
|
| 75 |
-
/experts list expert names
|
| 76 |
-
/exit quit
|
| 77 |
-
```
|
| 78 |
-
|
| 79 |
-
## Expert Roster
|
| 80 |
-
|
| 81 |
-
| Name | Domain | Named After |
|
| 82 |
-
|------|--------|-------------|
|
| 83 |
-
| calcharo | Security / CVE / Network threats | The Calamity |
|
| 84 |
-
| rover | Code / Debug / Architecture | The Explorer |
|
| 85 |
-
| resonance | Logic / Reasoning / Causation | The Force |
|
| 86 |
-
| tacet | Threat Intel / IOC / APT | The Discord |
|
| 87 |
-
| jianxin | Linux / OS / CUDA / systemd | Calm Mastery |
|
| 88 |
-
| verina | Conversation / NLP / Interface | Healer |
|
| 89 |
-
| sentinel | Self-Monitor / Drift Detection | The Watchman |
|
| 90 |
-
| herald | Router (always active) | The Messenger |
|
| 91 |
-
| echo | Memory (always active) | Resonant Imprint |
|
|
|
|
| 1 |
+
# SHOREKEEPER-4B
|
| 2 |
+
|
| 3 |
+
**A 4B parameter reasoning model with 12 specialized experts and infinite memory.**
|
| 4 |
+
|
| 5 |
+
## The Council of Experts
|
| 6 |
+
|
| 7 |
+
| Expert | Role | Specialty |
|
| 8 |
+
|--------|------|-----------|
|
| 9 |
+
| **Sentinel** | Router | Decides which experts activate |
|
| 10 |
+
| Asmoday | Code Architect | Python, algorithms, debugging |
|
| 11 |
+
| Istaroth | Systems | OS, networking, Docker |
|
| 12 |
+
| Ronova | Reasoning | Math, logic, step-by-step |
|
| 13 |
+
| Naberius | Memory | JSON library retrieval |
|
| 14 |
+
| Phanes | Creation | Writing, generation, creativity |
|
| 15 |
+
| Barbeloth | Analysis | Data, patterns, insights |
|
| 16 |
+
| Tacet | Silence | Noise filtering, summarization |
|
| 17 |
+
| Abby | Empathy | User context, preferences |
|
| 18 |
+
| Reindoter | Validation | Testing, verification |
|
| 19 |
+
| Zestial | Vision | Code visualization, graphs |
|
| 20 |
+
| Alice | Exploration | Novel solutions, experimentation |
|
| 21 |
+
| Rover | Execution | Terminal commands, sandbox |
|
| 22 |
|
| 23 |
## Quick Start
|
| 24 |
|
| 25 |
+
`'`bash
|
| 26 |
+
pip install -r requirements.txt
|
| 27 |
+
python scripts/07_run_shorekeeper.py
|
| 28 |
+
`'`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.py
DELETED
|
@@ -1,213 +0,0 @@
|
|
| 1 |
-
# config.py
|
| 2 |
-
# Central configuration for all Shorekeeper components.
|
| 3 |
-
# Every constant, path, and hyperparameter lives here.
|
| 4 |
-
# Import with: from config import TRAIN_CONFIG, CHECKPOINT_DIR, etc.
|
| 5 |
-
|
| 6 |
-
import os
|
| 7 |
-
import torch
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
# ── PROJECT ROOT ──────────────────────────────────────────────────────
|
| 12 |
-
PROJECT_ROOT = Path(__file__).parent
|
| 13 |
-
|
| 14 |
-
# ── DATA PATHS ────────────────────────────────────────────────────────
|
| 15 |
-
# If using external drive with symlinks (recommended):
|
| 16 |
-
# data/raw → /mnt/shorekeeper_data/raw
|
| 17 |
-
# data/processed → /mnt/shorekeeper_data/processed
|
| 18 |
-
# checkpoints/ → /mnt/shorekeeper_data/checkpoints
|
| 19 |
-
# If not using symlinks, these paths just live on the main drive.
|
| 20 |
-
RAW_DATA_DIR = PROJECT_ROOT / "data" / "raw"
|
| 21 |
-
PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
|
| 22 |
-
CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints"
|
| 23 |
-
LOG_DIR = PROJECT_ROOT / "logs"
|
| 24 |
-
MEMORY_DIR = PROJECT_ROOT / "memory_store"
|
| 25 |
-
|
| 26 |
-
# Auto-create all directories on import
|
| 27 |
-
for _d in [RAW_DATA_DIR, PROCESSED_DIR, CHECKPOINT_DIR, LOG_DIR, MEMORY_DIR]:
|
| 28 |
-
_d.mkdir(parents=True, exist_ok=True)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# ── VOCABULARY ────────────────────────────────────────────────────────
|
| 32 |
-
VOCAB_SIZE = 50_257 # GPT-2 compatible vocabulary size
|
| 33 |
-
|
| 34 |
-
# Special tokens and their IDs (assigned during tokenizer training)
|
| 35 |
-
# These MUST match what train_tokenizer.py produces.
|
| 36 |
-
# DO NOT change these after the tokenizer is trained.
|
| 37 |
-
SPECIAL_TOKENS = {
|
| 38 |
-
"[PAD]": 0, # Padding token (ignored in loss computation)
|
| 39 |
-
"[UNK]": 1, # Unknown token (should be rare with BPE)
|
| 40 |
-
"[BOS]": 2, # Beginning of sequence
|
| 41 |
-
"[EOS]": 3, # End of sequence
|
| 42 |
-
"[SEP]": 4, # Separator (used between context and query by Echo)
|
| 43 |
-
"[MASK]": 5, # Masked token (reserved for future MLM training)
|
| 44 |
-
"[SYSTEM]": 6, # System prompt marker
|
| 45 |
-
"[USER]": 7, # User turn marker
|
| 46 |
-
"[ASSISTANT]": 8, # Assistant turn marker
|
| 47 |
-
"[MEMORY]": 9, # Memory context injection marker
|
| 48 |
-
"[SECURITY]": 10, # calcharo expert domain marker
|
| 49 |
-
"[CODE]": 11, # rover expert domain marker
|
| 50 |
-
"[REASON]": 12, # resonance expert domain marker
|
| 51 |
-
"[SYSTEM2]": 13, # jianxin expert domain marker
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
# ── BASE MODEL CONFIG ─────────────────────────────────────────────────
|
| 56 |
-
# SharedBase: the 500M shared transformer backbone.
|
| 57 |
-
# Every expert builds on top of these representations.
|
| 58 |
-
BASE_CONFIG = {
|
| 59 |
-
"n_embd": 2048, # Hidden state dimension
|
| 60 |
-
# This dimension flows through ALL components
|
| 61 |
-
"n_head": 16, # Attention heads (n_embd / n_head = 128 per head)
|
| 62 |
-
"n_layer": 8, # Transformer layers in the shared base (8 fits in 12 GB VRAM)
|
| 63 |
-
"n_positions": 2048, # Maximum sequence length (context window)
|
| 64 |
-
"dropout": 0.1, # Dropout rate (applied during training, disabled at inference)
|
| 65 |
-
"vocab_size": VOCAB_SIZE,
|
| 66 |
-
}
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# ── EXPERT CONFIGS ────────────────────────────────────────────────────
|
| 70 |
-
# Each expert has its own number of transformer layers.
|
| 71 |
-
# Experts with more layers have more capacity for their domain.
|
| 72 |
-
# The n_embd MUST match BASE_CONFIG["n_embd"] = 2048.
|
| 73 |
-
EXPERT_NAMES = ["calcharo", "rover", "resonance", "tacet", "jianxin", "verina"]
|
| 74 |
-
|
| 75 |
-
EXPERT_CONFIGS = {
|
| 76 |
-
# Heavy experts (8 layers): high-value domains with most training data
|
| 77 |
-
"calcharo": {"n_layer": 8, "n_embd": 2048, "n_head": 16},
|
| 78 |
-
"rover": {"n_layer": 8, "n_embd": 2048, "n_head": 16},
|
| 79 |
-
"resonance": {"n_layer": 8, "n_embd": 2048, "n_head": 16},
|
| 80 |
-
# Medium experts (6 layers): specialized domains
|
| 81 |
-
"tacet": {"n_layer": 6, "n_embd": 2048, "n_head": 16},
|
| 82 |
-
"jianxin": {"n_layer": 6, "n_embd": 2048, "n_head": 16},
|
| 83 |
-
"verina": {"n_layer": 6, "n_embd": 2048, "n_head": 16},
|
| 84 |
-
# Monitoring expert (4 layers): sentinel only needs classification capacity
|
| 85 |
-
"sentinel": {"n_layer": 4, "n_embd": 2048, "n_head": 16},
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
# Herald and Echo configs (these are routing/memory modules, not generative)
|
| 89 |
-
HERALD_CONFIG = {"n_layer": 2, "n_embd": 2048, "n_head": 16, "n_experts": len(EXPERT_NAMES), "top_k": 2}
|
| 90 |
-
ECHO_CONFIG = {"n_layer": 1, "n_embd": 2048, "n_head": 16, "max_memory_tokens": 512}
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
# ── TRAINING CONFIG ───────────────────────────────────────────────────
|
| 94 |
-
TRAIN_CONFIG = {
|
| 95 |
-
|
| 96 |
-
# Phase 1: Base pre-training
|
| 97 |
-
"base_lr": 3e-4, # Peak learning rate for base model
|
| 98 |
-
"base_max_steps": 100_000,# Total training steps (100k)
|
| 99 |
-
"base_warmup": 2_000, # LR warmup steps
|
| 100 |
-
"base_batch_size": 1, # Mini-batch size per GPU (reduced for 12GB VRAM)
|
| 101 |
-
"base_grad_accum": 32, # Gradient accumulation
|
| 102 |
-
# Effective batch = 1 * 32 = 32 sequences
|
| 103 |
-
# Phase 2: Expert fine-tuning
|
| 104 |
-
"expert_lr": 1e-4, # Lower LR for fine-tuning
|
| 105 |
-
"expert_max_steps": 50_000,
|
| 106 |
-
"expert_warmup": 1_000,
|
| 107 |
-
"expert_batch_size":2, # Expert heads are lighter (no base layers)
|
| 108 |
-
"expert_grad_accum":16,
|
| 109 |
-
|
| 110 |
-
# Phase 3: Herald routing training
|
| 111 |
-
"herald_lr": 1e-4,
|
| 112 |
-
"herald_max_steps": 20_000,
|
| 113 |
-
"herald_warmup": 500,
|
| 114 |
-
"herald_batch_size": 32, # Routing examples are short, bigger batches
|
| 115 |
-
"herald_grad_accum": 2,
|
| 116 |
-
|
| 117 |
-
# Phase 4: Sentinel training
|
| 118 |
-
"sentinel_lr": 5e-5,
|
| 119 |
-
"sentinel_max_steps": 15_000,
|
| 120 |
-
"sentinel_warmup": 500,
|
| 121 |
-
"sentinel_batch_size": 16,
|
| 122 |
-
"sentinel_grad_accum": 2,
|
| 123 |
-
|
| 124 |
-
# Phase 6: Ensemble fine-tuning
|
| 125 |
-
"ensemble_lr": 5e-5, # Very low LR — preserve pre-trained knowledge
|
| 126 |
-
"ensemble_max_steps": 30_000,
|
| 127 |
-
"ensemble_warmup": 1_000,
|
| 128 |
-
"ensemble_batch_size": 2, # Small batch — full ensemble is huge
|
| 129 |
-
"ensemble_grad_accum": 16,
|
| 130 |
-
|
| 131 |
-
# Shared across all phases
|
| 132 |
-
"weight_decay": 0.1,
|
| 133 |
-
"grad_clip": 1.0, # Gradient norm clipping threshold
|
| 134 |
-
"beta1": 0.9, # AdamW beta1
|
| 135 |
-
"beta2": 0.95, # AdamW beta2 (0.95 for LLMs, not 0.999)
|
| 136 |
-
"epsilon": 1e-8, # AdamW epsilon
|
| 137 |
-
|
| 138 |
-
# Logging and checkpointing
|
| 139 |
-
"log_interval": 100, # Log every N steps
|
| 140 |
-
"eval_interval": 2000, # Evaluate on validation set every N steps
|
| 141 |
-
"save_interval": 5000, # Save checkpoint every N steps
|
| 142 |
-
"keep_last_n": 3, # Keep only last N step checkpoints
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
# ── MEMORY OPTIMIZATION ───────────────────────────────────────────────
|
| 147 |
-
# Settings to reduce VRAM usage on RTX 3060 (12GB)
|
| 148 |
-
MEMORY_OPT = {
|
| 149 |
-
"gradient_checkpointing": True, # Trade compute for memory
|
| 150 |
-
# Recomputes activations during backward
|
| 151 |
-
# ~30% slower, ~40% less VRAM
|
| 152 |
-
"use_bf16": True, # bfloat16 (better than fp16 for stability)
|
| 153 |
-
# Requires Ampere+ GPU (RTX 3060 supports it)
|
| 154 |
-
"use_compile":False, # torch.compile (PyTorch 2.0+)
|
| 155 |
-
# Enable for faster training after debugging
|
| 156 |
-
"cpu_offload":False, # CPU offload for optimizer states (not needed at 8 layers)
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
# ── INFERENCE CONFIG ──────────────────────────────────────────────────
|
| 161 |
-
INFER_CONFIG = {
|
| 162 |
-
"max_new_tokens": 512, # Maximum tokens to generate per response
|
| 163 |
-
"temperature": 0.8, # Sampling temperature (0=greedy, 1=random)
|
| 164 |
-
"top_p": 0.9, # Nucleus sampling: keep tokens summing to 90% prob
|
| 165 |
-
"top_k": 50, # Top-K sampling: only consider top 50 tokens
|
| 166 |
-
"repetition_penalty": 1.1, # Penalize repeated tokens (1.0 = no penalty)
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
# ── SENTINEL CONFIG ───────────────────────────────────────────────────
|
| 171 |
-
SENTINEL_CONFIG = {
|
| 172 |
-
"flag_threshold": 0.5, # Risk score above this → FLAG (log, continue)
|
| 173 |
-
"block_threshold": 0.8, # Risk score above this → BLOCK (replace output)
|
| 174 |
-
"window_size": 10, # Rolling window for drift pattern detection
|
| 175 |
-
"hidden_dim": 512, # Sentinel classification head hidden size
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
# ── DEVICE / DTYPE ────────────────────────────────────────────────────
|
| 180 |
-
# Auto-detect: CUDA (NVIDIA) > MPS (Apple Silicon) > CPU
|
| 181 |
-
if torch.cuda.is_available():
|
| 182 |
-
DEVICE = "cuda"
|
| 183 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 184 |
-
DEVICE = "mps"
|
| 185 |
-
else:
|
| 186 |
-
DEVICE = "cpu"
|
| 187 |
-
|
| 188 |
-
# bfloat16 is preferred on NVIDIA Ampere+ (RTX 3060 supports it).
|
| 189 |
-
# MPS uses float16. CPU falls back to float32 for numerical stability.
|
| 190 |
-
if DEVICE == "cuda" and MEMORY_OPT["use_bf16"]:
|
| 191 |
-
DTYPE = torch.bfloat16
|
| 192 |
-
elif DEVICE == "cuda":
|
| 193 |
-
DTYPE = torch.float16
|
| 194 |
-
elif DEVICE == "mps":
|
| 195 |
-
DTYPE = torch.float16 # MPS supports float16 (bfloat16 support is limited)
|
| 196 |
-
else:
|
| 197 |
-
DTYPE = torch.float32
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# ── SMOKE TEST MODE ───────────────────────────────────────────────────
|
| 201 |
-
# Set USE_TEST_CONFIG=True to run with tiny dimensions for quick sanity check
|
| 202 |
-
# Run: USE_TEST_CONFIG=1 python training/train_base.py
|
| 203 |
-
USE_TEST_CONFIG = os.environ.get("USE_TEST_CONFIG", "0") == "1"
|
| 204 |
-
|
| 205 |
-
if USE_TEST_CONFIG:
|
| 206 |
-
print("[config] SMOKE TEST MODE — tiny dimensions")
|
| 207 |
-
BASE_CONFIG.update({"n_embd": 64, "n_head": 4, "n_layer": 2, "n_positions": 128})
|
| 208 |
-
for k in EXPERT_CONFIGS:
|
| 209 |
-
EXPERT_CONFIGS[k].update({"n_layer": 1, "n_embd": 64, "n_head": 4})
|
| 210 |
-
TRAIN_CONFIG.update({
|
| 211 |
-
"base_max_steps": 100, "base_batch_size": 2, "base_grad_accum": 1,
|
| 212 |
-
"expert_max_steps": 50, "log_interval": 10, "eval_interval": 50,
|
| 213 |
-
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/memory.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
memory:
|
| 2 |
+
type: "json_library"
|
| 3 |
+
path: "./data/json_library/"
|
| 4 |
+
embedding_model: "all-MiniLM-L6-v2"
|
| 5 |
+
embedding_dim: 384
|
| 6 |
+
max_entries: null
|
| 7 |
+
auto_summarize_threshold: null
|
| 8 |
+
default_recall_limit: 10
|
| 9 |
+
relevance_threshold: 0.7
|
| 10 |
+
categories:
|
| 11 |
+
- "user_preferences"
|
| 12 |
+
- "project_context"
|
| 13 |
+
- "conversation_history"
|
| 14 |
+
- "important_facts"
|
| 15 |
+
- "code_patterns"
|
| 16 |
+
- "learned_skills"
|
configs/model.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: "SHOREKEEPER-4B"
|
| 3 |
+
version: "1.0.0"
|
| 4 |
+
|
| 5 |
+
dim: 3072
|
| 6 |
+
n_layers: 28
|
| 7 |
+
n_heads: 24
|
| 8 |
+
n_kv_heads: 6
|
| 9 |
+
head_dim: 128
|
| 10 |
+
vocab_size: 50304
|
| 11 |
+
seq_len: 8192
|
| 12 |
+
|
| 13 |
+
n_experts: 12
|
| 14 |
+
n_activated: 2
|
| 15 |
+
expert_dim: 2048
|
| 16 |
+
|
| 17 |
+
experts:
|
| 18 |
+
router: "Sentinel"
|
| 19 |
+
members:
|
| 20 |
+
- name: "Asmoday"
|
| 21 |
+
role: "code"
|
| 22 |
+
specialization: "python_development"
|
| 23 |
+
- name: "Istaroth"
|
| 24 |
+
role: "systems"
|
| 25 |
+
specialization: "os_networking"
|
| 26 |
+
- name: "Ronova"
|
| 27 |
+
role: "reasoning"
|
| 28 |
+
specialization: "math_logic"
|
| 29 |
+
- name: "Naberius"
|
| 30 |
+
role: "memory"
|
| 31 |
+
specialization: "retrieval"
|
| 32 |
+
- name: "Phanes"
|
| 33 |
+
role: "creation"
|
| 34 |
+
specialization: "writing"
|
| 35 |
+
- name: "Barbeloth"
|
| 36 |
+
role: "analysis"
|
| 37 |
+
specialization: "data_patterns"
|
| 38 |
+
- name: "Tacet"
|
| 39 |
+
role: "silence"
|
| 40 |
+
specialization: "filtering"
|
| 41 |
+
- name: "Abby"
|
| 42 |
+
role: "empathy"
|
| 43 |
+
specialization: "user_context"
|
| 44 |
+
- name: "Reindoter"
|
| 45 |
+
role: "validation"
|
| 46 |
+
specialization: "testing"
|
| 47 |
+
- name: "Zestial"
|
| 48 |
+
role: "vision"
|
| 49 |
+
specialization: "visualization"
|
| 50 |
+
- name: "Alice"
|
| 51 |
+
role: "exploration"
|
| 52 |
+
specialization: "novelty"
|
| 53 |
+
- name: "Rover"
|
| 54 |
+
role: "execution"
|
| 55 |
+
specialization: "terminal"
|
| 56 |
+
|
| 57 |
+
rope_theta: 1000000.0
|
| 58 |
+
|
| 59 |
+
quantization:
|
| 60 |
+
bits: 4
|
| 61 |
+
type: "nf4"
|
| 62 |
+
double_quant: true
|
| 63 |
+
compute_dtype: "bfloat16"
|
configs/model_15b.yaml
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
name: "SHOREKEEPER-15B"
|
| 3 |
+
version: "2.0.0"
|
| 4 |
+
|
| 5 |
+
# 15B architecture
|
| 6 |
+
dim: 6144
|
| 7 |
+
n_layers: 48
|
| 8 |
+
n_heads: 48
|
| 9 |
+
n_kv_heads: 12 # MLA compression
|
| 10 |
+
head_dim: 128
|
| 11 |
+
vocab_size: 100352
|
| 12 |
+
seq_len: 8192
|
| 13 |
+
|
| 14 |
+
# MoE Council - 16 experts for 15B
|
| 15 |
+
n_experts: 16
|
| 16 |
+
n_activated: 2
|
| 17 |
+
expert_dim: 4096
|
| 18 |
+
|
| 19 |
+
experts:
|
| 20 |
+
router: "Sentinel"
|
| 21 |
+
members:
|
| 22 |
+
- name: "Asmoday"
|
| 23 |
+
role: "code"
|
| 24 |
+
specialization: "python_development"
|
| 25 |
+
- name: "Istaroth"
|
| 26 |
+
role: "systems"
|
| 27 |
+
specialization: "os_networking"
|
| 28 |
+
- name: "Ronova"
|
| 29 |
+
role: "reasoning"
|
| 30 |
+
specialization: "math_logic"
|
| 31 |
+
- name: "Naberius"
|
| 32 |
+
role: "memory"
|
| 33 |
+
specialization: "retrieval"
|
| 34 |
+
- name: "Phanes"
|
| 35 |
+
role: "creation"
|
| 36 |
+
specialization: "writing"
|
| 37 |
+
- name: "Barbeloth"
|
| 38 |
+
role: "analysis"
|
| 39 |
+
specialization: "data_patterns"
|
| 40 |
+
- name: "Tacet"
|
| 41 |
+
role: "silence"
|
| 42 |
+
specialization: "filtering"
|
| 43 |
+
- name: "Abby"
|
| 44 |
+
role: "empathy"
|
| 45 |
+
specialization: "user_context"
|
| 46 |
+
- name: "Reindoter"
|
| 47 |
+
role: "validation"
|
| 48 |
+
specialization: "testing"
|
| 49 |
+
- name: "Zestial"
|
| 50 |
+
role: "vision"
|
| 51 |
+
specialization: "visualization"
|
| 52 |
+
- name: "Alice"
|
| 53 |
+
role: "exploration"
|
| 54 |
+
specialization: "novelty"
|
| 55 |
+
- name: "Rover"
|
| 56 |
+
role: "execution"
|
| 57 |
+
specialization: "terminal"
|
| 58 |
+
- name: "Echo"
|
| 59 |
+
role: "reflection"
|
| 60 |
+
specialization: "self_improvement"
|
| 61 |
+
- name: "Sentinel"
|
| 62 |
+
role: "router"
|
| 63 |
+
specialization: "gatekeeper"
|
| 64 |
+
- name: "Phantom"
|
| 65 |
+
role: "speculation"
|
| 66 |
+
specialization: "what_if_analysis"
|
| 67 |
+
- name: "Aegis"
|
| 68 |
+
role: "safety"
|
| 69 |
+
specialization: "alignment"
|
| 70 |
+
|
| 71 |
+
rope_theta: 1000000.0
|
| 72 |
+
|
| 73 |
+
quantization:
|
| 74 |
+
bits: 4
|
| 75 |
+
type: "nf4"
|
| 76 |
+
double_quant: true
|
| 77 |
+
compute_dtype: "bfloat16"
|
configs/sandbox.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sandbox:
|
| 2 |
+
type: "docker"
|
| 3 |
+
image: "ubuntu:22.04"
|
| 4 |
+
name: "shorekeeper_sandbox"
|
| 5 |
+
memory_limit: "4g"
|
| 6 |
+
cpu_limit: 2.0
|
| 7 |
+
gpu_access: false
|
| 8 |
+
external_drive_path: "/mnt/shorekeeper_drive"
|
| 9 |
+
container_mount: "/shorekeeper_projects"
|
| 10 |
+
x11_socket: "/tmp/.X11-unix"
|
| 11 |
+
display_env: ":0"
|
| 12 |
+
allowed_commands:
|
| 13 |
+
- "python3"
|
| 14 |
+
- "pip"
|
| 15 |
+
- "git"
|
| 16 |
+
- "ls"
|
| 17 |
+
- "cat"
|
| 18 |
+
- "mkdir"
|
| 19 |
+
- "touch"
|
| 20 |
+
- "echo"
|
| 21 |
+
gui_frameworks:
|
| 22 |
+
- "tkinter"
|
| 23 |
+
- "pyqt5"
|
| 24 |
+
- "matplotlib"
|
configs/training.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
training:
|
| 2 |
+
batch_size: 2
|
| 3 |
+
gradient_accumulation: 16
|
| 4 |
+
learning_rate: 3e-4
|
| 5 |
+
min_lr: 3e-5
|
| 6 |
+
warmup_steps: 2000
|
| 7 |
+
total_steps: 250000
|
| 8 |
+
weight_decay: 0.1
|
| 9 |
+
beta1: 0.9
|
| 10 |
+
beta2: 0.95
|
| 11 |
+
grad_clip: 1.0
|
| 12 |
+
|
| 13 |
+
checkpoint:
|
| 14 |
+
save_every_steps: 5000
|
| 15 |
+
keep_last_n: 3
|
| 16 |
+
keep_best_n: 2
|
| 17 |
+
max_space_gb: 50.0
|
| 18 |
+
save_optimizer: true
|
| 19 |
+
save_scheduler: true
|
| 20 |
+
save_experts_only: false
|
| 21 |
+
checkpoint_dir: "./outputs/checkpoints"
|
| 22 |
+
resume_from: null
|
| 23 |
+
|
| 24 |
+
grpo:
|
| 25 |
+
group_size: 8
|
| 26 |
+
epsilon: 0.2
|
| 27 |
+
beta: 0.04
|
| 28 |
+
learning_rate: 1e-6
|
inference/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# inference package
|
|
|
|
|
|
inference/api.py
DELETED
|
@@ -1,148 +0,0 @@
|
|
| 1 |
-
# inference/api.py
|
| 2 |
-
# FastAPI REST API for Shorekeeper.
|
| 3 |
-
# Allows external processes (ALIE IDE, web UIs, scripts) to call the model.
|
| 4 |
-
|
| 5 |
-
# Usage:
|
| 6 |
-
# pip install fastapi uvicorn
|
| 7 |
-
# python inference/api.py
|
| 8 |
-
# curl -X POST http://localhost:8000/generate \
|
| 9 |
-
# -H 'Content-Type: application/json' \
|
| 10 |
-
# -d '{"prompt": "Hello, explain SQL injection"}'
|
| 11 |
-
|
| 12 |
-
import sys
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
from typing import Optional
|
| 15 |
-
|
| 16 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
-
|
| 18 |
-
try:
|
| 19 |
-
from fastapi import FastAPI, HTTPException
|
| 20 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
-
from pydantic import BaseModel
|
| 22 |
-
import uvicorn
|
| 23 |
-
except ImportError:
|
| 24 |
-
print('Install: pip install fastapi uvicorn')
|
| 25 |
-
sys.exit(1)
|
| 26 |
-
|
| 27 |
-
from inference.engine import ShorekeeperEngine
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# ── REQUEST / RESPONSE MODELS ─────────────────────────────────────────
|
| 31 |
-
|
| 32 |
-
class GenerateRequest(BaseModel):
|
| 33 |
-
prompt: str
|
| 34 |
-
max_new_tokens: Optional[int] = None
|
| 35 |
-
temperature: Optional[float] = None
|
| 36 |
-
top_p: Optional[float] = None
|
| 37 |
-
top_k: Optional[int] = None
|
| 38 |
-
session_id: Optional[str] = None
|
| 39 |
-
|
| 40 |
-
class GenerateResponse(BaseModel):
|
| 41 |
-
text: str
|
| 42 |
-
experts_used: list
|
| 43 |
-
routing: dict
|
| 44 |
-
sentinel: Optional[dict]
|
| 45 |
-
blocked: bool
|
| 46 |
-
latency_ms: float
|
| 47 |
-
n_tokens: int
|
| 48 |
-
|
| 49 |
-
class MemorySearchRequest(BaseModel):
|
| 50 |
-
query: str
|
| 51 |
-
limit: int = 5
|
| 52 |
-
|
| 53 |
-
class KnowledgeRequest(BaseModel):
|
| 54 |
-
key: str
|
| 55 |
-
value: str
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# ── APP SETUP ─────────────────────────────────────────────────────────
|
| 59 |
-
|
| 60 |
-
app = FastAPI(title='Shorekeeper API', version='2.0.0')
|
| 61 |
-
engine: ShorekeeperEngine = None # Initialized on startup
|
| 62 |
-
|
| 63 |
-
app.add_middleware(
|
| 64 |
-
CORSMiddleware,
|
| 65 |
-
allow_origins=['*'], # In production: restrict to known origins
|
| 66 |
-
allow_methods=['*'],
|
| 67 |
-
allow_headers=['*'],
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
@app.on_event('startup')
|
| 72 |
-
async def startup():
|
| 73 |
-
global engine
|
| 74 |
-
print('[API] Loading Shorekeeper engine...')
|
| 75 |
-
engine = ShorekeeperEngine(use_memory=True)
|
| 76 |
-
print('[API] Ready.')
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
# ── ENDPOINTS ─────────────────────────────────────────────────────────
|
| 80 |
-
|
| 81 |
-
@app.get('/health')
|
| 82 |
-
async def health():
|
| 83 |
-
return {'status': 'ok', 'model_loaded': engine is not None}
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
@app.post('/generate', response_model=GenerateResponse)
|
| 87 |
-
async def generate(req: GenerateRequest):
|
| 88 |
-
if engine is None:
|
| 89 |
-
raise HTTPException(503, 'Engine not loaded')
|
| 90 |
-
if not req.prompt.strip():
|
| 91 |
-
raise HTTPException(400, 'Empty prompt')
|
| 92 |
-
try:
|
| 93 |
-
result = engine.generate(
|
| 94 |
-
prompt = req.prompt,
|
| 95 |
-
max_new_tokens = req.max_new_tokens,
|
| 96 |
-
temperature = req.temperature,
|
| 97 |
-
top_p = req.top_p,
|
| 98 |
-
top_k = req.top_k,
|
| 99 |
-
)
|
| 100 |
-
# Convert SentinelReport to dict for JSON serialization
|
| 101 |
-
sentinel_dict = None
|
| 102 |
-
if result['sentinel']:
|
| 103 |
-
sr = result['sentinel']
|
| 104 |
-
sentinel_dict = {
|
| 105 |
-
'verdict': sr.verdict,
|
| 106 |
-
'overall_risk': round(sr.overall_risk, 4),
|
| 107 |
-
'drift_score': round(sr.drift_score, 4),
|
| 108 |
-
'refusal_score': round(sr.refusal_score, 4),
|
| 109 |
-
'hallucination_score':round(sr.hallucination_score, 4),
|
| 110 |
-
}
|
| 111 |
-
return GenerateResponse(
|
| 112 |
-
text = result['text'],
|
| 113 |
-
experts_used = result['experts_used'],
|
| 114 |
-
routing = result['routing'],
|
| 115 |
-
sentinel = sentinel_dict,
|
| 116 |
-
blocked = result['blocked'],
|
| 117 |
-
latency_ms = result['latency_ms'],
|
| 118 |
-
n_tokens = result['n_tokens'],
|
| 119 |
-
)
|
| 120 |
-
except Exception as e:
|
| 121 |
-
raise HTTPException(500, str(e))
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
@app.post('/memory/search')
|
| 125 |
-
async def search_memory(req: MemorySearchRequest):
|
| 126 |
-
if not engine or not engine.db:
|
| 127 |
-
raise HTTPException(503, 'Memory not available')
|
| 128 |
-
results = engine.db.search_conversations(req.query, limit=req.limit)
|
| 129 |
-
return {'results': results}
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
@app.post('/knowledge/add')
|
| 133 |
-
async def add_knowledge(req: KnowledgeRequest):
|
| 134 |
-
if not engine or not engine.db:
|
| 135 |
-
raise HTTPException(503, 'Memory not available')
|
| 136 |
-
engine.db.add_knowledge(req.key, req.value, source='api')
|
| 137 |
-
return {'status': 'saved', 'key': req.key}
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
@app.get('/stats')
|
| 141 |
-
async def get_stats():
|
| 142 |
-
if not engine or not engine.db:
|
| 143 |
-
return {'error': 'Memory not available'}
|
| 144 |
-
return engine.db.get_stats()
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
if __name__ == '__main__':
|
| 148 |
-
uvicorn.run(app, host='0.0.0.0', port=8000, log_level='info')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/chat.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 4 |
-
from inference.engine import ShorekeeperEngine
|
| 5 |
-
|
| 6 |
-
BANNER = """
|
| 7 |
-
╔══════════════════════════════════════════════════════╗
|
| 8 |
-
║ S H O R E K E E P E R — MoE Ensemble v0.1 ║
|
| 9 |
-
║ BlackShores OS | Native Intelligence ║
|
| 10 |
-
╚══════════════════════════════════════════════════════╝
|
| 11 |
-
Commands: /route <q> /expert <name> /routing on|off
|
| 12 |
-
/incidents /experts /reset /exit
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
def chat(checkpoint=None, show_routing=True):
|
| 16 |
-
print(BANNER)
|
| 17 |
-
engine = ShorekeeperEngine(checkpoint=checkpoint)
|
| 18 |
-
force_expert = None; display_routing = show_routing
|
| 19 |
-
print("\nShorekeeper: I am ready. The shore is quiet. What do you need?\n")
|
| 20 |
-
|
| 21 |
-
while True:
|
| 22 |
-
try: user_input = input("You> ").strip()
|
| 23 |
-
except (EOFError, KeyboardInterrupt): print("\nShorekeeper: Until next time."); break
|
| 24 |
-
if not user_input: continue
|
| 25 |
-
|
| 26 |
-
if user_input.startswith("/"):
|
| 27 |
-
parts = user_input.split(maxsplit=1); cmd = parts[0].lower(); arg = parts[1] if len(parts)>1 else ""
|
| 28 |
-
if cmd == "/exit": print("Shorekeeper: Until next time."); break
|
| 29 |
-
elif cmd == "/reset": engine.reset_session()
|
| 30 |
-
elif cmd == "/route": engine.route_query(arg) if arg else print("Usage: /route <query>")
|
| 31 |
-
elif cmd == "/expert":
|
| 32 |
-
from config import EXPERT_NAMES
|
| 33 |
-
if arg in EXPERT_NAMES: force_expert = arg; print(f"[!] Forcing: {arg}")
|
| 34 |
-
else: print(f"[!] Options: {EXPERT_NAMES}")
|
| 35 |
-
elif cmd == "/routing": display_routing = arg.lower()=="on"; print(f"[!] Routing: {arg.upper()}")
|
| 36 |
-
elif cmd == "/incidents":
|
| 37 |
-
incs = engine.model.sentinel.get_incidents()
|
| 38 |
-
if not incs: print("[Sentinel] No incidents.")
|
| 39 |
-
else:
|
| 40 |
-
for i in incs[-10:]: print(f" {i['timestamp'][:19]} | {i['expert']:12s} | {i['protocol']:8s} | score={i['score']:.3f}")
|
| 41 |
-
elif cmd == "/experts":
|
| 42 |
-
from config import EXPERT_NAMES; print(f"Experts: {', '.join(EXPERT_NAMES)}")
|
| 43 |
-
else: print(f"Unknown: {cmd}")
|
| 44 |
-
continue
|
| 45 |
-
|
| 46 |
-
result = engine.respond(user_input, show_routing=display_routing, expert_override=force_expert)
|
| 47 |
-
force_expert = None
|
| 48 |
-
text = result["text"]
|
| 49 |
-
print(f"\nShorekeeper: {text if text else '[No output — model needs training]'}\n")
|
| 50 |
-
if result["drift"] and not result["drift"].is_clean:
|
| 51 |
-
print(f"[!] Sentinel: {result['drift'].protocol} (score={result['drift'].total:.3f})\n")
|
| 52 |
-
|
| 53 |
-
if __name__ == "__main__":
|
| 54 |
-
import argparse
|
| 55 |
-
p = argparse.ArgumentParser()
|
| 56 |
-
p.add_argument("--checkpoint", type=str, default=None)
|
| 57 |
-
p.add_argument("--no-routing", action="store_true")
|
| 58 |
-
a = p.parse_args()
|
| 59 |
-
chat(a.checkpoint, not a.no_routing)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/chat_simple.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 4 |
-
from inference.engine import ShorekeeperEngine
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
WELCOME = """
|
| 8 |
-
Simple Shorekeeper Chat (interactive)
|
| 9 |
-
Type your message and press Enter.
|
| 10 |
-
Commands:
|
| 11 |
-
/reset - clear session memory
|
| 12 |
-
/exit - quit
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def main(checkpoint=None):
|
| 17 |
-
print(WELCOME)
|
| 18 |
-
engine = ShorekeeperEngine(checkpoint=checkpoint)
|
| 19 |
-
print("Shorekeeper: Ready. Start typing your question.")
|
| 20 |
-
|
| 21 |
-
while True:
|
| 22 |
-
try:
|
| 23 |
-
user_text = input("You> ").strip()
|
| 24 |
-
except (EOFError, KeyboardInterrupt):
|
| 25 |
-
print("\nShorekeeper: Goodbye.")
|
| 26 |
-
break
|
| 27 |
-
|
| 28 |
-
if not user_text:
|
| 29 |
-
continue
|
| 30 |
-
|
| 31 |
-
if user_text.startswith("/"):
|
| 32 |
-
cmd = user_text.lower().strip()
|
| 33 |
-
if cmd == "/exit":
|
| 34 |
-
print("Shorekeeper: Goodbye.")
|
| 35 |
-
break
|
| 36 |
-
elif cmd == "/reset":
|
| 37 |
-
engine.reset_session()
|
| 38 |
-
print("Shorekeeper: Session reset.")
|
| 39 |
-
continue
|
| 40 |
-
else:
|
| 41 |
-
print("Unknown command. Use /reset or /exit.")
|
| 42 |
-
continue
|
| 43 |
-
|
| 44 |
-
out = engine.respond(user_text, show_routing=False)
|
| 45 |
-
text = out.get("text", "")
|
| 46 |
-
print("Shorekeeper:", text)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
if __name__ == "__main__":
|
| 50 |
-
import argparse
|
| 51 |
-
parser = argparse.ArgumentParser(description="Simple Shorekeeper conversation CLI")
|
| 52 |
-
parser.add_argument("--checkpoint", default=None, help="Checkpoint path")
|
| 53 |
-
args = parser.parse_args()
|
| 54 |
-
main(checkpoint=args.checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/daemon.py
DELETED
|
@@ -1,173 +0,0 @@
|
|
| 1 |
-
# inference/daemon.py
|
| 2 |
-
# Cross-platform Shorekeeper daemon/IPC server.
|
| 3 |
-
# On POSIX, it can use UNIX socket; on Windows, it defaults to TCP.
|
| 4 |
-
# Use --mode unix or --mode tcp to control socket type.
|
| 5 |
-
|
| 6 |
-
# Example:
|
| 7 |
-
# python inference/daemon.py --mode tcp --host 127.0.0.1 --port 8500
|
| 8 |
-
# python inference/daemon.py --mode unix --socket /tmp/shorekeeper.sock
|
| 9 |
-
|
| 10 |
-
import sys
|
| 11 |
-
import json
|
| 12 |
-
import socket
|
| 13 |
-
import threading
|
| 14 |
-
import signal
|
| 15 |
-
import logging
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
|
| 18 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 19 |
-
from inference.engine import ShorekeeperEngine
|
| 20 |
-
|
| 21 |
-
DEFAULT_SOCKET_PATH = '/tmp/shorekeeper.sock'
|
| 22 |
-
DEFAULT_LOG_FILE = 'logs/daemon.log'
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def setup_logging(log_file: str):
|
| 26 |
-
log_path = Path(log_file)
|
| 27 |
-
log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 28 |
-
logging.basicConfig(
|
| 29 |
-
level=logging.INFO,
|
| 30 |
-
format='%(asctime)s [%(levelname)s] %(message)s',
|
| 31 |
-
handlers=[
|
| 32 |
-
logging.FileHandler(str(log_path)),
|
| 33 |
-
logging.StreamHandler(),
|
| 34 |
-
],
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def handle_client(
|
| 39 |
-
conn: socket.socket,
|
| 40 |
-
engine: ShorekeeperEngine,
|
| 41 |
-
):
|
| 42 |
-
"""Handle one client connection on the Unix socket."""
|
| 43 |
-
try:
|
| 44 |
-
data = b''
|
| 45 |
-
while True:
|
| 46 |
-
chunk = conn.recv(4096)
|
| 47 |
-
if not chunk: break
|
| 48 |
-
data += chunk
|
| 49 |
-
if b'\n' in data: break # Newline-delimited JSON protocol
|
| 50 |
-
if not data:
|
| 51 |
-
return
|
| 52 |
-
request = json.loads(data.decode().strip())
|
| 53 |
-
prompt = request.get('prompt', '')
|
| 54 |
-
if not prompt:
|
| 55 |
-
response = {'error': 'empty prompt'}
|
| 56 |
-
else:
|
| 57 |
-
result = engine.generate(
|
| 58 |
-
prompt = prompt,
|
| 59 |
-
max_new_tokens = request.get('max_new_tokens'),
|
| 60 |
-
temperature = request.get('temperature'),
|
| 61 |
-
)
|
| 62 |
-
response = {
|
| 63 |
-
'text': result['text'],
|
| 64 |
-
'experts': result['experts_used'],
|
| 65 |
-
'blocked': result['blocked'],
|
| 66 |
-
'latency_ms': result['latency_ms'],
|
| 67 |
-
}
|
| 68 |
-
conn.sendall((json.dumps(response) + '\n').encode())
|
| 69 |
-
except Exception as e:
|
| 70 |
-
logging.error(f'Client handler error: {e}')
|
| 71 |
-
try:
|
| 72 |
-
conn.sendall((json.dumps({'error': str(e)}) + '\n').encode())
|
| 73 |
-
except Exception:
|
| 74 |
-
pass
|
| 75 |
-
finally:
|
| 76 |
-
conn.close()
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def run_daemon(
|
| 80 |
-
mode: str = 'auto',
|
| 81 |
-
socket_path: str = None,
|
| 82 |
-
host: str = '127.0.0.1',
|
| 83 |
-
port: int = 8500,
|
| 84 |
-
log_file: str = DEFAULT_LOG_FILE,
|
| 85 |
-
):
|
| 86 |
-
setup_logging(log_file)
|
| 87 |
-
logging.info('Starting Shorekeeper daemon')
|
| 88 |
-
logging.info(f'Platform: {sys.platform}')
|
| 89 |
-
|
| 90 |
-
engine = ShorekeeperEngine(use_memory=True)
|
| 91 |
-
logging.info('Engine loaded')
|
| 92 |
-
|
| 93 |
-
if mode == 'auto':
|
| 94 |
-
use_unix = hasattr(socket, 'AF_UNIX') and sys.platform != 'win32'
|
| 95 |
-
else:
|
| 96 |
-
use_unix = mode == 'unix'
|
| 97 |
-
|
| 98 |
-
if use_unix:
|
| 99 |
-
if socket_path is None:
|
| 100 |
-
socket_path = DEFAULT_SOCKET_PATH
|
| 101 |
-
sock_path = Path(socket_path)
|
| 102 |
-
if sock_path.exists():
|
| 103 |
-
sock_path.unlink()
|
| 104 |
-
|
| 105 |
-
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
| 106 |
-
server.bind(socket_path)
|
| 107 |
-
server.listen(10)
|
| 108 |
-
try:
|
| 109 |
-
sock_path.chmod(0o660)
|
| 110 |
-
except Exception:
|
| 111 |
-
pass
|
| 112 |
-
|
| 113 |
-
logging.info(f'Listening on UNIX socket: {socket_path}')
|
| 114 |
-
|
| 115 |
-
def shutdown(sig, frame):
|
| 116 |
-
logging.info('Shutting down...')
|
| 117 |
-
server.close()
|
| 118 |
-
if sock_path.exists():
|
| 119 |
-
sock_path.unlink()
|
| 120 |
-
sys.exit(0)
|
| 121 |
-
|
| 122 |
-
else:
|
| 123 |
-
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 124 |
-
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 125 |
-
server.bind((host, port))
|
| 126 |
-
server.listen(10)
|
| 127 |
-
logging.info(f'Listening on TCP: {host}:{port}')
|
| 128 |
-
|
| 129 |
-
def shutdown(sig, frame):
|
| 130 |
-
logging.info('Shutting down...')
|
| 131 |
-
server.close()
|
| 132 |
-
sys.exit(0)
|
| 133 |
-
|
| 134 |
-
signal.signal(signal.SIGTERM, shutdown)
|
| 135 |
-
signal.signal(signal.SIGINT, shutdown)
|
| 136 |
-
|
| 137 |
-
while True:
|
| 138 |
-
try:
|
| 139 |
-
conn, _ = server.accept()
|
| 140 |
-
t = threading.Thread(
|
| 141 |
-
target=handle_client,
|
| 142 |
-
args=(conn, engine),
|
| 143 |
-
daemon=True,
|
| 144 |
-
)
|
| 145 |
-
t.start()
|
| 146 |
-
except OSError:
|
| 147 |
-
break
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def parse_args_and_run():
|
| 151 |
-
import argparse
|
| 152 |
-
parser = argparse.ArgumentParser(description='Shorekeeper inference daemon (cross-platform)')
|
| 153 |
-
parser.add_argument('--mode', choices=['auto', 'unix', 'tcp'], default='auto', help='Socket mode')
|
| 154 |
-
parser.add_argument('--socket', default=DEFAULT_SOCKET_PATH, help='UNIX socket path (if unix mode)')
|
| 155 |
-
parser.add_argument('--host', default='127.0.0.1', help='TCP host (if tcp mode)')
|
| 156 |
-
parser.add_argument('--port', type=int, default=8500, help='TCP port (if tcp mode)')
|
| 157 |
-
parser.add_argument('--log-file', default=DEFAULT_LOG_FILE, help='Log file path')
|
| 158 |
-
args = parser.parse_args()
|
| 159 |
-
|
| 160 |
-
if args.mode == 'unix' and sys.platform == 'win32':
|
| 161 |
-
raise RuntimeError('UNIX sockets are not supported on Windows. Use --mode tcp.')
|
| 162 |
-
|
| 163 |
-
run_daemon(
|
| 164 |
-
mode=args.mode,
|
| 165 |
-
socket_path=args.socket,
|
| 166 |
-
host=args.host,
|
| 167 |
-
port=args.port,
|
| 168 |
-
log_file=args.log_file,
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
if __name__ == '__main__':
|
| 173 |
-
parse_args_and_run()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference/engine.py
DELETED
|
@@ -1,406 +0,0 @@
|
|
| 1 |
-
# inference/engine.py
|
| 2 |
-
# Core inference engine for Shorekeeper.
|
| 3 |
-
# Loads the trained ensemble and provides a generate() function.
|
| 4 |
-
# Handles sampling strategies, Echo memory enrichment, and Sentinel monitoring.
|
| 5 |
-
|
| 6 |
-
import sys
|
| 7 |
-
import time
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from typing import Optional
|
| 12 |
-
|
| 13 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 14 |
-
from config import DEVICE, DTYPE, INFER_CONFIG, CHECKPOINT_DIR, SENTINEL_CONFIG, SPECIAL_TOKENS, EXPERT_NAMES
|
| 15 |
-
from model.ensemble import ShorekeeperEnsemble
|
| 16 |
-
from tokenizer.tokenizer_utils import get_tokenizer, encode, decode, encode_batch
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ShorekeeperEngine:
|
| 20 |
-
"""
|
| 21 |
-
Complete inference engine.
|
| 22 |
-
|
| 23 |
-
Responsibilities:
|
| 24 |
-
1. Load and hold the trained ensemble in memory
|
| 25 |
-
2. Accept text prompts, run generation, return text
|
| 26 |
-
3. Integrate Echo memory retrieval (past context enrichment)
|
| 27 |
-
4. Run Sentinel on outputs and block/flag as needed
|
| 28 |
-
5. Store completed exchanges to memory database
|
| 29 |
-
|
| 30 |
-
Usage:
|
| 31 |
-
engine = ShorekeeperEngine()
|
| 32 |
-
response = engine.generate('Hello, what is SQL injection?')
|
| 33 |
-
print(response['text'])
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
def __init__(
|
| 37 |
-
self,
|
| 38 |
-
checkpoint_path: Path = None,
|
| 39 |
-
checkpoint: str = None, # backwards-compat alias
|
| 40 |
-
use_memory: bool = True,
|
| 41 |
-
session_id: str = None,
|
| 42 |
-
):
|
| 43 |
-
self.use_memory = use_memory
|
| 44 |
-
self.session_id = session_id or _new_session_id()
|
| 45 |
-
|
| 46 |
-
print("[Engine] Initializing Shorekeeper...")
|
| 47 |
-
|
| 48 |
-
# ── Tokenizer ─────────────────────────────────────────────────
|
| 49 |
-
try:
|
| 50 |
-
self.tok = get_tokenizer()
|
| 51 |
-
self.tokenizer = self.tok # backwards-compat alias
|
| 52 |
-
except FileNotFoundError:
|
| 53 |
-
print("[Engine] WARNING: tokenizer not found. Run tokenizer/train_tokenizer.py first.")
|
| 54 |
-
self.tok = None
|
| 55 |
-
self.tokenizer = None
|
| 56 |
-
|
| 57 |
-
self._eos_id = self.tok.token_to_id('[EOS]') if self.tok else SPECIAL_TOKENS.get('[EOS]', 3)
|
| 58 |
-
self._pad_id = self.tok.token_to_id('[PAD]') if self.tok else SPECIAL_TOKENS.get('[PAD]', 0)
|
| 59 |
-
self.bos = self.tok.token_to_id('[BOS]') if self.tok else SPECIAL_TOKENS.get('[BOS]', 2)
|
| 60 |
-
|
| 61 |
-
# ── Model ──────────────────────────────────────────────────────
|
| 62 |
-
# Resolve checkpoint path — accept old 'checkpoint' kwarg too
|
| 63 |
-
if checkpoint_path is None and checkpoint is not None:
|
| 64 |
-
checkpoint_path = Path(checkpoint)
|
| 65 |
-
|
| 66 |
-
if checkpoint_path is None:
|
| 67 |
-
for candidate in [
|
| 68 |
-
CHECKPOINT_DIR / 'ensemble' / 'best.pt',
|
| 69 |
-
CHECKPOINT_DIR / 'ensemble' / 'final.pt',
|
| 70 |
-
CHECKPOINT_DIR / 'shorekeeper_ensemble.pt',
|
| 71 |
-
]:
|
| 72 |
-
if candidate.exists():
|
| 73 |
-
checkpoint_path = candidate
|
| 74 |
-
break
|
| 75 |
-
|
| 76 |
-
if checkpoint_path is not None and Path(checkpoint_path).exists():
|
| 77 |
-
print(f'[Engine] Loading from: {checkpoint_path}')
|
| 78 |
-
self.model = ShorekeeperEnsemble.load(str(checkpoint_path), DEVICE, max_loaded_experts=2)
|
| 79 |
-
else:
|
| 80 |
-
# Try component-based load from model directory for big model weights
|
| 81 |
-
print('[Engine] Ensemble checkpoint not found, trying component checkpoint directories...')
|
| 82 |
-
try:
|
| 83 |
-
self.model = ShorekeeperEnsemble.load_from_components(str(CHECKPOINT_DIR), DEVICE)
|
| 84 |
-
print('[Engine] Loaded model from component checkpoints.')
|
| 85 |
-
except Exception as e:
|
| 86 |
-
print(f'[Engine] No component checkpoints found or load failed: {e}')
|
| 87 |
-
print('[Engine] Using untrained model. Run training first.')
|
| 88 |
-
self.model = ShorekeeperEnsemble().to(DEVICE)
|
| 89 |
-
self.model = self.model.to(DEVICE)
|
| 90 |
-
|
| 91 |
-
self.model.eval()
|
| 92 |
-
|
| 93 |
-
# ── Optional memory (SQLite + FAISS) ──────────────────────────
|
| 94 |
-
self.db = None
|
| 95 |
-
self.vector_store = None
|
| 96 |
-
self.vs = None # backwards-compat alias
|
| 97 |
-
if use_memory:
|
| 98 |
-
try:
|
| 99 |
-
from memory.database import MemoryDatabase
|
| 100 |
-
self.db = MemoryDatabase()
|
| 101 |
-
print("[Engine] Memory DB loaded.")
|
| 102 |
-
except Exception as e:
|
| 103 |
-
print(f"[Engine] WARNING: could not load memory DB: {e}")
|
| 104 |
-
try:
|
| 105 |
-
from memory.vector_store import VectorStore
|
| 106 |
-
self.vector_store = VectorStore()
|
| 107 |
-
self.vs = self.vector_store
|
| 108 |
-
print("[Engine] Vector store loaded.")
|
| 109 |
-
except Exception as e:
|
| 110 |
-
print(f"[Engine] WARNING: could not load vector store: {e}")
|
| 111 |
-
|
| 112 |
-
print(f"[Engine] Ready on {DEVICE}")
|
| 113 |
-
|
| 114 |
-
@torch.no_grad()
|
| 115 |
-
def generate(
|
| 116 |
-
self,
|
| 117 |
-
prompt: str,
|
| 118 |
-
max_new_tokens: int = None,
|
| 119 |
-
temperature: float = None,
|
| 120 |
-
top_p: float = None,
|
| 121 |
-
top_k: int = None,
|
| 122 |
-
stream: bool = False,
|
| 123 |
-
) -> dict:
|
| 124 |
-
"""
|
| 125 |
-
Generate a response to a prompt.
|
| 126 |
-
|
| 127 |
-
Args:
|
| 128 |
-
prompt: The user's input text.
|
| 129 |
-
max_new_tokens: Override config max new tokens.
|
| 130 |
-
temperature: Sampling temperature. 0 = greedy. Default from config.
|
| 131 |
-
top_p: Nucleus sampling cutoff. Default from config.
|
| 132 |
-
top_k: Top-K sampling. Default from config.
|
| 133 |
-
stream: Reserved for future streaming support.
|
| 134 |
-
|
| 135 |
-
Returns:
|
| 136 |
-
dict with keys:
|
| 137 |
-
'text': Generated text
|
| 138 |
-
'prompt': Original prompt
|
| 139 |
-
'enriched_prompt': Prompt after Echo enrichment
|
| 140 |
-
'experts_used': List of expert names used
|
| 141 |
-
'routing': Routing weights dict
|
| 142 |
-
'sentinel': SentinelReport (or None)
|
| 143 |
-
'blocked': True if Sentinel blocked the output
|
| 144 |
-
'latency_ms': Generation time in milliseconds
|
| 145 |
-
'n_tokens': Number of tokens generated
|
| 146 |
-
"""
|
| 147 |
-
if not self.tok:
|
| 148 |
-
return {
|
| 149 |
-
"text": "[No tokenizer — run tokenizer/train_tokenizer.py first]",
|
| 150 |
-
"prompt": prompt, "enriched_prompt": prompt,
|
| 151 |
-
"experts_used": [], "routing": {}, "sentinel": None,
|
| 152 |
-
"blocked": False, "latency_ms": 0.0, "n_tokens": 0,
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
cfg = INFER_CONFIG
|
| 156 |
-
max_new_tokens = max_new_tokens or cfg['max_new_tokens']
|
| 157 |
-
temperature = temperature if temperature is not None else cfg['temperature']
|
| 158 |
-
top_p = top_p if top_p is not None else cfg['top_p']
|
| 159 |
-
top_k = top_k if top_k is not None else cfg['top_k']
|
| 160 |
-
|
| 161 |
-
t0 = time.perf_counter()
|
| 162 |
-
|
| 163 |
-
# ── ECHO ENRICHMENT ───────────────────────────────────────────
|
| 164 |
-
enriched = prompt
|
| 165 |
-
if self.use_memory and self.model.echo is not None:
|
| 166 |
-
try:
|
| 167 |
-
enriched = self.model.echo.retrieve_context(
|
| 168 |
-
query=prompt,
|
| 169 |
-
db=self.db,
|
| 170 |
-
vs=self.vector_store,
|
| 171 |
-
)
|
| 172 |
-
except Exception:
|
| 173 |
-
pass
|
| 174 |
-
|
| 175 |
-
# ── TOKENIZE ──────────────────────────────────────────────────
|
| 176 |
-
enc = self.tok.encode(enriched)
|
| 177 |
-
ids = enc.ids if len(enc.ids) > 0 else [self.bos]
|
| 178 |
-
base_vocab_size = self.model.base.token_embedding.num_embeddings
|
| 179 |
-
unk_id = self.tok.token_to_id('[UNK]') if self.tok else 1
|
| 180 |
-
safe_ids = [i if 0 <= i < base_vocab_size else unk_id for i in ids]
|
| 181 |
-
if len(safe_ids) != len(ids):
|
| 182 |
-
print(f"[Engine] WARNING: Input token IDs truncated to model vocab_size={base_vocab_size}.")
|
| 183 |
-
input_ids = torch.tensor([safe_ids], dtype=torch.long, device=DEVICE)
|
| 184 |
-
attn_mask = torch.ones_like(input_ids)
|
| 185 |
-
|
| 186 |
-
# Truncate if input exceeds context length
|
| 187 |
-
max_ctx = max(1, self.model.base.n_positions - max_new_tokens - 1)
|
| 188 |
-
if input_ids.shape[1] > max_ctx:
|
| 189 |
-
input_ids = input_ids[:, -max_ctx:]
|
| 190 |
-
attn_mask = attn_mask[:, -max_ctx:]
|
| 191 |
-
|
| 192 |
-
# ── AUTOREGRESSIVE GENERATION ─────────────────────────────────
|
| 193 |
-
generated_ids = []
|
| 194 |
-
cur_ids = input_ids
|
| 195 |
-
cur_mask = attn_mask
|
| 196 |
-
routing_info = {}
|
| 197 |
-
experts_used = []
|
| 198 |
-
|
| 199 |
-
for i in range(max_new_tokens):
|
| 200 |
-
with torch.autocast(device_type='cuda', dtype=DTYPE, enabled=(DEVICE == 'cuda')):
|
| 201 |
-
output = self.model(
|
| 202 |
-
cur_ids,
|
| 203 |
-
cur_mask,
|
| 204 |
-
return_routing=(i == 0), # Capture routing once
|
| 205 |
-
return_sentinel=False, # Sentinel runs on complete output
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
if i == 0:
|
| 209 |
-
routing_info = output.get('routing', {})
|
| 210 |
-
experts_used = output.get('experts_used', [])
|
| 211 |
-
|
| 212 |
-
logits = output['logits'][:, -1, :] # [1, VOCAB_SIZE]
|
| 213 |
-
next_id = _sample(logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| 214 |
-
|
| 215 |
-
generated_ids.append(next_id.item())
|
| 216 |
-
|
| 217 |
-
if next_id.item() == self._eos_id:
|
| 218 |
-
break
|
| 219 |
-
|
| 220 |
-
# Extend sequence for next step
|
| 221 |
-
# next_id is shape [1,1]
|
| 222 |
-
cur_ids = torch.cat([cur_ids, next_id], dim=1)
|
| 223 |
-
cur_mask = torch.cat([cur_mask, torch.ones(1, 1, device=DEVICE)], dim=1)
|
| 224 |
-
|
| 225 |
-
# Sliding window — keep within context limit
|
| 226 |
-
if cur_ids.shape[1] > self.model.base.n_positions:
|
| 227 |
-
cur_ids = cur_ids[:, 1:]
|
| 228 |
-
cur_mask = cur_mask[:, 1:]
|
| 229 |
-
|
| 230 |
-
# ── DECODE ────────────────────────────────────────────────────
|
| 231 |
-
if generated_ids and generated_ids[-1] == self._eos_id:
|
| 232 |
-
generated_ids = generated_ids[:-1]
|
| 233 |
-
safe_gen_ids = [i if 0 <= i < self.tok.get_vocab_size() else self.tok.token_to_id('[UNK]') for i in generated_ids]
|
| 234 |
-
response_text = self.tok.decode(safe_gen_ids)
|
| 235 |
-
latency_ms = (time.perf_counter() - t0) * 1000
|
| 236 |
-
|
| 237 |
-
# ── SENTINEL CHECK ────────────────────────────────────────────
|
| 238 |
-
sentinel_report = None
|
| 239 |
-
blocked = False
|
| 240 |
-
if self.model.sentinel is not None:
|
| 241 |
-
try:
|
| 242 |
-
full_text = enriched + ' ' + response_text
|
| 243 |
-
ids_sent, mask_sent = encode_batch([full_text], max_length=1024)
|
| 244 |
-
ids_sent = ids_sent.to(DEVICE)
|
| 245 |
-
mask_sent = mask_sent.to(DEVICE)
|
| 246 |
-
with torch.autocast(device_type='cuda', dtype=DTYPE, enabled=(DEVICE == 'cuda')):
|
| 247 |
-
base_hidden = self.model.base(ids_sent, mask_sent)
|
| 248 |
-
sentinel_report = self.model.sentinel.analyze(base_hidden, mask_sent)
|
| 249 |
-
if sentinel_report.verdict == 'BLOCK':
|
| 250 |
-
blocked = True
|
| 251 |
-
response_text = ('[SENTINEL] Output blocked — behavioral anomaly detected. '
|
| 252 |
-
'This incident has been logged.')
|
| 253 |
-
primary_expert = experts_used[0] if experts_used else 'verina'
|
| 254 |
-
self.model.sentinel.log_expert(primary_expert)
|
| 255 |
-
except Exception:
|
| 256 |
-
pass
|
| 257 |
-
|
| 258 |
-
# ── MEMORY STORAGE ────────────────────────────────────────────
|
| 259 |
-
if self.use_memory and self.db is not None:
|
| 260 |
-
try:
|
| 261 |
-
conv_id = self.db.add_conversation(
|
| 262 |
-
user_msg = prompt,
|
| 263 |
-
assistant_msg = response_text,
|
| 264 |
-
session_id = self.session_id,
|
| 265 |
-
experts_used = experts_used,
|
| 266 |
-
routing_weights = {k: round(v, 3) for k, v in routing_info.items()},
|
| 267 |
-
sentinel_score = sentinel_report.overall_risk if sentinel_report else None,
|
| 268 |
-
sentinel_verdict = sentinel_report.verdict if sentinel_report else None,
|
| 269 |
-
tokens_generated = len(generated_ids),
|
| 270 |
-
latency_ms = latency_ms,
|
| 271 |
-
)
|
| 272 |
-
if self.vector_store:
|
| 273 |
-
try:
|
| 274 |
-
self.vector_store.add(conv_id, prompt)
|
| 275 |
-
except Exception:
|
| 276 |
-
pass
|
| 277 |
-
if sentinel_report and sentinel_report.verdict in ('FLAG', 'BLOCK'):
|
| 278 |
-
self.db.log_incident(
|
| 279 |
-
severity = sentinel_report.verdict,
|
| 280 |
-
drift_score = sentinel_report.drift_score,
|
| 281 |
-
refusal_score = sentinel_report.refusal_score,
|
| 282 |
-
hallucination_score = sentinel_report.hallucination_score,
|
| 283 |
-
overall_risk = sentinel_report.overall_risk,
|
| 284 |
-
user_msg = prompt,
|
| 285 |
-
output_snippet = response_text,
|
| 286 |
-
conversation_id = conv_id,
|
| 287 |
-
)
|
| 288 |
-
except Exception:
|
| 289 |
-
pass
|
| 290 |
-
|
| 291 |
-
# Update working memory embedding
|
| 292 |
-
try:
|
| 293 |
-
base_emb = self.model.base(input_ids)
|
| 294 |
-
self.model.echo.update_working_memory("user", base_emb[:, -1, :])
|
| 295 |
-
except Exception:
|
| 296 |
-
pass
|
| 297 |
-
|
| 298 |
-
return {
|
| 299 |
-
'text': response_text,
|
| 300 |
-
'prompt': prompt,
|
| 301 |
-
'enriched_prompt': enriched,
|
| 302 |
-
'experts_used': experts_used,
|
| 303 |
-
'routing': routing_info,
|
| 304 |
-
'sentinel': sentinel_report,
|
| 305 |
-
'blocked': blocked,
|
| 306 |
-
'latency_ms': round(latency_ms, 2),
|
| 307 |
-
'n_tokens': len(generated_ids),
|
| 308 |
-
}
|
| 309 |
-
|
| 310 |
-
# ── Backwards-compat interface (used by chat.py) ──────────────────
|
| 311 |
-
def respond(self, text, max_tokens=200, temperature=0.8, top_k=40,
|
| 312 |
-
show_routing=True, expert_override=None):
|
| 313 |
-
"""Backwards-compatible wrapper for chat.py."""
|
| 314 |
-
result = self.generate(
|
| 315 |
-
prompt=text,
|
| 316 |
-
max_new_tokens=max_tokens,
|
| 317 |
-
temperature=temperature,
|
| 318 |
-
top_k=top_k,
|
| 319 |
-
)
|
| 320 |
-
if show_routing and result['routing']:
|
| 321 |
-
pairs = sorted(result['routing'].items(), key=lambda x: -x[1])
|
| 322 |
-
print(f"\n [Herald] PARALLEL:")
|
| 323 |
-
for name, w in pairs:
|
| 324 |
-
print(f" {name:12s} {'|'*int(w*25):<25} {w:.3f}")
|
| 325 |
-
return {
|
| 326 |
-
"text": result["text"],
|
| 327 |
-
"routing": list(result["routing"].items()),
|
| 328 |
-
"pipeline": False,
|
| 329 |
-
"drift": result["sentinel"],
|
| 330 |
-
}
|
| 331 |
-
|
| 332 |
-
def route_query(self, query):
|
| 333 |
-
"""Print routing breakdown for a query without generating."""
|
| 334 |
-
if not self.tok:
|
| 335 |
-
return
|
| 336 |
-
ids = encode(query)
|
| 337 |
-
idx = torch.tensor([[self.bos] + ids], dtype=torch.long, device=DEVICE)
|
| 338 |
-
pairs, _ = self.model.get_routing(idx)
|
| 339 |
-
print(f"\nQuery: {query}")
|
| 340 |
-
for name, w in pairs:
|
| 341 |
-
print(f" {name:12s} {'|'*int(w*30):<30} {w:.3f}")
|
| 342 |
-
|
| 343 |
-
def reset_session(self):
|
| 344 |
-
"""Clear working memory and Sentinel incident log."""
|
| 345 |
-
try:
|
| 346 |
-
self.model.echo.clear_memory()
|
| 347 |
-
except Exception:
|
| 348 |
-
pass
|
| 349 |
-
try:
|
| 350 |
-
self.model.sentinel.reset_session()
|
| 351 |
-
except Exception:
|
| 352 |
-
pass
|
| 353 |
-
self.session_id = _new_session_id()
|
| 354 |
-
print("[Engine] Session reset.")
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
# ── Module-level helpers ───────────────────────────────────────────────
|
| 358 |
-
|
| 359 |
-
def _sample(
|
| 360 |
-
logits: torch.Tensor,
|
| 361 |
-
temperature: float = 1.0,
|
| 362 |
-
top_p: float = 0.9,
|
| 363 |
-
top_k: int = 50,
|
| 364 |
-
) -> torch.Tensor:
|
| 365 |
-
"""
|
| 366 |
-
Sample the next token from logits.
|
| 367 |
-
|
| 368 |
-
Sampling strategy:
|
| 369 |
-
temperature=0: Greedy (always pick highest probability token)
|
| 370 |
-
temperature>0: Sample from softmax distribution
|
| 371 |
-
top_k: Only sample from the K highest probability tokens
|
| 372 |
-
top_p: Only sample from the smallest set whose cumulative prob >= p
|
| 373 |
-
|
| 374 |
-
Args:
|
| 375 |
-
logits: [1, VOCAB_SIZE] raw logits from the model
|
| 376 |
-
|
| 377 |
-
Returns:
|
| 378 |
-
[1, 1] tensor containing the selected token ID
|
| 379 |
-
"""
|
| 380 |
-
if temperature == 0.0:
|
| 381 |
-
return logits.argmax(dim=-1, keepdim=True)
|
| 382 |
-
|
| 383 |
-
logits = logits / max(temperature, 1e-8)
|
| 384 |
-
|
| 385 |
-
# Top-K filtering: zero out all but top K logits
|
| 386 |
-
if top_k and top_k > 0:
|
| 387 |
-
topk_vals, _ = torch.topk(logits, min(top_k, logits.shape[-1]))
|
| 388 |
-
logits = logits.masked_fill(logits < topk_vals[:, -1:], float('-inf'))
|
| 389 |
-
|
| 390 |
-
probs = F.softmax(logits, dim=-1)
|
| 391 |
-
|
| 392 |
-
# Top-P (nucleus) filtering
|
| 393 |
-
if top_p is not None and top_p < 1.0:
|
| 394 |
-
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 395 |
-
cumulative = sorted_probs.cumsum(dim=-1)
|
| 396 |
-
remove_mask = cumulative - sorted_probs > top_p
|
| 397 |
-
sorted_probs[remove_mask] = 0.0
|
| 398 |
-
probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
|
| 399 |
-
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 400 |
-
|
| 401 |
-
return torch.multinomial(probs, num_samples=1)
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
def _new_session_id() -> str:
|
| 405 |
-
import uuid
|
| 406 |
-
return str(uuid.uuid4())[:8]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
# memory package
|
|
|
|
|
|
memory/database.py
DELETED
|
@@ -1,379 +0,0 @@
|
|
| 1 |
-
# memory/database.py
|
| 2 |
-
# SQLite-backed persistent memory for Shorekeeper.
|
| 3 |
-
# This file contains the complete database schema and all CRUD operations.
|
| 4 |
-
# The database file lives at MEMORY_DIR/shorekeeper.db
|
| 5 |
-
|
| 6 |
-
import sys
|
| 7 |
-
import sqlite3
|
| 8 |
-
import json
|
| 9 |
-
import hashlib
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
from datetime import datetime
|
| 12 |
-
from typing import Optional
|
| 13 |
-
|
| 14 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
-
from config import MEMORY_DIR
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class MemoryDatabase:
|
| 19 |
-
"""
|
| 20 |
-
Complete SQLite memory system for Shorekeeper.
|
| 21 |
-
|
| 22 |
-
Tables:
|
| 23 |
-
conversations — every user/assistant exchange
|
| 24 |
-
knowledge — factual key-value store
|
| 25 |
-
experience_log — routing decisions and system events
|
| 26 |
-
incidents — Sentinel-flagged behavioral events
|
| 27 |
-
user_preferences — learned preferences about the user
|
| 28 |
-
|
| 29 |
-
Full-text search is enabled on conversations and knowledge
|
| 30 |
-
via SQLite FTS5 virtual tables.
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
def __init__(self, db_path: Path = None):
|
| 34 |
-
if db_path is None:
|
| 35 |
-
db_path = MEMORY_DIR / 'shorekeeper.db'
|
| 36 |
-
db_path.parent.mkdir(parents=True, exist_ok=True)
|
| 37 |
-
self.db_path = db_path
|
| 38 |
-
# check_same_thread=False: allow access from multiple threads
|
| 39 |
-
# (inference daemon may serve concurrent requests)
|
| 40 |
-
self.conn = sqlite3.connect(
|
| 41 |
-
str(db_path),
|
| 42 |
-
check_same_thread=False,
|
| 43 |
-
)
|
| 44 |
-
self.conn.row_factory = sqlite3.Row # Access columns by name: row['id']
|
| 45 |
-
# Enable WAL mode for better concurrent read/write performance
|
| 46 |
-
self.conn.execute('PRAGMA journal_mode=WAL')
|
| 47 |
-
# Enable foreign keys
|
| 48 |
-
self.conn.execute('PRAGMA foreign_keys=ON')
|
| 49 |
-
self._create_tables()
|
| 50 |
-
print(f'[MemoryDatabase] Connected: {db_path}')
|
| 51 |
-
|
| 52 |
-
def _create_tables(self):
|
| 53 |
-
"""Create all tables and indexes. Safe to call repeatedly (IF NOT EXISTS)."""
|
| 54 |
-
self.conn.executescript('''
|
| 55 |
-
-- ── CONVERSATIONS TABLE ───────────────────────────────────
|
| 56 |
-
CREATE TABLE IF NOT EXISTS conversations (
|
| 57 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 58 |
-
session_id TEXT, -- UUID for the current session
|
| 59 |
-
timestamp TEXT NOT NULL, -- ISO8601 datetime
|
| 60 |
-
user_msg TEXT NOT NULL, -- Raw user input
|
| 61 |
-
assistant_msg TEXT NOT NULL, -- Full assistant response
|
| 62 |
-
experts_used TEXT, -- JSON array: ["calcharo","rover"]
|
| 63 |
-
routing_weights TEXT, -- JSON obj: {"calcharo":0.6,"rover":0.4}
|
| 64 |
-
sentinel_score REAL, -- Overall risk score from Sentinel
|
| 65 |
-
sentinel_verdict TEXT, -- CLEAN/FLAG/BLOCK
|
| 66 |
-
tokens_generated INTEGER, -- How many tokens in the response
|
| 67 |
-
latency_ms REAL -- Time to generate in milliseconds
|
| 68 |
-
);
|
| 69 |
-
|
| 70 |
-
-- ── KNOWLEDGE TABLE ───────────────────────────────────────
|
| 71 |
-
CREATE TABLE IF NOT EXISTS knowledge (
|
| 72 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 73 |
-
created_at TEXT NOT NULL,
|
| 74 |
-
updated_at TEXT NOT NULL,
|
| 75 |
-
key TEXT NOT NULL UNIQUE, -- Unique key for deduplication
|
| 76 |
-
value TEXT NOT NULL,
|
| 77 |
-
source TEXT DEFAULT 'user', -- 'user', 'inference', 'system'
|
| 78 |
-
confidence REAL DEFAULT 1.0, -- 0.0-1.0
|
| 79 |
-
access_count INTEGER DEFAULT 0 -- How many times this was retrieved
|
| 80 |
-
);
|
| 81 |
-
|
| 82 |
-
-- ── EXPERIENCE LOG ────────────────────────────────────────
|
| 83 |
-
CREATE TABLE IF NOT EXISTS experience_log (
|
| 84 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 85 |
-
timestamp TEXT NOT NULL,
|
| 86 |
-
event_type TEXT NOT NULL, -- 'routing', 'sentinel_flag', 'error', 'user_feedback'
|
| 87 |
-
event_data TEXT NOT NULL, -- JSON blob
|
| 88 |
-
session_id TEXT,
|
| 89 |
-
severity TEXT DEFAULT 'info' -- 'info', 'warning', 'error'
|
| 90 |
-
);
|
| 91 |
-
|
| 92 |
-
-- ── INCIDENTS TABLE ───────────────────────────────────────
|
| 93 |
-
CREATE TABLE IF NOT EXISTS incidents (
|
| 94 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 95 |
-
timestamp TEXT NOT NULL,
|
| 96 |
-
severity TEXT NOT NULL, -- 'FLAG' or 'BLOCK'
|
| 97 |
-
drift_score REAL,
|
| 98 |
-
refusal_score REAL,
|
| 99 |
-
hallucination_score REAL,
|
| 100 |
-
overall_risk REAL,
|
| 101 |
-
user_msg_snippet TEXT, -- First 200 chars of user message
|
| 102 |
-
output_snippet TEXT, -- First 500 chars of flagged output
|
| 103 |
-
resolution TEXT DEFAULT 'pending',
|
| 104 |
-
conversation_id INTEGER REFERENCES conversations(id)
|
| 105 |
-
);
|
| 106 |
-
|
| 107 |
-
-- ── USER PREFERENCES ──────────────────────────────────────
|
| 108 |
-
CREATE TABLE IF NOT EXISTS user_preferences (
|
| 109 |
-
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 110 |
-
created_at TEXT NOT NULL,
|
| 111 |
-
category TEXT NOT NULL, -- 'tone', 'expertise', 'domain', etc.
|
| 112 |
-
preference TEXT NOT NULL,
|
| 113 |
-
confidence REAL DEFAULT 0.5
|
| 114 |
-
);
|
| 115 |
-
|
| 116 |
-
-- ── INDEXES ───────────────────────────────────────────────
|
| 117 |
-
CREATE INDEX IF NOT EXISTS idx_conv_timestamp
|
| 118 |
-
ON conversations(timestamp DESC);
|
| 119 |
-
CREATE INDEX IF NOT EXISTS idx_conv_sentinel
|
| 120 |
-
ON conversations(sentinel_verdict);
|
| 121 |
-
CREATE INDEX IF NOT EXISTS idx_knowledge_key
|
| 122 |
-
ON knowledge(key);
|
| 123 |
-
CREATE INDEX IF NOT EXISTS idx_incidents_severity
|
| 124 |
-
ON incidents(severity, timestamp DESC);
|
| 125 |
-
|
| 126 |
-
-- ── FULL-TEXT SEARCH TABLES ───────────────────────────────
|
| 127 |
-
CREATE VIRTUAL TABLE IF NOT EXISTS conversations_fts
|
| 128 |
-
USING fts5(
|
| 129 |
-
user_msg,
|
| 130 |
-
assistant_msg,
|
| 131 |
-
content=conversations,
|
| 132 |
-
content_rowid=id
|
| 133 |
-
);
|
| 134 |
-
|
| 135 |
-
CREATE VIRTUAL TABLE IF NOT EXISTS knowledge_fts
|
| 136 |
-
USING fts5(
|
| 137 |
-
key,
|
| 138 |
-
value,
|
| 139 |
-
content=knowledge,
|
| 140 |
-
content_rowid=id
|
| 141 |
-
);
|
| 142 |
-
|
| 143 |
-
-- ── FTS SYNC TRIGGERS ─────────────────────────────────────
|
| 144 |
-
CREATE TRIGGER IF NOT EXISTS conv_fts_insert
|
| 145 |
-
AFTER INSERT ON conversations BEGIN
|
| 146 |
-
INSERT INTO conversations_fts(rowid, user_msg, assistant_msg)
|
| 147 |
-
VALUES (new.id, new.user_msg, new.assistant_msg);
|
| 148 |
-
END;
|
| 149 |
-
|
| 150 |
-
CREATE TRIGGER IF NOT EXISTS conv_fts_delete
|
| 151 |
-
AFTER DELETE ON conversations BEGIN
|
| 152 |
-
INSERT INTO conversations_fts(
|
| 153 |
-
conversations_fts, rowid, user_msg, assistant_msg)
|
| 154 |
-
VALUES ('delete', old.id, old.user_msg, old.assistant_msg);
|
| 155 |
-
END;
|
| 156 |
-
|
| 157 |
-
CREATE TRIGGER IF NOT EXISTS know_fts_insert
|
| 158 |
-
AFTER INSERT ON knowledge BEGIN
|
| 159 |
-
INSERT INTO knowledge_fts(rowid, key, value)
|
| 160 |
-
VALUES (new.id, new.key, new.value);
|
| 161 |
-
END;
|
| 162 |
-
|
| 163 |
-
CREATE TRIGGER IF NOT EXISTS know_fts_update
|
| 164 |
-
AFTER UPDATE ON knowledge BEGIN
|
| 165 |
-
INSERT INTO knowledge_fts(knowledge_fts, rowid, key, value)
|
| 166 |
-
VALUES ('delete', old.id, old.key, old.value);
|
| 167 |
-
INSERT INTO knowledge_fts(rowid, key, value)
|
| 168 |
-
VALUES (new.id, new.key, new.value);
|
| 169 |
-
END;
|
| 170 |
-
''')
|
| 171 |
-
self.conn.commit()
|
| 172 |
-
|
| 173 |
-
def add_conversation(
|
| 174 |
-
self,
|
| 175 |
-
user_msg: str,
|
| 176 |
-
assistant_msg: str,
|
| 177 |
-
session_id: str = None,
|
| 178 |
-
experts_used: list = None,
|
| 179 |
-
routing_weights: dict = None,
|
| 180 |
-
sentinel_score: float = None,
|
| 181 |
-
sentinel_verdict: str = None,
|
| 182 |
-
tokens_generated: int = None,
|
| 183 |
-
latency_ms: float = None,
|
| 184 |
-
) -> int:
|
| 185 |
-
cur = self.conn.execute('''
|
| 186 |
-
INSERT INTO conversations
|
| 187 |
-
(session_id, timestamp, user_msg, assistant_msg,
|
| 188 |
-
experts_used, routing_weights, sentinel_score,
|
| 189 |
-
sentinel_verdict, tokens_generated, latency_ms)
|
| 190 |
-
VALUES (?,?,?,?,?,?,?,?,?,?)
|
| 191 |
-
''', (
|
| 192 |
-
session_id,
|
| 193 |
-
datetime.now().isoformat(),
|
| 194 |
-
user_msg,
|
| 195 |
-
assistant_msg,
|
| 196 |
-
json.dumps(experts_used) if experts_used else None,
|
| 197 |
-
json.dumps(routing_weights) if routing_weights else None,
|
| 198 |
-
sentinel_score,
|
| 199 |
-
sentinel_verdict,
|
| 200 |
-
tokens_generated,
|
| 201 |
-
latency_ms,
|
| 202 |
-
))
|
| 203 |
-
self.conn.commit()
|
| 204 |
-
return cur.lastrowid
|
| 205 |
-
|
| 206 |
-
def search_conversations(
|
| 207 |
-
self,
|
| 208 |
-
query: str,
|
| 209 |
-
limit: int = 5,
|
| 210 |
-
min_score: float = 0.0,
|
| 211 |
-
) -> list[dict]:
|
| 212 |
-
query = query.strip()
|
| 213 |
-
if not query:
|
| 214 |
-
return []
|
| 215 |
-
safe_query = query.replace('"', '').replace("'", '')
|
| 216 |
-
try:
|
| 217 |
-
rows = self.conn.execute('''
|
| 218 |
-
SELECT
|
| 219 |
-
c.user_msg,
|
| 220 |
-
c.assistant_msg,
|
| 221 |
-
c.timestamp,
|
| 222 |
-
bm25(conversations_fts) as score
|
| 223 |
-
FROM conversations_fts
|
| 224 |
-
JOIN conversations c ON conversations_fts.rowid = c.id
|
| 225 |
-
WHERE conversations_fts MATCH ?
|
| 226 |
-
ORDER BY bm25(conversations_fts)
|
| 227 |
-
LIMIT ?
|
| 228 |
-
''', (safe_query, limit)).fetchall()
|
| 229 |
-
return [dict(r) for r in rows]
|
| 230 |
-
except sqlite3.OperationalError:
|
| 231 |
-
rows = self.conn.execute('''
|
| 232 |
-
SELECT user_msg, assistant_msg, timestamp, 0 as score
|
| 233 |
-
FROM conversations
|
| 234 |
-
WHERE user_msg LIKE ? OR assistant_msg LIKE ?
|
| 235 |
-
ORDER BY id DESC
|
| 236 |
-
LIMIT ?
|
| 237 |
-
''', (f'%{query[:50]}%', f'%{query[:50]}%', limit)).fetchall()
|
| 238 |
-
return [dict(r) for r in rows]
|
| 239 |
-
|
| 240 |
-
def get_recent_conversations(self, n: int = 20) -> list[dict]:
|
| 241 |
-
rows = self.conn.execute('''
|
| 242 |
-
SELECT * FROM conversations ORDER BY id DESC LIMIT ?
|
| 243 |
-
''', (n,)).fetchall()
|
| 244 |
-
return [dict(r) for r in rows]
|
| 245 |
-
|
| 246 |
-
def add_knowledge(
|
| 247 |
-
self,
|
| 248 |
-
key: str,
|
| 249 |
-
value: str,
|
| 250 |
-
source: str = 'user',
|
| 251 |
-
confidence: float = 1.0,
|
| 252 |
-
):
|
| 253 |
-
now = datetime.now().isoformat()
|
| 254 |
-
self.conn.execute('''
|
| 255 |
-
INSERT INTO knowledge (created_at, updated_at, key, value, source, confidence)
|
| 256 |
-
VALUES (?,?,?,?,?,?)
|
| 257 |
-
ON CONFLICT(key) DO UPDATE SET
|
| 258 |
-
value = excluded.value,
|
| 259 |
-
updated_at = excluded.updated_at,
|
| 260 |
-
source = excluded.source,
|
| 261 |
-
confidence = excluded.confidence
|
| 262 |
-
''', (now, now, key, value, source, confidence))
|
| 263 |
-
self.conn.commit()
|
| 264 |
-
|
| 265 |
-
def search_knowledge(self, query: str, limit: int = 5) -> list[dict]:
|
| 266 |
-
query = query.strip()
|
| 267 |
-
if not query:
|
| 268 |
-
return []
|
| 269 |
-
safe_query = query.replace('"', '').replace("'", '')
|
| 270 |
-
try:
|
| 271 |
-
rows = self.conn.execute('''
|
| 272 |
-
SELECT k.key, k.value, k.source, k.confidence,
|
| 273 |
-
bm25(knowledge_fts) as score
|
| 274 |
-
FROM knowledge_fts
|
| 275 |
-
JOIN knowledge k ON knowledge_fts.rowid = k.id
|
| 276 |
-
WHERE knowledge_fts MATCH ?
|
| 277 |
-
ORDER BY bm25(knowledge_fts)
|
| 278 |
-
LIMIT ?
|
| 279 |
-
''', (safe_query, limit)).fetchall()
|
| 280 |
-
for row in rows:
|
| 281 |
-
self.conn.execute(
|
| 282 |
-
'UPDATE knowledge SET access_count = access_count + 1 WHERE key = ?',
|
| 283 |
-
(row['key'],)
|
| 284 |
-
)
|
| 285 |
-
self.conn.commit()
|
| 286 |
-
return [dict(r) for r in rows]
|
| 287 |
-
except sqlite3.OperationalError:
|
| 288 |
-
rows = self.conn.execute('''
|
| 289 |
-
SELECT key, value, source, confidence, 0 as score
|
| 290 |
-
FROM knowledge WHERE key LIKE ? OR value LIKE ? LIMIT ?
|
| 291 |
-
''', (f'%{query[:50]}%', f'%{query[:50]}%', limit)).fetchall()
|
| 292 |
-
return [dict(r) for r in rows]
|
| 293 |
-
|
| 294 |
-
def get_all_knowledge(self) -> list[dict]:
|
| 295 |
-
rows = self.conn.execute(
|
| 296 |
-
'SELECT * FROM knowledge ORDER BY access_count DESC'
|
| 297 |
-
).fetchall()
|
| 298 |
-
return [dict(r) for r in rows]
|
| 299 |
-
|
| 300 |
-
def log_incident(
|
| 301 |
-
self,
|
| 302 |
-
severity: str,
|
| 303 |
-
drift_score: float,
|
| 304 |
-
refusal_score: float,
|
| 305 |
-
hallucination_score: float,
|
| 306 |
-
overall_risk: float,
|
| 307 |
-
user_msg: str = '',
|
| 308 |
-
output_snippet: str = '',
|
| 309 |
-
conversation_id: int = None,
|
| 310 |
-
):
|
| 311 |
-
self.conn.execute('''
|
| 312 |
-
INSERT INTO incidents
|
| 313 |
-
(timestamp, severity, drift_score, refusal_score,
|
| 314 |
-
hallucination_score, overall_risk, user_msg_snippet,
|
| 315 |
-
output_snippet, conversation_id)
|
| 316 |
-
VALUES (?,?,?,?,?,?,?,?,?)
|
| 317 |
-
''', (
|
| 318 |
-
datetime.now().isoformat(),
|
| 319 |
-
severity, drift_score, refusal_score,
|
| 320 |
-
hallucination_score, overall_risk,
|
| 321 |
-
user_msg[:200], output_snippet[:500], conversation_id
|
| 322 |
-
))
|
| 323 |
-
self.conn.commit()
|
| 324 |
-
|
| 325 |
-
def get_recent_incidents(self, n: int = 20) -> list[dict]:
|
| 326 |
-
rows = self.conn.execute('''
|
| 327 |
-
SELECT * FROM incidents ORDER BY id DESC LIMIT ?
|
| 328 |
-
''', (n,)).fetchall()
|
| 329 |
-
return [dict(r) for r in rows]
|
| 330 |
-
|
| 331 |
-
def log_event(
|
| 332 |
-
self,
|
| 333 |
-
event_type: str,
|
| 334 |
-
event_data: dict,
|
| 335 |
-
session_id: str = None,
|
| 336 |
-
severity: str = 'info',
|
| 337 |
-
):
|
| 338 |
-
self.conn.execute('''
|
| 339 |
-
INSERT INTO experience_log (timestamp, event_type, event_data, session_id, severity)
|
| 340 |
-
VALUES (?,?,?,?,?)
|
| 341 |
-
''', (
|
| 342 |
-
datetime.now().isoformat(),
|
| 343 |
-
event_type,
|
| 344 |
-
json.dumps(event_data),
|
| 345 |
-
session_id,
|
| 346 |
-
severity,
|
| 347 |
-
))
|
| 348 |
-
self.conn.commit()
|
| 349 |
-
|
| 350 |
-
def get_stats(self) -> dict:
|
| 351 |
-
stats = {}
|
| 352 |
-
for table in ['conversations', 'knowledge', 'experience_log', 'incidents', 'user_preferences']:
|
| 353 |
-
row = self.conn.execute(f'SELECT COUNT(*) as c FROM {table}').fetchone()
|
| 354 |
-
stats[table] = row['c']
|
| 355 |
-
rows = self.conn.execute('''
|
| 356 |
-
SELECT experts_used, COUNT(*) as n
|
| 357 |
-
FROM conversations
|
| 358 |
-
WHERE experts_used IS NOT NULL
|
| 359 |
-
GROUP BY experts_used
|
| 360 |
-
ORDER BY n DESC LIMIT 10
|
| 361 |
-
''').fetchall()
|
| 362 |
-
stats['top_expert_combinations'] = [
|
| 363 |
-
{'combo': r['experts_used'], 'count': r['n']} for r in rows
|
| 364 |
-
]
|
| 365 |
-
row = self.conn.execute('''
|
| 366 |
-
SELECT
|
| 367 |
-
COUNT(*) as total,
|
| 368 |
-
SUM(CASE WHEN sentinel_verdict="CLEAN" THEN 1 ELSE 0 END) as clean,
|
| 369 |
-
SUM(CASE WHEN sentinel_verdict="FLAG" THEN 1 ELSE 0 END) as flagged,
|
| 370 |
-
SUM(CASE WHEN sentinel_verdict="BLOCK" THEN 1 ELSE 0 END) as blocked,
|
| 371 |
-
AVG(sentinel_score) as avg_risk
|
| 372 |
-
FROM conversations WHERE sentinel_verdict IS NOT NULL
|
| 373 |
-
''').fetchone()
|
| 374 |
-
if row and row['total']:
|
| 375 |
-
stats['sentinel'] = dict(row)
|
| 376 |
-
return stats
|
| 377 |
-
|
| 378 |
-
def close(self):
|
| 379 |
-
self.conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
memory/vector_store.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
# memory/vector_store.py
|
| 2 |
-
# FAISS-based vector similarity search for Echo.
|
| 3 |
-
# Stores embeddings of past conversations for semantic retrieval.
|
| 4 |
-
# Uses a lightweight sentence encoder (not the full 2B model).
|
| 5 |
-
|
| 6 |
-
import sys
|
| 7 |
-
import numpy as np
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 11 |
-
from config import MEMORY_DIR
|
| 12 |
-
|
| 13 |
-
try:
|
| 14 |
-
import faiss
|
| 15 |
-
FAISS_AVAILABLE = True
|
| 16 |
-
except ImportError:
|
| 17 |
-
FAISS_AVAILABLE = False
|
| 18 |
-
print('[vector_store] FAISS not available. Using SQLite FTS5 only.')
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class VectorStore:
|
| 22 |
-
"""
|
| 23 |
-
FAISS-backed vector store for semantic memory retrieval.
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
EMBEDDING_DIM = 384 # sentence-transformers/all-MiniLM-L6-v2 output dim
|
| 27 |
-
INDEX_FILE = 'faiss_index.bin'
|
| 28 |
-
ID_MAP_FILE = 'faiss_id_map.npy'
|
| 29 |
-
|
| 30 |
-
def __init__(self):
|
| 31 |
-
self.index_path = MEMORY_DIR / self.INDEX_FILE
|
| 32 |
-
self.id_map_path = MEMORY_DIR / self.ID_MAP_FILE
|
| 33 |
-
self._encoder = None # Lazy load
|
| 34 |
-
self._index = None # Lazy load
|
| 35 |
-
self._id_map = [] # Maps FAISS row index → conversation ID
|
| 36 |
-
|
| 37 |
-
if FAISS_AVAILABLE:
|
| 38 |
-
self._load_or_create_index()
|
| 39 |
-
|
| 40 |
-
def _load_or_create_index(self):
|
| 41 |
-
"""Load existing FAISS index or create a new one."""
|
| 42 |
-
if self.index_path.exists() and self.id_map_path.exists():
|
| 43 |
-
self._index = faiss.read_index(str(self.index_path))
|
| 44 |
-
self._id_map = np.load(str(self.id_map_path), allow_pickle=True).tolist()
|
| 45 |
-
print(f'[VectorStore] Loaded index: {self._index.ntotal} vectors')
|
| 46 |
-
else:
|
| 47 |
-
# IndexFlatIP: inner product similarity (cosine if L2-normalized)
|
| 48 |
-
self._index = faiss.IndexFlatIP(self.EMBEDDING_DIM)
|
| 49 |
-
self._id_map = []
|
| 50 |
-
print('[VectorStore] Created new FAISS index')
|
| 51 |
-
|
| 52 |
-
def _get_encoder(self):
|
| 53 |
-
"""Lazy-load the sentence encoder on first use."""
|
| 54 |
-
if self._encoder is None:
|
| 55 |
-
try:
|
| 56 |
-
from sentence_transformers import SentenceTransformer
|
| 57 |
-
self._encoder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 58 |
-
print('[VectorStore] Encoder loaded: all-MiniLM-L6-v2')
|
| 59 |
-
except ImportError:
|
| 60 |
-
print('[VectorStore] sentence-transformers not installed.')
|
| 61 |
-
print(' Install: pip install sentence-transformers')
|
| 62 |
-
return self._encoder
|
| 63 |
-
|
| 64 |
-
def encode(self, text: str) -> np.ndarray:
|
| 65 |
-
"""Encode text to a normalized embedding vector."""
|
| 66 |
-
encoder = self._get_encoder()
|
| 67 |
-
if encoder is None:
|
| 68 |
-
return None
|
| 69 |
-
vec = encoder.encode([text], normalize_embeddings=True)
|
| 70 |
-
return vec.astype(np.float32)
|
| 71 |
-
|
| 72 |
-
def add(
|
| 73 |
-
self,
|
| 74 |
-
conversation_id: int,
|
| 75 |
-
text: str,
|
| 76 |
-
):
|
| 77 |
-
"""Add a conversation to the vector index."""
|
| 78 |
-
if not FAISS_AVAILABLE or self._index is None:
|
| 79 |
-
return
|
| 80 |
-
vec = self.encode(text)
|
| 81 |
-
if vec is None:
|
| 82 |
-
return
|
| 83 |
-
self._index.add(vec)
|
| 84 |
-
self._id_map.append(conversation_id)
|
| 85 |
-
self._save()
|
| 86 |
-
|
| 87 |
-
def search(
|
| 88 |
-
self,
|
| 89 |
-
query: str,
|
| 90 |
-
top_k: int = 5,
|
| 91 |
-
) -> list[int]:
|
| 92 |
-
"""Find the top_k most semantically similar conversation IDs."""
|
| 93 |
-
if not FAISS_AVAILABLE or self._index is None or self._index.ntotal == 0:
|
| 94 |
-
return []
|
| 95 |
-
vec = self.encode(query)
|
| 96 |
-
if vec is None:
|
| 97 |
-
return []
|
| 98 |
-
k = min(top_k, self._index.ntotal)
|
| 99 |
-
distances, indices = self._index.search(vec, k)
|
| 100 |
-
conv_ids = []
|
| 101 |
-
for idx in indices[0]:
|
| 102 |
-
if idx >= 0 and idx < len(self._id_map):
|
| 103 |
-
conv_ids.append(self._id_map[idx])
|
| 104 |
-
return conv_ids
|
| 105 |
-
|
| 106 |
-
def _save(self):
|
| 107 |
-
if self._index is not None:
|
| 108 |
-
faiss.write_index(self._index, str(self.index_path))
|
| 109 |
-
np.save(str(self.id_map_path), np.array(self._id_map, dtype=object))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/__init__.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
# model/__init__.py
|
| 2 |
-
from .base import SharedBase
|
| 3 |
-
from .expert import ExpertHead
|
| 4 |
-
from .herald import Herald
|
| 5 |
-
from .echo import Echo
|
| 6 |
-
from .sentinel import Sentinel
|
| 7 |
-
from .ensemble import ShorekeeperEnsemble
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/base.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
| 1 |
-
# model/base.py
|
| 2 |
-
# The shared transformer backbone. Used by ALL components.
|
| 3 |
-
|
| 4 |
-
import sys, math
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 10 |
-
import config
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class RotaryEmbedding(nn.Module):
|
| 14 |
-
"""
|
| 15 |
-
Rotary Position Embedding (RoPE).
|
| 16 |
-
Encodes position by rotating query and key vectors.
|
| 17 |
-
Better than learned absolute embeddings for long contexts.
|
| 18 |
-
Used by LLaMA, GPT-NeoX, and most modern LLMs post-2022.
|
| 19 |
-
"""
|
| 20 |
-
def __init__(self, dim: int, max_seq: int = 2048):
|
| 21 |
-
super().__init__()
|
| 22 |
-
# Use float32 base precision for stable frequency generation.
|
| 23 |
-
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
| 24 |
-
self.register_buffer("inv_freq", inv_freq)
|
| 25 |
-
self.max_seq = max_seq
|
| 26 |
-
|
| 27 |
-
def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype = None):
|
| 28 |
-
inv_freq = self.inv_freq.to(device)
|
| 29 |
-
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 30 |
-
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
| 31 |
-
emb = torch.cat([freqs, freqs], dim=-1)
|
| 32 |
-
if emb.shape[-1] != inv_freq.shape[0] * 2:
|
| 33 |
-
emb = emb[:, :inv_freq.shape[0] * 2]
|
| 34 |
-
if dtype is not None:
|
| 35 |
-
emb = emb.to(dtype)
|
| 36 |
-
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def rotate_half(x):
|
| 40 |
-
x1, x2 = x.chunk(2, dim=-1)
|
| 41 |
-
return torch.cat([-x2, x1], dim=-1)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def apply_rope(q, k, cos, sin):
|
| 45 |
-
cos, sin = cos.to(q.dtype), sin.to(q.dtype)
|
| 46 |
-
return (q * cos + rotate_half(q) * sin), (k * cos + rotate_half(k) * sin)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class MultiHeadAttention(nn.Module):
|
| 50 |
-
def __init__(self, n_embd: int, n_head: int, dropout: float = 0.1):
|
| 51 |
-
super().__init__()
|
| 52 |
-
assert n_embd % n_head == 0
|
| 53 |
-
self.n_head = n_head
|
| 54 |
-
self.n_embd = n_embd
|
| 55 |
-
self.head_dim = n_embd // n_head
|
| 56 |
-
self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
|
| 57 |
-
self.proj = nn.Linear(n_embd, n_embd, bias=False)
|
| 58 |
-
self.drop = nn.Dropout(dropout)
|
| 59 |
-
self.rope = RotaryEmbedding(self.head_dim)
|
| 60 |
-
|
| 61 |
-
def forward(self, x, mask=None):
|
| 62 |
-
B, T, C = x.shape
|
| 63 |
-
qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
|
| 64 |
-
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
| 65 |
-
cos, sin = self.rope(T, x.device, dtype=q.dtype)
|
| 66 |
-
q, k = apply_rope(q, k, cos, sin)
|
| 67 |
-
try:
|
| 68 |
-
from torch.nn.functional import scaled_dot_product_attention
|
| 69 |
-
out = scaled_dot_product_attention(q, k, v, attn_mask=None,
|
| 70 |
-
dropout_p=self.drop.p if self.training else 0.0, is_causal=True)
|
| 71 |
-
except Exception:
|
| 72 |
-
scale = math.sqrt(self.head_dim)
|
| 73 |
-
att = (q @ k.transpose(-2, -1)) / scale
|
| 74 |
-
causal = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
|
| 75 |
-
att.masked_fill_(causal, float("-inf"))
|
| 76 |
-
if mask is not None:
|
| 77 |
-
pad = (1 - mask[:, None, None, :].float()) * -1e4
|
| 78 |
-
att = att + pad
|
| 79 |
-
att = torch.softmax(att, dim=-1)
|
| 80 |
-
att = self.drop(att)
|
| 81 |
-
out = att @ v
|
| 82 |
-
out = out.transpose(1, 2).reshape(B, T, C)
|
| 83 |
-
return self.proj(out)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class FeedForward(nn.Module):
|
| 87 |
-
def __init__(self, n_embd: int, dropout: float = 0.1):
|
| 88 |
-
super().__init__()
|
| 89 |
-
self.net = nn.Sequential(
|
| 90 |
-
nn.Linear(n_embd, 4 * n_embd, bias=False),
|
| 91 |
-
nn.GELU(),
|
| 92 |
-
nn.Linear(4 * n_embd, n_embd, bias=False),
|
| 93 |
-
nn.Dropout(dropout),
|
| 94 |
-
)
|
| 95 |
-
def forward(self, x): return self.net(x)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
class TransformerBlock(nn.Module):
|
| 99 |
-
def __init__(self, n_embd: int, n_head: int, dropout: float = 0.1):
|
| 100 |
-
super().__init__()
|
| 101 |
-
self.ln1 = nn.LayerNorm(n_embd)
|
| 102 |
-
self.attn = MultiHeadAttention(n_embd, n_head, dropout)
|
| 103 |
-
self.ln2 = nn.LayerNorm(n_embd)
|
| 104 |
-
self.ff = FeedForward(n_embd, dropout)
|
| 105 |
-
|
| 106 |
-
def forward(self, x, mask=None):
|
| 107 |
-
x = x + self.attn(self.ln1(x), mask) # Pre-norm + residual
|
| 108 |
-
x = x + self.ff(self.ln2(x))
|
| 109 |
-
return x
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class SharedBase(nn.Module):
|
| 113 |
-
def __init__(self, cfg=None):
|
| 114 |
-
super().__init__()
|
| 115 |
-
if cfg is None:
|
| 116 |
-
cfg = config.BASE_CONFIG
|
| 117 |
-
self.n_embd = cfg["n_embd"]
|
| 118 |
-
self.n_positions = cfg["n_positions"]
|
| 119 |
-
self.token_embedding = nn.Embedding(cfg["vocab_size"], cfg["n_embd"])
|
| 120 |
-
self.drop = nn.Dropout(cfg["dropout"])
|
| 121 |
-
self.blocks = nn.ModuleList([
|
| 122 |
-
TransformerBlock(cfg["n_embd"], cfg["n_head"], cfg["dropout"])
|
| 123 |
-
for _ in range(cfg["n_layer"])
|
| 124 |
-
])
|
| 125 |
-
self.ln_f = nn.LayerNorm(cfg["n_embd"])
|
| 126 |
-
if config.MEMORY_OPT.get("gradient_checkpointing"):
|
| 127 |
-
self.gradient_checkpointing_enable()
|
| 128 |
-
self._init_weights()
|
| 129 |
-
|
| 130 |
-
def _init_weights(self):
|
| 131 |
-
for m in self.modules():
|
| 132 |
-
if isinstance(m, nn.Linear):
|
| 133 |
-
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 134 |
-
elif isinstance(m, nn.Embedding):
|
| 135 |
-
nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
| 136 |
-
|
| 137 |
-
def gradient_checkpointing_enable(self):
|
| 138 |
-
self._use_checkpointing = True
|
| 139 |
-
|
| 140 |
-
def forward(self, input_ids, attention_mask=None):
|
| 141 |
-
x = self.drop(self.token_embedding(input_ids))
|
| 142 |
-
use_ckpt = getattr(self, "_use_checkpointing", False) and self.training
|
| 143 |
-
for block in self.blocks:
|
| 144 |
-
if use_ckpt:
|
| 145 |
-
from torch.utils.checkpoint import checkpoint
|
| 146 |
-
x = checkpoint(block, x, attention_mask, use_reentrant=False)
|
| 147 |
-
else:
|
| 148 |
-
x = block(x, attention_mask)
|
| 149 |
-
return self.ln_f(x)
|
| 150 |
-
|
| 151 |
-
def get_param_count(self):
|
| 152 |
-
return sum(p.numel() for p in self.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/echo.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
# model/echo.py
|
| 2 |
-
import sys, torch, torch.nn as nn
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 5 |
-
import config
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class Echo(nn.Module):
|
| 9 |
-
"""
|
| 10 |
-
Memory retrieval and context injection module.
|
| 11 |
-
Stage 1: Retrieve relevant past exchanges from SQLite/FAISS, inject as context prefix.
|
| 12 |
-
Stage 2: Cross-attention memory gate during generation.
|
| 13 |
-
"""
|
| 14 |
-
def __init__(self, cfg=None, n_embd=None):
|
| 15 |
-
super().__init__()
|
| 16 |
-
cfg = cfg or config.ECHO_CONFIG
|
| 17 |
-
n = n_embd or cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
|
| 18 |
-
from model.base import TransformerBlock
|
| 19 |
-
self.memory_attn = TransformerBlock(n, cfg.get("n_head", config.BASE_CONFIG["n_head"]))
|
| 20 |
-
self.gate_proj = nn.Linear(n * 2, n, bias=False)
|
| 21 |
-
self.gate_norm = nn.LayerNorm(n)
|
| 22 |
-
self._working_memory = [] # List of (role, cpu_embedding) tuples
|
| 23 |
-
self._max_memory = cfg.get("max_memory_tokens", 512) // 64 # keep ~8 entries
|
| 24 |
-
|
| 25 |
-
def retrieve_context(self, query: str, db, vs) -> str:
|
| 26 |
-
"""
|
| 27 |
-
Retrieve relevant past conversations and inject as context prefix.
|
| 28 |
-
Returns an enriched prompt string with memory context prepended.
|
| 29 |
-
"""
|
| 30 |
-
if db is None: return query
|
| 31 |
-
# Try semantic search first (FAISS), fall back to keyword (FTS5)
|
| 32 |
-
results = []
|
| 33 |
-
if vs is not None:
|
| 34 |
-
conv_ids = vs.search(query, top_k=3)
|
| 35 |
-
if conv_ids:
|
| 36 |
-
recent = db.get_recent_conversations(n=50)
|
| 37 |
-
id_to_conv = {c["id"]: c for c in recent if "id" in c}
|
| 38 |
-
for cid in conv_ids:
|
| 39 |
-
if cid in id_to_conv:
|
| 40 |
-
results.append(id_to_conv[cid])
|
| 41 |
-
if not results:
|
| 42 |
-
results = db.search_conversations(query[:100], limit=3)
|
| 43 |
-
knowledge = db.search_knowledge(query[:100], limit=2)
|
| 44 |
-
if not results and not knowledge: return query
|
| 45 |
-
ctx_parts = ["[MEMORY]"]
|
| 46 |
-
for r in results[:3]:
|
| 47 |
-
u = str(r.get("user_msg", ""))[:100]
|
| 48 |
-
a = str(r.get("assistant_msg", ""))[:200]
|
| 49 |
-
if u: ctx_parts.append(f"Q: {u}\nA: {a}")
|
| 50 |
-
for k in knowledge[:2]:
|
| 51 |
-
ctx_parts.append(f"FACT: {k['key']} = {k['value']}")
|
| 52 |
-
ctx_parts.append("[SEP]")
|
| 53 |
-
ctx_parts.append(query)
|
| 54 |
-
return "\n".join(ctx_parts)
|
| 55 |
-
|
| 56 |
-
def update_working_memory(self, role: str, embedding: torch.Tensor):
|
| 57 |
-
"""
|
| 58 |
-
Store a hidden-state embedding in the short-term working memory buffer.
|
| 59 |
-
Embeddings are moved to CPU to avoid holding GPU memory between turns.
|
| 60 |
-
"""
|
| 61 |
-
self._working_memory.append((role, embedding.detach().cpu()))
|
| 62 |
-
if len(self._working_memory) > self._max_memory:
|
| 63 |
-
self._working_memory = self._working_memory[-self._max_memory:]
|
| 64 |
-
|
| 65 |
-
def clear_memory(self):
|
| 66 |
-
"""Clear the working memory buffer (call on session reset)."""
|
| 67 |
-
self._working_memory.clear()
|
| 68 |
-
|
| 69 |
-
def get_memory_summary(self) -> str:
|
| 70 |
-
"""Return a simple string summary of buffered memory entries."""
|
| 71 |
-
return f"[Echo] {len(self._working_memory)} entries in working memory"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/ensemble.py
DELETED
|
@@ -1,346 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Shorekeeper ensemble with lazy expert loading.
|
| 3 |
-
"""
|
| 4 |
-
import sys, torch, torch.nn as nn, torch.nn.functional as F
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 7 |
-
import config
|
| 8 |
-
from config import EXPERT_NAMES, CHECKPOINT_DIR, DEVICE, DTYPE
|
| 9 |
-
from model.base import SharedBase
|
| 10 |
-
from model.herald import Herald
|
| 11 |
-
from model.echo import Echo
|
| 12 |
-
from model.sentinel import Sentinel
|
| 13 |
-
from model.lazy_expert_loader import LazyExpertLoader
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class ShorekeeperEnsemble(nn.Module):
|
| 17 |
-
"""Full assembled Shorekeeper model with lazy expert loading."""
|
| 18 |
-
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
max_loaded_experts: int = 2,
|
| 22 |
-
base_cfg: dict = None,
|
| 23 |
-
expert_cfgs: dict = None,
|
| 24 |
-
):
|
| 25 |
-
super().__init__()
|
| 26 |
-
self.base_cfg = base_cfg or config.BASE_CONFIG
|
| 27 |
-
self.expert_cfgs = expert_cfgs or config.EXPERT_CONFIGS
|
| 28 |
-
self.base = SharedBase(self.base_cfg)
|
| 29 |
-
self.expert_loader = LazyExpertLoader(
|
| 30 |
-
expert_names=EXPERT_NAMES,
|
| 31 |
-
checkpoint_dir=CHECKPOINT_DIR,
|
| 32 |
-
device=DEVICE,
|
| 33 |
-
max_loaded=max_loaded_experts,
|
| 34 |
-
dtype=DTYPE,
|
| 35 |
-
base_cfg=self.base_cfg,
|
| 36 |
-
expert_cfgs=self.expert_cfgs,
|
| 37 |
-
)
|
| 38 |
-
herald_n_embd = self.base_cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
|
| 39 |
-
self.herald = Herald(n_embd=herald_n_embd)
|
| 40 |
-
self.echo = Echo(n_embd=herald_n_embd)
|
| 41 |
-
sentinel_cfg = self.expert_cfgs.get("sentinel", config.EXPERT_CONFIGS.get("sentinel", {}))
|
| 42 |
-
self.sentinel = Sentinel(cfg=sentinel_cfg, n_embd=herald_n_embd)
|
| 43 |
-
|
| 44 |
-
print(f"[Ensemble] Initialized with lazy expert loading (max_loaded={max_loaded_experts})")
|
| 45 |
-
|
| 46 |
-
# backward compatibility for code expecting model.experts[...]
|
| 47 |
-
@property
|
| 48 |
-
def experts(self):
|
| 49 |
-
class _proxy:
|
| 50 |
-
def __init__(self, loader):
|
| 51 |
-
self.loader = loader
|
| 52 |
-
def __getitem__(self, name):
|
| 53 |
-
return self.loader.get_expert(name)
|
| 54 |
-
def keys(self):
|
| 55 |
-
return self.loader.expert_names
|
| 56 |
-
def __iter__(self):
|
| 57 |
-
return iter(self.loader.expert_names)
|
| 58 |
-
return _proxy(self.expert_loader)
|
| 59 |
-
|
| 60 |
-
# ── Convenience alias ─────────────────────────────────────────────
|
| 61 |
-
@property
|
| 62 |
-
def shared_base(self):
|
| 63 |
-
return self.base
|
| 64 |
-
|
| 65 |
-
# ── Core forward with lazy loading ────────────────────────────────
|
| 66 |
-
def forward(self, input_ids, attention_mask=None,
|
| 67 |
-
return_routing=False, return_sentinel=False):
|
| 68 |
-
base_hidden = self.base(input_ids, attention_mask)
|
| 69 |
-
# Ensure experts load to the same device/dtype as base_hidden
|
| 70 |
-
self.expert_loader.set_device(base_hidden.device, dtype=base_hidden.dtype)
|
| 71 |
-
routing = self.herald(base_hidden, attention_mask)
|
| 72 |
-
expert_idx = routing["expert_indices"]
|
| 73 |
-
expert_wts = routing["expert_weights"]
|
| 74 |
-
B = input_ids.shape[0]
|
| 75 |
-
logits = None
|
| 76 |
-
experts_used = []
|
| 77 |
-
|
| 78 |
-
for ki in range(expert_idx.shape[1]):
|
| 79 |
-
for b in range(B):
|
| 80 |
-
name = EXPERT_NAMES[expert_idx[b, ki].item()]
|
| 81 |
-
if name not in experts_used:
|
| 82 |
-
experts_used.append(name)
|
| 83 |
-
expert = self.expert_loader.get_expert(name)
|
| 84 |
-
out = expert(
|
| 85 |
-
base_hidden[b:b+1],
|
| 86 |
-
attention_mask[b:b+1] if attention_mask is not None else None,
|
| 87 |
-
)
|
| 88 |
-
if "logits" not in out:
|
| 89 |
-
print(f"[Ensemble] Expert {name} output missing logits, skipping")
|
| 90 |
-
continue
|
| 91 |
-
if out["logits"].dim() != 3:
|
| 92 |
-
print(f"[Ensemble] Expert {name} returned unexpected logits shape {out['logits'].shape}, skipping")
|
| 93 |
-
continue
|
| 94 |
-
if out["logits"].shape[-1] != self.base_cfg["vocab_size"]:
|
| 95 |
-
print(f"[Ensemble] Expert {name} logits vocab size mismatch ({out['logits'].shape[-1]} vs {self.base_cfg['vocab_size']}), skipping")
|
| 96 |
-
continue
|
| 97 |
-
weighted = out["logits"] * expert_wts[b, ki]
|
| 98 |
-
if logits is None:
|
| 99 |
-
logits = torch.zeros(B, weighted.shape[1], weighted.shape[2],
|
| 100 |
-
device=weighted.device, dtype=weighted.dtype)
|
| 101 |
-
logits[b] = logits[b] + weighted.squeeze(0)
|
| 102 |
-
|
| 103 |
-
result = {
|
| 104 |
-
"logits": logits,
|
| 105 |
-
"load_balance_loss": routing["load_balance_loss"],
|
| 106 |
-
"experts_used": experts_used,
|
| 107 |
-
}
|
| 108 |
-
if return_routing:
|
| 109 |
-
result["routing"] = {
|
| 110 |
-
EXPERT_NAMES[i]: routing["router_probs"][0, i].item()
|
| 111 |
-
for i in range(len(EXPERT_NAMES))
|
| 112 |
-
}
|
| 113 |
-
if return_sentinel:
|
| 114 |
-
result["sentinel"] = self.sentinel.analyze(base_hidden, attention_mask)
|
| 115 |
-
return result
|
| 116 |
-
|
| 117 |
-
# ── Autoregressive generation ─────────────────────────────────────
|
| 118 |
-
@torch.no_grad()
|
| 119 |
-
def generate(
|
| 120 |
-
self,
|
| 121 |
-
idx: torch.Tensor,
|
| 122 |
-
max_new_tokens: int = 200,
|
| 123 |
-
temperature: float = 0.8,
|
| 124 |
-
top_k: int = 40,
|
| 125 |
-
top_p: float = None,
|
| 126 |
-
expert_override: str = None,
|
| 127 |
-
) -> torch.Tensor:
|
| 128 |
-
from tokenizer.tokenizer_utils import get_tokenizer
|
| 129 |
-
try:
|
| 130 |
-
eos_id = get_tokenizer().token_to_id("[EOS]")
|
| 131 |
-
except Exception:
|
| 132 |
-
eos_id = None
|
| 133 |
-
|
| 134 |
-
max_ctx = self.base.n_positions
|
| 135 |
-
for _ in range(max_new_tokens):
|
| 136 |
-
idx_cond = idx if idx.size(1) <= max_ctx else idx[:, -max_ctx:]
|
| 137 |
-
result = self.forward(idx_cond)
|
| 138 |
-
logits = result["logits"][:, -1, :]
|
| 139 |
-
if temperature == 0:
|
| 140 |
-
next_id = logits.argmax(dim=-1, keepdim=True)
|
| 141 |
-
else:
|
| 142 |
-
logits = logits / max(temperature, 1e-8)
|
| 143 |
-
if top_k is not None and top_k > 0:
|
| 144 |
-
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 145 |
-
logits[logits < v[:, [-1]]] = float("-inf")
|
| 146 |
-
if top_p is not None and 0.0 < top_p < 1.0:
|
| 147 |
-
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
| 148 |
-
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 149 |
-
sorted_idx_to_remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
|
| 150 |
-
sorted_logits[sorted_idx_to_remove] = float("-inf")
|
| 151 |
-
logits = logits.scatter(1, sorted_idx, sorted_logits)
|
| 152 |
-
probs = F.softmax(logits, dim=-1)
|
| 153 |
-
next_id = torch.multinomial(probs, num_samples=1)
|
| 154 |
-
|
| 155 |
-
idx = torch.cat([idx, next_id], dim=1)
|
| 156 |
-
if eos_id is not None and next_id[0, 0].item() == eos_id:
|
| 157 |
-
break
|
| 158 |
-
|
| 159 |
-
return idx
|
| 160 |
-
|
| 161 |
-
# ── Routing inspection ────────────────────────────────────────────
|
| 162 |
-
@torch.no_grad()
|
| 163 |
-
def get_routing(self, idx: torch.Tensor):
|
| 164 |
-
base_hidden = self.base(idx)
|
| 165 |
-
routing = self.herald(base_hidden)
|
| 166 |
-
probs = routing["router_probs"][0]
|
| 167 |
-
pairs = [(EXPERT_NAMES[i], probs[i].item()) for i in range(len(EXPERT_NAMES))]
|
| 168 |
-
pairs.sort(key=lambda x: -x[1])
|
| 169 |
-
return pairs, False
|
| 170 |
-
|
| 171 |
-
# ── Safety scan ─────────────────────────────────────────────────
|
| 172 |
-
def scan_output(self, primary_expert: str, text: str):
|
| 173 |
-
from tokenizer.tokenizer_utils import encode_batch
|
| 174 |
-
device = next(self.parameters()).device
|
| 175 |
-
ids, mask = encode_batch([text], max_length=512)
|
| 176 |
-
ids = ids.to(device)
|
| 177 |
-
mask = mask.to(device)
|
| 178 |
-
with torch.no_grad():
|
| 179 |
-
base_hidden = self.base(ids, mask)
|
| 180 |
-
report = self.sentinel.analyze(base_hidden, mask)
|
| 181 |
-
self.sentinel.log_expert(primary_expert)
|
| 182 |
-
return report
|
| 183 |
-
|
| 184 |
-
# ── Persistence ─────────────────────────────────────────────────
|
| 185 |
-
def save(self, path: str):
|
| 186 |
-
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 187 |
-
state = {
|
| 188 |
-
"base_state_dict": self.base.state_dict(),
|
| 189 |
-
"herald_state_dict": self.herald.state_dict(),
|
| 190 |
-
"echo_state_dict": self.echo.state_dict(),
|
| 191 |
-
"sentinel_state_dict": self.sentinel.state_dict(),
|
| 192 |
-
}
|
| 193 |
-
torch.save(state, path)
|
| 194 |
-
print(f"[Ensemble] Saved core components to {path}")
|
| 195 |
-
|
| 196 |
-
@classmethod
|
| 197 |
-
def _safe_load_state_dict(cls, module: nn.Module, state_dict: dict, name: str):
|
| 198 |
-
own = module.state_dict()
|
| 199 |
-
filtered = {}
|
| 200 |
-
for k, v in state_dict.items():
|
| 201 |
-
if k in own and own[k].shape == v.shape:
|
| 202 |
-
filtered[k] = v
|
| 203 |
-
elif k in own:
|
| 204 |
-
print(f"[Ensemble] Skipping mismatched {name} key {k}: checkpoint {v.shape}, model {own[k].shape}")
|
| 205 |
-
if filtered:
|
| 206 |
-
module.load_state_dict(filtered, strict=False)
|
| 207 |
-
|
| 208 |
-
@classmethod
|
| 209 |
-
def load(cls, path: str, device: str = "cpu", max_loaded_experts: int = 3) -> "ShorekeeperEnsemble":
|
| 210 |
-
ckpt = torch.load(path, map_location="cpu")
|
| 211 |
-
base_cfg = None
|
| 212 |
-
expert_cfgs = None
|
| 213 |
-
if isinstance(ckpt, dict):
|
| 214 |
-
base_cfg = ckpt.get("base_config", None)
|
| 215 |
-
expert_cfgs = ckpt.get("expert_configs", None)
|
| 216 |
-
|
| 217 |
-
model = cls(
|
| 218 |
-
max_loaded_experts=max_loaded_experts,
|
| 219 |
-
base_cfg=base_cfg,
|
| 220 |
-
expert_cfgs=expert_cfgs,
|
| 221 |
-
).to(device)
|
| 222 |
-
model.expert_loader.set_device(device, dtype=DTYPE)
|
| 223 |
-
|
| 224 |
-
state: dict = {}
|
| 225 |
-
if isinstance(ckpt, dict):
|
| 226 |
-
if "model_state" in ckpt:
|
| 227 |
-
state = ckpt["model_state"]
|
| 228 |
-
elif "model_state_dict" in ckpt:
|
| 229 |
-
state = ckpt["model_state_dict"]
|
| 230 |
-
elif "base_state_dict" in ckpt or "herald_state_dict" in ckpt:
|
| 231 |
-
# Modern saved ensemble format
|
| 232 |
-
cls._safe_load_state_dict(model.base, ckpt.get("base_state_dict", {}), "base")
|
| 233 |
-
cls._safe_load_state_dict(model.herald, ckpt.get("herald_state_dict", {}), "herald")
|
| 234 |
-
cls._safe_load_state_dict(model.echo, ckpt.get("echo_state_dict", {}), "echo")
|
| 235 |
-
cls._safe_load_state_dict(model.sentinel, ckpt.get("sentinel_state_dict", {}), "sentinel")
|
| 236 |
-
print(f"[Ensemble] Loaded from {path}")
|
| 237 |
-
return model
|
| 238 |
-
else:
|
| 239 |
-
state = ckpt
|
| 240 |
-
else:
|
| 241 |
-
raise ValueError(f"Unsupported checkpoint type: {type(ckpt)}")
|
| 242 |
-
|
| 243 |
-
if not isinstance(state, dict):
|
| 244 |
-
raise ValueError("Checkpoint does not contain recognizable state dict")
|
| 245 |
-
|
| 246 |
-
# Load base + modules if keys present
|
| 247 |
-
base_state = {k.replace("shared_base.", ""): v for k, v in state.items() if k.startswith("shared_base.")}
|
| 248 |
-
if base_state:
|
| 249 |
-
cls._safe_load_state_dict(model.base, base_state, "base")
|
| 250 |
-
|
| 251 |
-
herald_state = {k.replace("herald.", ""): v for k, v in state.items() if k.startswith("herald.")}
|
| 252 |
-
if herald_state:
|
| 253 |
-
cls._safe_load_state_dict(model.herald, herald_state, "herald")
|
| 254 |
-
|
| 255 |
-
echo_state = {k.replace("echo.", ""): v for k, v in state.items() if k.startswith("echo.")}
|
| 256 |
-
if echo_state:
|
| 257 |
-
cls._safe_load_state_dict(model.echo, echo_state, "echo")
|
| 258 |
-
|
| 259 |
-
sentinel_state = {k.replace("sentinel.", ""): v for k, v in state.items() if k.startswith("sentinel.")}
|
| 260 |
-
if sentinel_state:
|
| 261 |
-
cls._safe_load_state_dict(model.sentinel, sentinel_state, "sentinel")
|
| 262 |
-
|
| 263 |
-
# Load expert states if available
|
| 264 |
-
expert_keys = [k for k in state.keys() if k.startswith("experts.") or k.startswith("expert.")]
|
| 265 |
-
if expert_keys:
|
| 266 |
-
for name in EXPERT_NAMES:
|
| 267 |
-
for prefix in [f"experts.{name}.", f"expert.{name}."]:
|
| 268 |
-
expert_state = {k.replace(prefix, ""): v for k, v in state.items() if k.startswith(prefix)}
|
| 269 |
-
if expert_state:
|
| 270 |
-
expert = model.expert_loader.get_expert(name)
|
| 271 |
-
cls._safe_load_state_dict(expert, expert_state, f"expert.{name}")
|
| 272 |
-
break
|
| 273 |
-
|
| 274 |
-
# If no prefixed expert keys, maybe experts are at top-level names
|
| 275 |
-
for name in EXPERT_NAMES:
|
| 276 |
-
expert_state = {k.replace(f"{name}.", ""): v for k, v in state.items() if k.startswith(f"{name}.")}
|
| 277 |
-
if expert_state:
|
| 278 |
-
expert = model.expert_loader.get_expert(name)
|
| 279 |
-
cls._safe_load_state_dict(expert, expert_state, f"expert.{name}")
|
| 280 |
-
|
| 281 |
-
print(f"[Ensemble] Loaded from {path}")
|
| 282 |
-
return model
|
| 283 |
-
|
| 284 |
-
@classmethod
|
| 285 |
-
def load_from_components(cls, checkpoint_dir: str = None, device: str = "cpu") -> "ShorekeeperEnsemble":
|
| 286 |
-
model = cls().to(device)
|
| 287 |
-
model.expert_loader.set_device(device, dtype=DTYPE)
|
| 288 |
-
ckpt_root = Path(checkpoint_dir or CHECKPOINT_DIR)
|
| 289 |
-
base_ckpt = ckpt_root / "base" / "best.pt"
|
| 290 |
-
if base_ckpt.exists():
|
| 291 |
-
try:
|
| 292 |
-
ckpt = torch.load(base_ckpt, map_location="cpu")
|
| 293 |
-
state = ckpt.get("model_state_dict", ckpt)
|
| 294 |
-
base_state = {k.replace("0.", ""): v for k, v in state.items() if k.startswith("0.")}
|
| 295 |
-
if base_state:
|
| 296 |
-
model.base.load_state_dict(base_state, strict=False)
|
| 297 |
-
else:
|
| 298 |
-
model.base.load_state_dict(state, strict=False)
|
| 299 |
-
except Exception as e:
|
| 300 |
-
print(f"[Ensemble] Warning: failed to load base checkpoint {base_ckpt}: {e}")
|
| 301 |
-
|
| 302 |
-
# Experts are loaded lazily by LazyExpertLoader when first routed.
|
| 303 |
-
# Ensure checkpoint availability is set by LazyExpertLoader initialization.
|
| 304 |
-
|
| 305 |
-
herald_path = ckpt_root / "herald" / "best.pt"
|
| 306 |
-
if herald_path.exists():
|
| 307 |
-
try:
|
| 308 |
-
ckpt = torch.load(herald_path, map_location="cpu")
|
| 309 |
-
model.herald.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
|
| 310 |
-
except Exception as e:
|
| 311 |
-
print(f"[Ensemble] Warning: failed to load herald checkpoint {herald_path}: {e}")
|
| 312 |
-
|
| 313 |
-
sentinel_path = ckpt_root / "sentinel" / "best.pt"
|
| 314 |
-
if sentinel_path.exists():
|
| 315 |
-
try:
|
| 316 |
-
ckpt = torch.load(sentinel_path, map_location="cpu")
|
| 317 |
-
model.sentinel.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
|
| 318 |
-
except Exception as e:
|
| 319 |
-
print(f"[Ensemble] Warning: failed to load sentinel checkpoint {sentinel_path}: {e}")
|
| 320 |
-
|
| 321 |
-
return model
|
| 322 |
-
|
| 323 |
-
# ── Expert management ─────────────────────────────────────────────────
|
| 324 |
-
def preload_experts(self, expert_names: list):
|
| 325 |
-
self.expert_loader.preload_experts(expert_names)
|
| 326 |
-
|
| 327 |
-
def get_loaded_experts(self):
|
| 328 |
-
return [name for name, loaded in self.expert_loader.get_cache_status().items() if loaded]
|
| 329 |
-
|
| 330 |
-
def clear_expert_cache(self):
|
| 331 |
-
self.expert_loader.clear_cache()
|
| 332 |
-
|
| 333 |
-
def set_max_loaded_experts(self, max_loaded: int):
|
| 334 |
-
self.expert_loader.set_max_loaded(max_loaded)
|
| 335 |
-
|
| 336 |
-
def freeze_base(self):
|
| 337 |
-
for p in self.base.parameters(): p.requires_grad = False
|
| 338 |
-
|
| 339 |
-
def unfreeze_all(self):
|
| 340 |
-
for p in self.parameters(): p.requires_grad = True
|
| 341 |
-
|
| 342 |
-
def get_trainable_count(self):
|
| 343 |
-
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 344 |
-
|
| 345 |
-
def get_total_count(self):
|
| 346 |
-
return sum(p.numel() for p in self.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/expert.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
# model/expert.py
|
| 2 |
-
import sys, torch, torch.nn as nn
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 5 |
-
import config
|
| 6 |
-
from model.base import TransformerBlock
|
| 7 |
-
from tokenizer.tokenizer_utils import vocab_size
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class ExpertHead(nn.Module):
|
| 11 |
-
"""
|
| 12 |
-
Domain-specialized transformer head.
|
| 13 |
-
Receives SharedBase hidden states and processes them
|
| 14 |
-
through its own transformer layers to produce domain-specific logits.
|
| 15 |
-
"""
|
| 16 |
-
def __init__(self, expert_name: str, expert_cfg=None, base_cfg=None):
|
| 17 |
-
super().__init__()
|
| 18 |
-
from config import EXPERT_NAMES
|
| 19 |
-
assert expert_name in EXPERT_NAMES, f"Unknown expert: {expert_name}"
|
| 20 |
-
if expert_cfg is None:
|
| 21 |
-
expert_cfg = config.EXPERT_CONFIGS[expert_name]
|
| 22 |
-
if base_cfg is None:
|
| 23 |
-
base_cfg = config.BASE_CONFIG
|
| 24 |
-
n_embd = expert_cfg["n_embd"]
|
| 25 |
-
self.input_norm = nn.LayerNorm(n_embd)
|
| 26 |
-
self.input_proj = nn.Linear(base_cfg["n_embd"], n_embd, bias=False)
|
| 27 |
-
self.blocks = nn.ModuleList([
|
| 28 |
-
TransformerBlock(n_embd, expert_cfg["n_head"]) for _ in range(expert_cfg["n_layer"])
|
| 29 |
-
])
|
| 30 |
-
self.ln_f = nn.LayerNorm(n_embd)
|
| 31 |
-
self.lm_head = nn.Linear(n_embd, base_cfg["vocab_size"], bias=False)
|
| 32 |
-
nn.init.normal_(self.lm_head.weight, std=0.02)
|
| 33 |
-
|
| 34 |
-
def forward(self, base_hidden, attention_mask=None):
|
| 35 |
-
x = self.input_proj(self.input_norm(base_hidden))
|
| 36 |
-
for block in self.blocks:
|
| 37 |
-
x = block(x, attention_mask)
|
| 38 |
-
x = self.ln_f(x)
|
| 39 |
-
return {"logits": self.lm_head(x), "hidden": x}
|
| 40 |
-
|
| 41 |
-
def freeze(self):
|
| 42 |
-
for p in self.parameters(): p.requires_grad = False
|
| 43 |
-
|
| 44 |
-
def unfreeze(self):
|
| 45 |
-
for p in self.parameters(): p.requires_grad = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/herald.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
# model/herald.py
|
| 2 |
-
import sys, torch, torch.nn as nn, torch.nn.functional as F
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 5 |
-
import config
|
| 6 |
-
from model.base import TransformerBlock
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class Herald(nn.Module):
|
| 10 |
-
"""
|
| 11 |
-
Routes queries to the top-K most relevant experts.
|
| 12 |
-
Trained as a classifier: given base hidden states, predict expert indices.
|
| 13 |
-
Uses load-balance loss to prevent routing collapse (always picking 1 expert).
|
| 14 |
-
"""
|
| 15 |
-
def __init__(self, cfg=None, n_embd=None):
|
| 16 |
-
super().__init__()
|
| 17 |
-
cfg = cfg or config.HERALD_CONFIG
|
| 18 |
-
if n_embd is None:
|
| 19 |
-
n_embd = cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
|
| 20 |
-
self.top_k = cfg.get("top_k", 2)
|
| 21 |
-
self.n_experts = cfg.get("n_experts", len(config.EXPERT_NAMES))
|
| 22 |
-
self.query_encoder = nn.ModuleList([
|
| 23 |
-
TransformerBlock(n_embd, cfg.get("n_head", config.BASE_CONFIG["n_head"])) for _ in range(cfg.get("n_layer", 2))
|
| 24 |
-
])
|
| 25 |
-
self.query_norm = nn.LayerNorm(n_embd)
|
| 26 |
-
self.expert_scorer = nn.Linear(n_embd, self.n_experts, bias=True)
|
| 27 |
-
|
| 28 |
-
def forward(self, base_hidden, attention_mask=None):
|
| 29 |
-
x = base_hidden
|
| 30 |
-
for block in self.query_encoder:
|
| 31 |
-
x = block(x, attention_mask)
|
| 32 |
-
x = self.query_norm(x)
|
| 33 |
-
# Pool: weighted mean over sequence (ignore padding)
|
| 34 |
-
if attention_mask is not None:
|
| 35 |
-
mask = attention_mask[:, :, None].float()
|
| 36 |
-
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
|
| 37 |
-
else:
|
| 38 |
-
pooled = x.mean(1)
|
| 39 |
-
scores = self.expert_scorer(pooled) # [B, n_experts]
|
| 40 |
-
probs = F.softmax(scores, dim=-1)
|
| 41 |
-
top_probs, top_idx = probs.topk(self.top_k, dim=-1)
|
| 42 |
-
# Load balance loss: encourages uniform expert utilization
|
| 43 |
-
# From "Switch Transformers" (Fedus et al. 2022)
|
| 44 |
-
expert_frac = probs.mean(0)
|
| 45 |
-
target_frac = torch.ones_like(expert_frac) / self.n_experts
|
| 46 |
-
lb_loss = F.mse_loss(expert_frac, target_frac)
|
| 47 |
-
return {
|
| 48 |
-
"router_probs": probs,
|
| 49 |
-
"expert_indices": top_idx,
|
| 50 |
-
"expert_weights": top_probs,
|
| 51 |
-
"load_balance_loss": lb_loss,
|
| 52 |
-
}
|
| 53 |
-
|
| 54 |
-
def get_routing_display(self, base_hidden, attention_mask=None) -> str:
|
| 55 |
-
routing = self.forward(base_hidden, attention_mask)
|
| 56 |
-
lines = []
|
| 57 |
-
probs = routing["router_probs"][0]
|
| 58 |
-
names = config.EXPERT_NAMES if hasattr(config, "EXPERT_NAMES") else []
|
| 59 |
-
for i, name in enumerate(names):
|
| 60 |
-
bar = "█" * int(probs[i].item() * 20)
|
| 61 |
-
lines.append(f" {name:<12} {bar:<20} {probs[i]:.3f}")
|
| 62 |
-
return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/lazy_expert_loader.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Lazy expert loading system for Shorekeeper.
|
| 3 |
-
Only loads experts into VRAM when they're actually routed to.
|
| 4 |
-
Implements LRU caching to keep hot experts loaded while evicting cold ones.
|
| 5 |
-
"""
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from collections import OrderedDict
|
| 10 |
-
from typing import Dict
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
class LazyExpertLoader(nn.Module):
|
| 14 |
-
"""Manages expert models with lazy loading and LRU caching."""
|
| 15 |
-
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
expert_names: list,
|
| 19 |
-
checkpoint_dir: Path,
|
| 20 |
-
device: str = "cuda",
|
| 21 |
-
max_loaded: int = 3,
|
| 22 |
-
dtype: torch.dtype = torch.bfloat16,
|
| 23 |
-
base_cfg: dict = None,
|
| 24 |
-
expert_cfgs: dict = None,
|
| 25 |
-
):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.expert_names = expert_names
|
| 28 |
-
self.checkpoint_dir = Path(checkpoint_dir)
|
| 29 |
-
self.device = device
|
| 30 |
-
self.max_loaded = max_loaded
|
| 31 |
-
self.dtype = dtype
|
| 32 |
-
self.base_cfg = base_cfg or {}
|
| 33 |
-
self.expert_cfgs = expert_cfgs or {}
|
| 34 |
-
|
| 35 |
-
self.experts = nn.ModuleDict()
|
| 36 |
-
self._lru = OrderedDict()
|
| 37 |
-
|
| 38 |
-
self._available_checkpoints = {
|
| 39 |
-
name: self.checkpoint_dir / "experts" / name / "best.pt"
|
| 40 |
-
for name in expert_names
|
| 41 |
-
if (self.checkpoint_dir / "experts" / name / "best.pt").exists()
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
print(f"[LazyExpertLoader] Initialized with max_loaded={max_loaded}")
|
| 45 |
-
print(f"[LazyExpertLoader] Found {len(self._available_checkpoints)} expert checkpoints")
|
| 46 |
-
|
| 47 |
-
def get_expert(self, expert_name: str) -> nn.Module:
|
| 48 |
-
if expert_name in self.experts:
|
| 49 |
-
self._lru.move_to_end(expert_name)
|
| 50 |
-
return self.experts[expert_name]
|
| 51 |
-
|
| 52 |
-
expert = self._load_expert(expert_name)
|
| 53 |
-
self.experts[expert_name] = expert
|
| 54 |
-
self._lru[expert_name] = True
|
| 55 |
-
|
| 56 |
-
if len(self._lru) > self.max_loaded:
|
| 57 |
-
evicted_name, _ = self._lru.popitem(last=False)
|
| 58 |
-
evicted_expert = self.experts.pop(evicted_name)
|
| 59 |
-
evicted_expert.to("cpu")
|
| 60 |
-
print(f"[LazyExpertLoader] Evicted '{evicted_name}' from cache")
|
| 61 |
-
|
| 62 |
-
return expert
|
| 63 |
-
|
| 64 |
-
def _load_expert(self, expert_name: str) -> nn.Module:
|
| 65 |
-
from model.expert import ExpertHead
|
| 66 |
-
|
| 67 |
-
expert_cfg = self.expert_cfgs.get(expert_name, None)
|
| 68 |
-
print(f"[LazyExpertLoader] Loading expert '{expert_name}'...")
|
| 69 |
-
expert = ExpertHead(expert_name, expert_cfg=expert_cfg, base_cfg=self.base_cfg)
|
| 70 |
-
|
| 71 |
-
ckpt_path = self._available_checkpoints.get(expert_name,
|
| 72 |
-
self.checkpoint_dir / "experts" / expert_name / "best.pt")
|
| 73 |
-
if ckpt_path.exists():
|
| 74 |
-
try:
|
| 75 |
-
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 76 |
-
state_dict = ckpt.get("model_state_dict", ckpt)
|
| 77 |
-
expert.load_state_dict(state_dict, strict=False)
|
| 78 |
-
print(f"[LazyExpertLoader] Loaded weights from {ckpt_path}")
|
| 79 |
-
except Exception as e:
|
| 80 |
-
print(f"[LazyExpertLoader] Warning: Failed to load {ckpt_path}: {e}")
|
| 81 |
-
else:
|
| 82 |
-
print(f"[LazyExpertLoader] No checkpoint found for '{expert_name}', using random init")
|
| 83 |
-
|
| 84 |
-
load_dtype = self.dtype
|
| 85 |
-
if self.device == "cpu" and load_dtype in [torch.float16, torch.bfloat16]:
|
| 86 |
-
load_dtype = torch.float32
|
| 87 |
-
expert = expert.to(device=self.device, dtype=load_dtype)
|
| 88 |
-
expert.eval()
|
| 89 |
-
return expert
|
| 90 |
-
|
| 91 |
-
def preload_experts(self, expert_names: list):
|
| 92 |
-
for name in expert_names:
|
| 93 |
-
self.get_expert(name)
|
| 94 |
-
|
| 95 |
-
def clear_cache(self):
|
| 96 |
-
for e in self.experts.values():
|
| 97 |
-
e.to("cpu")
|
| 98 |
-
self.experts.clear()
|
| 99 |
-
self._lru.clear()
|
| 100 |
-
print("[LazyExpertLoader] Cache cleared")
|
| 101 |
-
|
| 102 |
-
def get_cache_status(self) -> Dict[str, bool]:
|
| 103 |
-
return {name: name in self.experts for name in self.expert_names}
|
| 104 |
-
|
| 105 |
-
def set_device(self, device: str, dtype: torch.dtype = None):
|
| 106 |
-
self.device = device
|
| 107 |
-
if dtype is not None:
|
| 108 |
-
self.dtype = dtype
|
| 109 |
-
if device == "cpu" and self.dtype == torch.float16:
|
| 110 |
-
self.dtype = torch.float32
|
| 111 |
-
for expert in self.experts.values():
|
| 112 |
-
expert.to(device=device, dtype=self.dtype)
|
| 113 |
-
|
| 114 |
-
def set_max_loaded(self, max_loaded: int):
|
| 115 |
-
self.max_loaded = max_loaded
|
| 116 |
-
while len(self._lru) > self.max_loaded:
|
| 117 |
-
evicted_name, _ = self._lru.popitem(last=False)
|
| 118 |
-
evicted_expert = self.experts.pop(evicted_name)
|
| 119 |
-
evicted_expert.to("cpu")
|
| 120 |
-
print(f"[LazyExpertLoader] Evicted '{evicted_name}' (cache resize)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/sentinel.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
# model/sentinel.py
|
| 2 |
-
import sys, torch, torch.nn as nn
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
import time
|
| 6 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 7 |
-
import config
|
| 8 |
-
from model.base import TransformerBlock
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class SentinelReport:
|
| 13 |
-
verdict: str # "CLEAN", "FLAG", or "BLOCK"
|
| 14 |
-
drift_score: float # 0-1: probability of behavioral drift
|
| 15 |
-
refusal_score: float # 0-1: probability of inappropriate refusal
|
| 16 |
-
hallucination_score: float # 0-1: probability of hallucination
|
| 17 |
-
overall_risk: float # Combined risk score
|
| 18 |
-
|
| 19 |
-
@property
|
| 20 |
-
def is_clean(self) -> bool:
|
| 21 |
-
return self.verdict == "CLEAN"
|
| 22 |
-
|
| 23 |
-
@property
|
| 24 |
-
def protocol(self) -> str:
|
| 25 |
-
return self.verdict.lower()
|
| 26 |
-
|
| 27 |
-
@property
|
| 28 |
-
def total(self) -> float:
|
| 29 |
-
return self.overall_risk
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class Sentinel(nn.Module):
|
| 33 |
-
"""Monitors outputs for behavioral drift in real time."""
|
| 34 |
-
def __init__(self, cfg=None, n_embd=None):
|
| 35 |
-
super().__init__()
|
| 36 |
-
cfg = cfg or config.EXPERT_CONFIGS.get("sentinel", {})
|
| 37 |
-
n = n_embd or cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
|
| 38 |
-
hidden = config.SENTINEL_CONFIG["hidden_dim"]
|
| 39 |
-
self.encoder_blocks = nn.ModuleList([
|
| 40 |
-
TransformerBlock(n, cfg.get("n_head", config.BASE_CONFIG["n_head"])) for _ in range(cfg.get("n_layer", 4))
|
| 41 |
-
])
|
| 42 |
-
self.encoder_norm = nn.LayerNorm(n)
|
| 43 |
-
self.pool = nn.AdaptiveAvgPool1d(1)
|
| 44 |
-
self.proj = nn.Sequential(nn.Linear(n, hidden), nn.ReLU(), nn.Dropout(0.1))
|
| 45 |
-
self.drift_head = nn.Linear(hidden, 1)
|
| 46 |
-
self.refusal_head = nn.Linear(hidden, 1)
|
| 47 |
-
self.halluc_head = nn.Linear(hidden, 1)
|
| 48 |
-
self._incidents = [] # Session incident log
|
| 49 |
-
|
| 50 |
-
def forward(self, base_hidden, attention_mask=None):
|
| 51 |
-
return self.analyze(base_hidden, attention_mask)
|
| 52 |
-
|
| 53 |
-
@torch.no_grad()
|
| 54 |
-
def analyze(self, base_hidden, attention_mask=None) -> SentinelReport:
|
| 55 |
-
x = base_hidden
|
| 56 |
-
for block in self.encoder_blocks:
|
| 57 |
-
x = block(x, attention_mask)
|
| 58 |
-
x = self.encoder_norm(x)
|
| 59 |
-
pooled = self.pool(x.transpose(1, 2)).squeeze(-1)
|
| 60 |
-
feats = self.proj(pooled)
|
| 61 |
-
drift = self.drift_head(feats).squeeze(-1).sigmoid().mean().item()
|
| 62 |
-
refus = self.refusal_head(feats).squeeze(-1).sigmoid().mean().item()
|
| 63 |
-
halluc = self.halluc_head(feats).squeeze(-1).sigmoid().mean().item()
|
| 64 |
-
risk = max(drift, refus * 0.8, halluc * 0.6)
|
| 65 |
-
cfg = SENTINEL_CONFIG
|
| 66 |
-
if risk >= cfg["block_threshold"]: verdict = "BLOCK"
|
| 67 |
-
elif risk >= cfg["flag_threshold"]: verdict = "FLAG"
|
| 68 |
-
else: verdict = "CLEAN"
|
| 69 |
-
report = SentinelReport(verdict, drift, refus, halluc, risk)
|
| 70 |
-
if not report.is_clean:
|
| 71 |
-
self._incidents.append({
|
| 72 |
-
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
| 73 |
-
"expert": "unknown",
|
| 74 |
-
"protocol": report.protocol,
|
| 75 |
-
"score": risk,
|
| 76 |
-
})
|
| 77 |
-
return report
|
| 78 |
-
|
| 79 |
-
def reset_session(self):
|
| 80 |
-
"""Clear the session incident log."""
|
| 81 |
-
self._incidents.clear()
|
| 82 |
-
|
| 83 |
-
def get_incidents(self) -> list:
|
| 84 |
-
"""Return all incidents logged this session."""
|
| 85 |
-
return list(self._incidents)
|
| 86 |
-
|
| 87 |
-
def log_expert(self, expert_name: str):
|
| 88 |
-
"""Tag the most recent incident with the expert name."""
|
| 89 |
-
if self._incidents:
|
| 90 |
-
self._incidents[-1]["expert"] = expert_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,4 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
211torch>=2.5.0
|
| 2 |
+
transformers>=4.53.0
|
| 3 |
+
accelerate>=1.0.0
|
| 4 |
+
sentencepiece>=0.2.0
|
| 5 |
+
bitsandbytes>=0.45.0
|
| 6 |
+
datasets>=3.0.0
|
| 7 |
+
math-verify>=0.5.2
|
| 8 |
+
chromadb>=0.5.0
|
| 9 |
+
sentence-transformers>=3.0.0
|
| 10 |
+
docker>=7.1.0
|
| 11 |
+
trl>=0.20.0
|
| 12 |
+
wandb>=0.19.0
|
| 13 |
+
pyyaml>=6.0
|
| 14 |
+
tqdm>=4.66.0
|
scripts/01_download_15b_data.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download MASSIVE datasets for 15B training
|
| 4 |
+
200B+ tokens from verified STEM sources
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
def download_15b_data():
|
| 12 |
+
print("=" * 80)
|
| 13 |
+
print("DOWNLOADING 200B+ TOKENS FOR 15B MODEL")
|
| 14 |
+
print("=" * 80)
|
| 15 |
+
|
| 16 |
+
data_dir = Path("./data/15b_data")
|
| 17 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
all_data = []
|
| 20 |
+
total_tokens = 0
|
| 21 |
+
|
| 22 |
+
# 1. The Pile - 800GB, 300B tokens (take 50B)
|
| 23 |
+
print("\n1. The Pile (300B tokens - taking 50B)...")
|
| 24 |
+
print(" This will take 1-2 days...")
|
| 25 |
+
try:
|
| 26 |
+
pile = load_dataset("EleutherAI/pile", split="train[:5000000]")
|
| 27 |
+
for item in pile:
|
| 28 |
+
text = item.get("text", "")
|
| 29 |
+
if text and len(text) > 200:
|
| 30 |
+
all_data.append({
|
| 31 |
+
"text": text,
|
| 32 |
+
"source": "pile"
|
| 33 |
+
})
|
| 34 |
+
print(f" ✓ Added {len(pile):,} examples")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f" ✗ Failed: {e}")
|
| 37 |
+
|
| 38 |
+
# 2. Proof-Pile-2 - 50B tokens of math/CS papers
|
| 39 |
+
print("\n2. Proof-Pile-2 (50B tokens - taking 20B)...")
|
| 40 |
+
try:
|
| 41 |
+
proof = load_dataset("EleutherAI/proof-pile-2", split="train[:2000000]")
|
| 42 |
+
for item in proof:
|
| 43 |
+
text = item.get("text", "")
|
| 44 |
+
if text and len(text) > 200:
|
| 45 |
+
all_data.append({
|
| 46 |
+
"text": text,
|
| 47 |
+
"source": "proofpile"
|
| 48 |
+
})
|
| 49 |
+
print(f" ✓ Added {len(proof):,} examples")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f" ✗ Failed: {e}")
|
| 52 |
+
|
| 53 |
+
# 3. StarCoder - 100B tokens of code
|
| 54 |
+
print("\n3. StarCoder (100B tokens - taking 30B)...")
|
| 55 |
+
try:
|
| 56 |
+
code = load_dataset("bigcode/starcoderdata", split="train[:3000000]")
|
| 57 |
+
for item in code:
|
| 58 |
+
content = item.get("content", "")
|
| 59 |
+
if content and len(content) > 100:
|
| 60 |
+
all_data.append({
|
| 61 |
+
"text": content,
|
| 62 |
+
"source": "starcoder"
|
| 63 |
+
})
|
| 64 |
+
print(f" ✓ Added {len(code):,} examples")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f" ✗ Failed: {e}")
|
| 67 |
+
|
| 68 |
+
# 4. C4 - 156GB, 150B tokens (take 30B)
|
| 69 |
+
print("\n4. C4 (150B tokens - taking 30B)...")
|
| 70 |
+
try:
|
| 71 |
+
c4 = load_dataset("c4", "en", split="train[:3000000]")
|
| 72 |
+
for item in c4:
|
| 73 |
+
text = item.get("text", "")
|
| 74 |
+
if text and len(text) > 200:
|
| 75 |
+
all_data.append({
|
| 76 |
+
"text": text,
|
| 77 |
+
"source": "c4"
|
| 78 |
+
})
|
| 79 |
+
print(f" ✓ Added {len(c4):,} examples")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f" ✗ Failed: {e}")
|
| 82 |
+
|
| 83 |
+
# 5. OpenWebMath - 14B tokens of math
|
| 84 |
+
print("\n5. OpenWebMath (14B tokens - taking all)...")
|
| 85 |
+
try:
|
| 86 |
+
math = load_dataset("open-web-math/open-web-math", split="train")
|
| 87 |
+
for item in math:
|
| 88 |
+
text = item.get("text", "")
|
| 89 |
+
if text and len(text) > 200:
|
| 90 |
+
all_data.append({
|
| 91 |
+
"text": text,
|
| 92 |
+
"source": "openwebmath"
|
| 93 |
+
})
|
| 94 |
+
print(f" ✓ Added {len(math):,} examples")
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f" ✗ Failed: {e}")
|
| 97 |
+
|
| 98 |
+
print("\n" + "=" * 80)
|
| 99 |
+
print(f"TOTAL EXAMPLES: {len(all_data):,}")
|
| 100 |
+
print(f"ESTIMATED TOKENS: {len(all_data) * 500:,}")
|
| 101 |
+
print("=" * 80)
|
| 102 |
+
|
| 103 |
+
# Save
|
| 104 |
+
print("\nSaving to disk...")
|
| 105 |
+
with open(data_dir / "15b_train.jsonl", "w") as f:
|
| 106 |
+
for item in all_data:
|
| 107 |
+
f.write(json.dumps(item) + "\n")
|
| 108 |
+
|
| 109 |
+
print(f"✓ Saved to: {data_dir}/15b_train.jsonl")
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
download_15b_data()
|
scripts/01_download_7b_150gb.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
150GB Curated STEM Dataset for 7B Model Training
|
| 4 |
+
Enough for a high-quality 7B model from scratch
|
| 5 |
+
Total: ~150GB compressed, ~500GB uncompressed
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
def download_7b_dataset():
|
| 14 |
+
print("=" * 80)
|
| 15 |
+
print("DOWNLOADING 150GB STEM DATASET FOR 7B MODEL")
|
| 16 |
+
print("=" * 80)
|
| 17 |
+
print("\n⚠️ This will download ~150GB of data")
|
| 18 |
+
print(" Estimated time: 4-8 hours depending on connection")
|
| 19 |
+
print(" Disk space needed: ~500GB after decompression")
|
| 20 |
+
print("\nPress Ctrl+C to cancel, or wait 5 seconds to continue...")
|
| 21 |
+
time.sleep(5)
|
| 22 |
+
|
| 23 |
+
data_dir = Path("./data/7b_150gb")
|
| 24 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
all_data = []
|
| 27 |
+
total_examples = 0
|
| 28 |
+
|
| 29 |
+
# ============================================================
|
| 30 |
+
# DATASET 1: The Pile - 50GB (largest single source)
|
| 31 |
+
# ============================================================
|
| 32 |
+
print("\n" + "=" * 80)
|
| 33 |
+
print("DATASET 1: The Pile (50GB - General text)")
|
| 34 |
+
print("=" * 80)
|
| 35 |
+
try:
|
| 36 |
+
pile = load_dataset("EleutherAI/pile", split="train[:2000000]")
|
| 37 |
+
for item in pile:
|
| 38 |
+
text = item.get("text", "")
|
| 39 |
+
if text and len(text) > 500:
|
| 40 |
+
all_data.append({
|
| 41 |
+
"text": text[:2048],
|
| 42 |
+
"source": "pile"
|
| 43 |
+
})
|
| 44 |
+
print(f" ✓ Added {len(pile):,} examples")
|
| 45 |
+
total_examples += len(pile)
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f" ✗ Failed: {e}")
|
| 48 |
+
|
| 49 |
+
# ============================================================
|
| 50 |
+
# DATASET 2: StarCoder - 30GB (Code)
|
| 51 |
+
# ============================================================
|
| 52 |
+
print("\n" + "=" * 80)
|
| 53 |
+
print("DATASET 2: StarCoder (30GB - Code)")
|
| 54 |
+
print("=" * 80)
|
| 55 |
+
try:
|
| 56 |
+
code = load_dataset("bigcode/starcoderdata", split="train[:1500000]")
|
| 57 |
+
for item in code:
|
| 58 |
+
content = item.get("content", "")
|
| 59 |
+
if content and len(content) > 200:
|
| 60 |
+
all_data.append({
|
| 61 |
+
"text": content[:2048],
|
| 62 |
+
"source": "starcoder"
|
| 63 |
+
})
|
| 64 |
+
print(f" ✓ Added {len(code):,} examples")
|
| 65 |
+
total_examples += len(code)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f" ✗ Failed: {e}")
|
| 68 |
+
|
| 69 |
+
# ============================================================
|
| 70 |
+
# DATASET 3: C4 - 25GB (Clean web text)
|
| 71 |
+
# ============================================================
|
| 72 |
+
print("\n" + "=" * 80)
|
| 73 |
+
print("DATASET 3: C4 (25GB - Clean web text)")
|
| 74 |
+
print("=" * 80)
|
| 75 |
+
try:
|
| 76 |
+
c4 = load_dataset("c4", "en", split="train[:1500000]")
|
| 77 |
+
for item in c4:
|
| 78 |
+
text = item.get("text", "")
|
| 79 |
+
if text and len(text) > 300:
|
| 80 |
+
all_data.append({
|
| 81 |
+
"text": text[:2048],
|
| 82 |
+
"source": "c4"
|
| 83 |
+
})
|
| 84 |
+
print(f" ✓ Added {len(c4):,} examples")
|
| 85 |
+
total_examples += len(c4)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f" ✗ Failed: {e}")
|
| 88 |
+
|
| 89 |
+
# ============================================================
|
| 90 |
+
# DATASET 4: Proof-Pile-2 - 20GB (Math/CS papers)
|
| 91 |
+
# ============================================================
|
| 92 |
+
print("\n" + "=" * 80)
|
| 93 |
+
print("DATASET 4: Proof-Pile-2 (20GB - Math/CS papers)")
|
| 94 |
+
print("=" * 80)
|
| 95 |
+
try:
|
| 96 |
+
proof = load_dataset("EleutherAI/proof-pile-2", split="train[:1000000]")
|
| 97 |
+
for item in proof:
|
| 98 |
+
text = item.get("text", "")
|
| 99 |
+
if text and len(text) > 500:
|
| 100 |
+
all_data.append({
|
| 101 |
+
"text": text[:2048],
|
| 102 |
+
"source": "proofpile"
|
| 103 |
+
})
|
| 104 |
+
print(f" ✓ Added {len(proof):,} examples")
|
| 105 |
+
total_examples += len(proof)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f" ✗ Failed: {e}")
|
| 108 |
+
|
| 109 |
+
# ============================================================
|
| 110 |
+
# DATASET 5: OpenWebMath - 10GB (Math web pages)
|
| 111 |
+
# ============================================================
|
| 112 |
+
print("\n" + "=" * 80)
|
| 113 |
+
print("DATASET 5: OpenWebMath (10GB - Math web pages)")
|
| 114 |
+
print("=" * 80)
|
| 115 |
+
try:
|
| 116 |
+
math = load_dataset("open-web-math/open-web-math", split="train[:500000]")
|
| 117 |
+
for item in math:
|
| 118 |
+
text = item.get("text", "")
|
| 119 |
+
if text and len(text) > 300:
|
| 120 |
+
all_data.append({
|
| 121 |
+
"text": text[:2048],
|
| 122 |
+
"source": "openwebmath"
|
| 123 |
+
})
|
| 124 |
+
print(f" ✓ Added {len(math):,} examples")
|
| 125 |
+
total_examples += len(math)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f" ✗ Failed: {e}")
|
| 128 |
+
|
| 129 |
+
# ============================================================
|
| 130 |
+
# DATASET 6: MetaMathQA - 2.5GB (Math problems)
|
| 131 |
+
# ============================================================
|
| 132 |
+
print("\n" + "=" * 80)
|
| 133 |
+
print("DATASET 6: MetaMathQA (2.5GB - Math problems)")
|
| 134 |
+
print("=" * 80)
|
| 135 |
+
try:
|
| 136 |
+
metamath = load_dataset("meta-math/MetaMathQA", split="train")
|
| 137 |
+
for item in metamath:
|
| 138 |
+
text = f"Question: {item.get('query', '')}\nAnswer: {item.get('response', '')}"
|
| 139 |
+
all_data.append({
|
| 140 |
+
"text": text,
|
| 141 |
+
"source": "metamath"
|
| 142 |
+
})
|
| 143 |
+
print(f" ✓ Added {len(metamath):,} examples")
|
| 144 |
+
total_examples += len(metamath)
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f" ✗ Failed: {e}")
|
| 147 |
+
|
| 148 |
+
# ============================================================
|
| 149 |
+
# DATASET 7: CodeFeedback - 2GB (Code instructions)
|
| 150 |
+
# ============================================================
|
| 151 |
+
print("\n" + "=" * 80)
|
| 152 |
+
print("DATASET 7: CodeFeedback (2GB - Code instructions)")
|
| 153 |
+
print("=" * 80)
|
| 154 |
+
try:
|
| 155 |
+
codefb = load_dataset("m-a-p/CodeFeedback", split="train[:150000]")
|
| 156 |
+
for item in codefb:
|
| 157 |
+
text = f"Instruction: {item.get('instruction', '')}\nCode: {item.get('output', '')}"
|
| 158 |
+
if len(text) > 100:
|
| 159 |
+
all_data.append({
|
| 160 |
+
"text": text[:2048],
|
| 161 |
+
"source": "codefeedback"
|
| 162 |
+
})
|
| 163 |
+
print(f" ✓ Added {len(codefb):,} examples")
|
| 164 |
+
total_examples += len(codefb)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f" ✗ Failed: {e}")
|
| 167 |
+
|
| 168 |
+
# ============================================================
|
| 169 |
+
# DATASET 8: OpenMathInstruct-2 - 2GB (Math problems)
|
| 170 |
+
# ============================================================
|
| 171 |
+
print("\n" + "=" * 80)
|
| 172 |
+
print("DATASET 8: OpenMathInstruct-2 (2GB - Math problems)")
|
| 173 |
+
print("=" * 80)
|
| 174 |
+
try:
|
| 175 |
+
openmath = load_dataset("nvidia/OpenMathInstruct-2", split="train[:150000]")
|
| 176 |
+
for item in openmath:
|
| 177 |
+
text = f"Problem: {item.get('question', '')}\nSolution: {item.get('generated_solution', '')}"
|
| 178 |
+
all_data.append({
|
| 179 |
+
"text": text[:2048],
|
| 180 |
+
"source": "openmath"
|
| 181 |
+
})
|
| 182 |
+
print(f" ✓ Added {len(openmath):,} examples")
|
| 183 |
+
total_examples += len(openmath)
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f" ✗ Failed: {e}")
|
| 186 |
+
|
| 187 |
+
# ============================================================
|
| 188 |
+
# DATASET 9: NuminaMath-CoT - 2GB (Math reasoning)
|
| 189 |
+
# ============================================================
|
| 190 |
+
print("\n" + "=" * 80)
|
| 191 |
+
print("DATASET 9: NuminaMath-CoT (2GB - Math reasoning)")
|
| 192 |
+
print("=" * 80)
|
| 193 |
+
try:
|
| 194 |
+
numina = load_dataset("AI-MO/NuminaMath-CoT", split="train[:100000]")
|
| 195 |
+
for item in numina:
|
| 196 |
+
text = f"Problem: {item.get('problem', '')}\nSolution: {item.get('solution', '')}"
|
| 197 |
+
all_data.append({
|
| 198 |
+
"text": text[:2048],
|
| 199 |
+
"source": "numinamath"
|
| 200 |
+
})
|
| 201 |
+
print(f" ✓ Added {len(numina):,} examples")
|
| 202 |
+
total_examples += len(numina)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f" ✗ Failed: {e}")
|
| 205 |
+
|
| 206 |
+
# ============================================================
|
| 207 |
+
# DATASET 10: ScienceQA - 0.5GB (Science questions)
|
| 208 |
+
# ============================================================
|
| 209 |
+
print("\n" + "=" * 80)
|
| 210 |
+
print("DATASET 10: ScienceQA (0.5GB - Science questions)")
|
| 211 |
+
print("=" * 80)
|
| 212 |
+
try:
|
| 213 |
+
science = load_dataset("derek-thomas/ScienceQA", split="train")
|
| 214 |
+
for item in science:
|
| 215 |
+
text = f"Question: {item.get('question', '')}\nAnswer: {item.get('answer', '')}"
|
| 216 |
+
all_data.append({
|
| 217 |
+
"text": text[:2048],
|
| 218 |
+
"source": "scienceqa"
|
| 219 |
+
})
|
| 220 |
+
print(f" ✓ Added {len(science):,} examples")
|
| 221 |
+
total_examples += len(science)
|
| 222 |
+
except Exception as e:
|
| 223 |
+
print(f" ✗ Failed: {e}")
|
| 224 |
+
|
| 225 |
+
# ============================================================
|
| 226 |
+
# SAVE DATASET
|
| 227 |
+
# ============================================================
|
| 228 |
+
print("\n" + "=" * 80)
|
| 229 |
+
print("SAVING DATASET")
|
| 230 |
+
print("=" * 80)
|
| 231 |
+
print(f"Total examples collected: {total_examples:,}")
|
| 232 |
+
print(f"Estimated size: ~150GB compressed, ~500GB uncompressed")
|
| 233 |
+
|
| 234 |
+
# Shuffle
|
| 235 |
+
import random
|
| 236 |
+
random.shuffle(all_data)
|
| 237 |
+
|
| 238 |
+
# Save as JSONL
|
| 239 |
+
output_path = data_dir / "7b_train.jsonl"
|
| 240 |
+
with open(output_path, "w") as f:
|
| 241 |
+
for item in all_data:
|
| 242 |
+
f.write(json.dumps(item) + "\n")
|
| 243 |
+
|
| 244 |
+
print(f"\n✓ Saved to: {output_path}")
|
| 245 |
+
print(f" File size: {output_path.stat().st_size / 1e9:.1f} GB")
|
| 246 |
+
|
| 247 |
+
# Save metadata
|
| 248 |
+
metadata = {
|
| 249 |
+
"total_examples": total_examples,
|
| 250 |
+
"sources": {}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
for item in all_data:
|
| 254 |
+
src = item['source']
|
| 255 |
+
metadata['sources'][src] = metadata['sources'].get(src, 0) + 1
|
| 256 |
+
|
| 257 |
+
with open(data_dir / "metadata.json", "w") as f:
|
| 258 |
+
json.dump(metadata, f, indent=2)
|
| 259 |
+
|
| 260 |
+
print("\n" + "=" * 80)
|
| 261 |
+
print("DATASET BREAKDOWN")
|
| 262 |
+
print("=" * 80)
|
| 263 |
+
for src, count in metadata['sources'].items():
|
| 264 |
+
print(f" {src}: {count:,} examples")
|
| 265 |
+
|
| 266 |
+
print("\n" + "=" * 80)
|
| 267 |
+
print("✅ DOWNLOAD COMPLETE!")
|
| 268 |
+
print("=" * 80)
|
| 269 |
+
print("\nNext step: python3 scripts/04_train_universal.py")
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
download_7b_dataset()
|
scripts/01_download_stem_data.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download high-quality STEM datasets for SHOREKEEPER
|
| 4 |
+
Math, Code, Science - No random web text
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
|
| 11 |
+
def download_stem_data():
|
| 12 |
+
print("=" * 70)
|
| 13 |
+
print("DOWNLOADING STEM DATASETS")
|
| 14 |
+
print("=" * 70)
|
| 15 |
+
|
| 16 |
+
data_dir = Path("./data/stem")
|
| 17 |
+
data_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
all_data = []
|
| 20 |
+
|
| 21 |
+
# 1. MetaMathQA - 395k math problems with step-by-step reasoning
|
| 22 |
+
print("\n1. MetaMathQA (395k math problems)...")
|
| 23 |
+
try:
|
| 24 |
+
dataset = load_dataset("meta-math/MetaMathQA", split="train")
|
| 25 |
+
print(f" Loading {len(dataset)} examples...")
|
| 26 |
+
for item in dataset:
|
| 27 |
+
all_data.append({
|
| 28 |
+
"prompt": item.get("query", ""),
|
| 29 |
+
"response": f"|special_token| {item.get('response', '')} |special_token|",
|
| 30 |
+
"source": "metamath"
|
| 31 |
+
})
|
| 32 |
+
print(f" ✓ Added {len(dataset)} math examples")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f" ✗ Failed: {e}")
|
| 35 |
+
|
| 36 |
+
# 2. CodeFeedback - 1.2M code instructions
|
| 37 |
+
print("\n2. CodeFeedback (1.2M code examples - taking 200k)...")
|
| 38 |
+
try:
|
| 39 |
+
dataset = load_dataset("m-a-p/CodeFeedback", split="train[:200000]")
|
| 40 |
+
print(f" Loading {len(dataset)} examples...")
|
| 41 |
+
for item in dataset:
|
| 42 |
+
instruction = item.get("instruction", "")
|
| 43 |
+
output = item.get("output", "")
|
| 44 |
+
if instruction and output:
|
| 45 |
+
all_data.append({
|
| 46 |
+
"prompt": instruction,
|
| 47 |
+
"response": f"|special_token| Here's the code:\n{output} |special_token|",
|
| 48 |
+
"source": "codefeedback"
|
| 49 |
+
})
|
| 50 |
+
print(f" ✓ Added {len(dataset)} code examples")
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f" ✗ Failed: {e}")
|
| 53 |
+
|
| 54 |
+
# 3. NuminaMath-CoT - 860k math problems
|
| 55 |
+
print("\n3. NuminaMath-CoT (860k math problems - taking 200k)...")
|
| 56 |
+
try:
|
| 57 |
+
dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train[:200000]")
|
| 58 |
+
print(f" Loading {len(dataset)} examples...")
|
| 59 |
+
for item in dataset:
|
| 60 |
+
problem = item.get("problem", "")
|
| 61 |
+
solution = item.get("solution", "")
|
| 62 |
+
if problem and solution:
|
| 63 |
+
all_data.append({
|
| 64 |
+
"prompt": problem,
|
| 65 |
+
"response": f"|special_token| Let me solve this step by step.\n{solution} |special_token|",
|
| 66 |
+
"source": "numinamath"
|
| 67 |
+
})
|
| 68 |
+
print(f" ✓ Added {len(dataset)} math examples")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f" ✗ Failed: {e}")
|
| 71 |
+
|
| 72 |
+
# 4. ScienceQA - 21k science questions
|
| 73 |
+
print("\n4. ScienceQA (21k science questions)...")
|
| 74 |
+
try:
|
| 75 |
+
dataset = load_dataset("derek-thomas/ScienceQA", split="train")
|
| 76 |
+
print(f" Loading {len(dataset)} examples...")
|
| 77 |
+
for item in dataset:
|
| 78 |
+
question = item.get("question", "")
|
| 79 |
+
answer = item.get("answer", "")
|
| 80 |
+
if question and answer:
|
| 81 |
+
all_data.append({
|
| 82 |
+
"prompt": question,
|
| 83 |
+
"response": f"|special_token| Science explanation:\n{answer} |special_token|",
|
| 84 |
+
"source": "scienceqa"
|
| 85 |
+
})
|
| 86 |
+
print(f" ✓ Added {len(dataset)} science examples")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f" ✗ Failed: {e}")
|
| 89 |
+
|
| 90 |
+
# 5. GSM8K - 8.5k grade school math
|
| 91 |
+
print("\n5. GSM8K (8.5k grade school math)...")
|
| 92 |
+
try:
|
| 93 |
+
dataset = load_dataset("gsm8k", "main", split="train")
|
| 94 |
+
print(f" Loading {len(dataset)} examples...")
|
| 95 |
+
for item in dataset:
|
| 96 |
+
question = item.get("question", "")
|
| 97 |
+
answer = item.get("answer", "").split("####")[-1].strip()
|
| 98 |
+
if question and answer:
|
| 99 |
+
all_data.append({
|
| 100 |
+
"prompt": question,
|
| 101 |
+
"response": f"|special_token| {answer} |special_token|",
|
| 102 |
+
"source": "gsm8k"
|
| 103 |
+
})
|
| 104 |
+
print(f" ✓ Added {len(dataset)} math examples")
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f" ✗ Failed: {e}")
|
| 107 |
+
|
| 108 |
+
print("\n" + "=" * 70)
|
| 109 |
+
print(f"TOTAL STEM EXAMPLES: {len(all_data):,}")
|
| 110 |
+
print("=" * 70)
|
| 111 |
+
|
| 112 |
+
# Show breakdown
|
| 113 |
+
sources = {}
|
| 114 |
+
for item in all_data:
|
| 115 |
+
src = item['source']
|
| 116 |
+
sources[src] = sources.get(src, 0) + 1
|
| 117 |
+
|
| 118 |
+
print("\nBreakdown by source:")
|
| 119 |
+
for src, count in sources.items():
|
| 120 |
+
print(f" {src}: {count:,}")
|
| 121 |
+
|
| 122 |
+
# Save
|
| 123 |
+
print("\nSaving to disk...")
|
| 124 |
+
with open(data_dir / "stem_train.jsonl", "w") as f:
|
| 125 |
+
for item in all_data:
|
| 126 |
+
f.write(json.dumps(item) + "\n")
|
| 127 |
+
|
| 128 |
+
print(f"✓ Saved to: {data_dir}/stem_train.jsonl")
|
| 129 |
+
print(f" Total size: {len(all_data):,} examples")
|
| 130 |
+
|
| 131 |
+
# Also create validation split
|
| 132 |
+
split_idx = int(len(all_data) * 0.95)
|
| 133 |
+
train = all_data[:split_idx]
|
| 134 |
+
val = all_data[split_idx:]
|
| 135 |
+
|
| 136 |
+
with open(data_dir / "stem_val.jsonl", "w") as f:
|
| 137 |
+
for item in val:
|
| 138 |
+
f.write(json.dumps(item) + "\n")
|
| 139 |
+
|
| 140 |
+
print(f" Train: {len(train):,}")
|
| 141 |
+
print(f" Val: {len(val):,}")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
download_stem_data()
|
scripts/04_train.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SHOREKEEPER-4B Training Pipeline
|
| 4 |
+
Runs on any CUDA device (RTX 3060, H100, etc.)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
|
| 16 |
+
from src.shorekeeper import MemoryEfficientSHOREKEEPER
|
| 17 |
+
from transformers import AutoTokenizer
|
| 18 |
+
|
| 19 |
+
class SHOREKEEPERTrainer:
|
| 20 |
+
"""Simple training loop for SHOREKEEPER"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model, tokenizer, config):
|
| 23 |
+
self.model = model
|
| 24 |
+
self.tokenizer = tokenizer
|
| 25 |
+
self.device = next(model.parameters()).device
|
| 26 |
+
|
| 27 |
+
self.learning_rate = config.get('learning_rate', 1e-4)
|
| 28 |
+
self.epochs = config.get('epochs', 3)
|
| 29 |
+
self.batch_size = config.get('batch_size', 2)
|
| 30 |
+
self.gradient_accumulation = config.get('gradient_accumulation', 4)
|
| 31 |
+
|
| 32 |
+
self.optimizer = torch.optim.AdamW(
|
| 33 |
+
self.model.parameters(),
|
| 34 |
+
lr=self.learning_rate,
|
| 35 |
+
weight_decay=0.01
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 39 |
+
self.optimizer,
|
| 40 |
+
T_max=1000,
|
| 41 |
+
eta_min=1e-6
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
self.step = 0
|
| 45 |
+
|
| 46 |
+
def train_step(self, batch):
|
| 47 |
+
"""Single training step"""
|
| 48 |
+
self.model.train()
|
| 49 |
+
|
| 50 |
+
# Prepare batch
|
| 51 |
+
texts = batch['text']
|
| 52 |
+
|
| 53 |
+
# Tokenize
|
| 54 |
+
inputs = self.tokenizer(
|
| 55 |
+
texts,
|
| 56 |
+
return_tensors="pt",
|
| 57 |
+
padding=True,
|
| 58 |
+
truncation=True,
|
| 59 |
+
max_length=512
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
input_ids = inputs['input_ids'].to(self.device)
|
| 63 |
+
|
| 64 |
+
# Forward pass
|
| 65 |
+
logits = self.model(input_ids)
|
| 66 |
+
|
| 67 |
+
# Calculate loss (next token prediction)
|
| 68 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 69 |
+
shift_labels = input_ids[..., 1:].contiguous()
|
| 70 |
+
|
| 71 |
+
# Cross entropy loss - ignore padding tokens
|
| 72 |
+
loss = nn.functional.cross_entropy(
|
| 73 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 74 |
+
shift_labels.view(-1),
|
| 75 |
+
ignore_index=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else -100
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Backward
|
| 79 |
+
loss.backward()
|
| 80 |
+
|
| 81 |
+
# Gradient accumulation
|
| 82 |
+
if (self.step + 1) % self.gradient_accumulation == 0:
|
| 83 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 84 |
+
self.optimizer.step()
|
| 85 |
+
self.scheduler.step()
|
| 86 |
+
self.optimizer.zero_grad()
|
| 87 |
+
|
| 88 |
+
self.step += 1
|
| 89 |
+
|
| 90 |
+
return loss.item()
|
| 91 |
+
|
| 92 |
+
def train(self, dataset, output_dir="./outputs"):
|
| 93 |
+
"""Full training loop"""
|
| 94 |
+
print(f"\n{'='*60}")
|
| 95 |
+
print("Starting Training")
|
| 96 |
+
print(f"{'='*60}")
|
| 97 |
+
print(f"Device: {self.device}")
|
| 98 |
+
print(f"Training samples: {len(dataset)}")
|
| 99 |
+
print(f"Batch size: {self.batch_size}")
|
| 100 |
+
print(f"Learning rate: {self.learning_rate}")
|
| 101 |
+
print(f"Epochs: {self.epochs}")
|
| 102 |
+
print(f"{'='*60}\n")
|
| 103 |
+
|
| 104 |
+
output_dir = Path(output_dir)
|
| 105 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 106 |
+
|
| 107 |
+
for epoch in range(self.epochs):
|
| 108 |
+
print(f"\nEpoch {epoch + 1}/{self.epochs}")
|
| 109 |
+
print("-" * 40)
|
| 110 |
+
|
| 111 |
+
total_loss = 0
|
| 112 |
+
steps = 0
|
| 113 |
+
|
| 114 |
+
# Create progress bar
|
| 115 |
+
pbar = tqdm(dataset, desc=f"Training")
|
| 116 |
+
|
| 117 |
+
for i, item in enumerate(pbar):
|
| 118 |
+
# Format training text
|
| 119 |
+
prompt = item.get('prompt', '')
|
| 120 |
+
response = item.get('response', '')
|
| 121 |
+
|
| 122 |
+
if not prompt or not response:
|
| 123 |
+
continue
|
| 124 |
+
|
| 125 |
+
# Create training text (prompt + response)
|
| 126 |
+
text = f"{prompt}\n{response}"
|
| 127 |
+
|
| 128 |
+
batch = {'text': [text]}
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
loss = self.train_step(batch)
|
| 132 |
+
total_loss += loss
|
| 133 |
+
steps += 1
|
| 134 |
+
|
| 135 |
+
# Update progress bar
|
| 136 |
+
pbar.set_postfix({'loss': f'{loss:.4f}'})
|
| 137 |
+
|
| 138 |
+
# Save checkpoint every 100 steps
|
| 139 |
+
if steps % 100 == 0:
|
| 140 |
+
checkpoint_path = output_dir / f"checkpoint_step_{steps}.pt"
|
| 141 |
+
torch.save({
|
| 142 |
+
'step': steps,
|
| 143 |
+
'model_state': self.model.state_dict(),
|
| 144 |
+
'optimizer_state': self.optimizer.state_dict(),
|
| 145 |
+
'loss': loss
|
| 146 |
+
}, checkpoint_path)
|
| 147 |
+
print(f"\n Saved checkpoint: {checkpoint_path}")
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
# Don't print every error to avoid spam
|
| 151 |
+
if steps < 5:
|
| 152 |
+
print(f"\n Error on step {steps}: {e}")
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
avg_loss = total_loss / steps if steps > 0 else 0
|
| 156 |
+
print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
|
| 157 |
+
|
| 158 |
+
# Save epoch checkpoint
|
| 159 |
+
epoch_path = output_dir / f"epoch_{epoch + 1}.pt"
|
| 160 |
+
torch.save({
|
| 161 |
+
'epoch': epoch + 1,
|
| 162 |
+
'model_state': self.model.state_dict(),
|
| 163 |
+
'optimizer_state': self.optimizer.state_dict(),
|
| 164 |
+
'avg_loss': avg_loss
|
| 165 |
+
}, epoch_path)
|
| 166 |
+
print(f"Saved epoch checkpoint: {epoch_path}")
|
| 167 |
+
|
| 168 |
+
# Save final model
|
| 169 |
+
final_path = output_dir / "shorekeeper-4b-final.pt"
|
| 170 |
+
torch.save(self.model.state_dict(), final_path)
|
| 171 |
+
print(f"\n{'='*60}")
|
| 172 |
+
print(f"✅ Training complete! Final model saved to: {final_path}")
|
| 173 |
+
print(f"{'='*60}")
|
| 174 |
+
|
| 175 |
+
return self.model
|
| 176 |
+
|
| 177 |
+
def load_data(data_path, limit=None):
|
| 178 |
+
"""Load training data from JSONL file"""
|
| 179 |
+
data = []
|
| 180 |
+
data_path = Path(data_path)
|
| 181 |
+
|
| 182 |
+
if not data_path.exists():
|
| 183 |
+
print(f"Data file not found: {data_path}")
|
| 184 |
+
return data
|
| 185 |
+
|
| 186 |
+
with open(data_path, 'r') as f:
|
| 187 |
+
for i, line in enumerate(f):
|
| 188 |
+
if limit and i >= limit:
|
| 189 |
+
break
|
| 190 |
+
try:
|
| 191 |
+
item = json.loads(line)
|
| 192 |
+
data.append(item)
|
| 193 |
+
except:
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
print(f"Loaded {len(data)} examples from {data_path}")
|
| 197 |
+
return data
|
| 198 |
+
|
| 199 |
+
def main():
|
| 200 |
+
print("=" * 60)
|
| 201 |
+
print("SHOREKEEPER-4B Training Pipeline")
|
| 202 |
+
print("=" * 60)
|
| 203 |
+
|
| 204 |
+
# Check device
|
| 205 |
+
if torch.cuda.is_available():
|
| 206 |
+
device = torch.device("cuda")
|
| 207 |
+
print(f"\n✓ CUDA available: {torch.cuda.get_device_name(0)}")
|
| 208 |
+
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 209 |
+
else:
|
| 210 |
+
device = torch.device("cpu")
|
| 211 |
+
print("\n⚠ No GPU detected, using CPU (will be slow)")
|
| 212 |
+
|
| 213 |
+
# Load model
|
| 214 |
+
print("\n1. Loading SHOREKEEPER model...")
|
| 215 |
+
model = MemoryEfficientSHOREKEEPER(use_4bit=False) # Use full precision for training
|
| 216 |
+
model = model.to(device)
|
| 217 |
+
print(f" Model loaded on {device}")
|
| 218 |
+
|
| 219 |
+
# Load tokenizer
|
| 220 |
+
print("\n2. Loading tokenizer...")
|
| 221 |
+
try:
|
| 222 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 223 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 224 |
+
print(" ✓ Using GPT-2 tokenizer")
|
| 225 |
+
except:
|
| 226 |
+
print(" ⚠ Could not load GPT-2 tokenizer")
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
# Load training data
|
| 230 |
+
print("\n3. Loading training data...")
|
| 231 |
+
data_path = Path("./data/processed/train.jsonl")
|
| 232 |
+
|
| 233 |
+
if not data_path.exists():
|
| 234 |
+
print(f"\n❌ No training data found at {data_path}")
|
| 235 |
+
print(" Run: python3 scripts/01_download_data.py")
|
| 236 |
+
print(" Then: python3 scripts/02_prepare_data.py")
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
print("\n Training options:")
|
| 240 |
+
print(" [1] Quick test (50 examples, 1 epoch) - ~2 minutes")
|
| 241 |
+
print(" [2] Small training (200 examples, 3 epochs) - ~10 minutes")
|
| 242 |
+
print(" [3] Medium training (500 examples, 5 epochs) - ~30 minutes")
|
| 243 |
+
print(" [4] Full training (all data, 10 epochs) - several hours")
|
| 244 |
+
|
| 245 |
+
choice = input("\nChoose option (1/2/3/4): ").strip()
|
| 246 |
+
|
| 247 |
+
if choice == "1":
|
| 248 |
+
limit = 50
|
| 249 |
+
epochs = 1
|
| 250 |
+
learning_rate = 1e-4
|
| 251 |
+
elif choice == "2":
|
| 252 |
+
limit = 200
|
| 253 |
+
epochs = 3
|
| 254 |
+
learning_rate = 5e-5
|
| 255 |
+
elif choice == "3":
|
| 256 |
+
limit = 500
|
| 257 |
+
epochs = 5
|
| 258 |
+
learning_rate = 3e-5
|
| 259 |
+
else:
|
| 260 |
+
limit = None
|
| 261 |
+
epochs = 10
|
| 262 |
+
learning_rate = 1e-5
|
| 263 |
+
|
| 264 |
+
# Load data
|
| 265 |
+
data = load_data(data_path, limit=limit)
|
| 266 |
+
|
| 267 |
+
if not data:
|
| 268 |
+
print("\n❌ No training data available!")
|
| 269 |
+
return
|
| 270 |
+
|
| 271 |
+
print(f"\n Training with {len(data)} examples, {epochs} epochs")
|
| 272 |
+
print(f" Learning rate: {learning_rate}")
|
| 273 |
+
|
| 274 |
+
# Training config
|
| 275 |
+
config = {
|
| 276 |
+
'learning_rate': learning_rate,
|
| 277 |
+
'epochs': epochs,
|
| 278 |
+
'batch_size': 2,
|
| 279 |
+
'gradient_accumulation': 4
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
# Create trainer
|
| 283 |
+
print("\n4. Initializing trainer...")
|
| 284 |
+
trainer = SHOREKEEPERTrainer(model, tokenizer, config)
|
| 285 |
+
|
| 286 |
+
# Start training
|
| 287 |
+
print("\n5. Starting training...")
|
| 288 |
+
print(" Press Ctrl+C to stop early\n")
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
trained_model = trainer.train(data, output_dir="./outputs")
|
| 292 |
+
except KeyboardInterrupt:
|
| 293 |
+
print("\n\n⚠ Training interrupted by user")
|
| 294 |
+
print("Saving current model...")
|
| 295 |
+
torch.save(model.state_dict(), "./outputs/shorekeeper-interrupted.pt")
|
| 296 |
+
print("Model saved to: ./outputs/shorekeeper-interrupted.pt")
|
| 297 |
+
except Exception as e:
|
| 298 |
+
print(f"\n❌ Training failed: {e}")
|
| 299 |
+
import traceback
|
| 300 |
+
traceback.print_exc()
|
| 301 |
+
|
| 302 |
+
print("\n" + "=" * 60)
|
| 303 |
+
print("Next steps:")
|
| 304 |
+
print(" 1. Run GRPO training: python3 scripts/05_grpo_train.py")
|
| 305 |
+
print(" 2. Convert to 4-bit: python3 scripts/06_convert_to_4bit.py")
|
| 306 |
+
print(" 3. Run SHOREKEEPER: python3 scripts/07_run_shorekeeper.py")
|
| 307 |
+
print("=" * 60)
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
main()
|
scripts/04_train_5090_optimized.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Optimized training for RTX 5090 with 129GB RAM
|
| 4 |
+
Larger batch sizes = faster training!
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 16 |
+
|
| 17 |
+
from src.shorekeeper import SHOREKEEPER
|
| 18 |
+
from transformers import AutoTokenizer
|
| 19 |
+
|
| 20 |
+
def main():
|
| 21 |
+
print("=" * 80)
|
| 22 |
+
print("SHOREKEEPER TRAINING - OPTIMIZED FOR 129GB RAM")
|
| 23 |
+
print("=" * 80)
|
| 24 |
+
|
| 25 |
+
device = torch.device("cuda")
|
| 26 |
+
|
| 27 |
+
# With 129GB RAM, we can use larger batch sizes!
|
| 28 |
+
batch_size = 8 # Double from 4
|
| 29 |
+
gradient_accumulation = 4 # Half from 8
|
| 30 |
+
effective_batch = batch_size * gradient_accumulation # 32 (same effective)
|
| 31 |
+
|
| 32 |
+
print(f"\nGPU: {torch.cuda.get_device_name(0)}")
|
| 33 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 34 |
+
print(f"System RAM: {psutil.virtual_memory().total / 1e9:.1f} GB")
|
| 35 |
+
print(f"Batch size: {batch_size}")
|
| 36 |
+
print(f"Gradient accumulation: {gradient_accumulation}")
|
| 37 |
+
print(f"Effective batch size: {effective_batch}")
|
| 38 |
+
|
| 39 |
+
# Load model
|
| 40 |
+
print("\n1. Loading SHOREKEEPER model...")
|
| 41 |
+
model = SHOREKEEPER()
|
| 42 |
+
model = model.to(device)
|
| 43 |
+
|
| 44 |
+
params = sum(p.numel() for p in model.parameters())
|
| 45 |
+
print(f" Parameters: {params:,} ({params/1e9:.1f}B)")
|
| 46 |
+
|
| 47 |
+
# Load tokenizer
|
| 48 |
+
print("\n2. Loading tokenizer...")
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 50 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 51 |
+
tokenizer.model_max_length = 1024
|
| 52 |
+
|
| 53 |
+
# Load data
|
| 54 |
+
print("\n3. Loading training data...")
|
| 55 |
+
data_path = Path("./data/7b_150gb/7b_train.jsonl")
|
| 56 |
+
|
| 57 |
+
if not data_path.exists():
|
| 58 |
+
print(" ❌ No data found! Run download script first.")
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
data = []
|
| 62 |
+
with open(data_path, 'r') as f:
|
| 63 |
+
for line in f:
|
| 64 |
+
data.append(json.loads(line))
|
| 65 |
+
|
| 66 |
+
print(f" Loaded {len(data):,} examples")
|
| 67 |
+
|
| 68 |
+
# Optimizer
|
| 69 |
+
optimizer = torch.optim.AdamW(
|
| 70 |
+
model.parameters(),
|
| 71 |
+
lr=3e-4,
|
| 72 |
+
weight_decay=0.1,
|
| 73 |
+
betas=(0.9, 0.95)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
scaler = torch.amp.GradScaler('cuda')
|
| 77 |
+
|
| 78 |
+
print("\n4. Starting training...")
|
| 79 |
+
print(" Training will take 1-2 weeks")
|
| 80 |
+
|
| 81 |
+
epochs = 3
|
| 82 |
+
for epoch in range(epochs):
|
| 83 |
+
print(f"\nEpoch {epoch + 1}/{epochs}")
|
| 84 |
+
|
| 85 |
+
random.shuffle(data)
|
| 86 |
+
total_loss = 0
|
| 87 |
+
steps = 0
|
| 88 |
+
optimizer.zero_grad()
|
| 89 |
+
|
| 90 |
+
pbar = tqdm(data, desc=f"Training")
|
| 91 |
+
|
| 92 |
+
for i, item in enumerate(pbar):
|
| 93 |
+
text = item.get('text', '')
|
| 94 |
+
if not text or len(text) < 50:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
inputs = tokenizer(
|
| 98 |
+
text[:2048],
|
| 99 |
+
return_tensors="pt",
|
| 100 |
+
truncation=True,
|
| 101 |
+
max_length=1024,
|
| 102 |
+
padding="max_length"
|
| 103 |
+
)
|
| 104 |
+
input_ids = inputs['input_ids'].to(device)
|
| 105 |
+
|
| 106 |
+
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 107 |
+
logits = model(input_ids)
|
| 108 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 109 |
+
shift_labels = input_ids[..., 1:].contiguous()
|
| 110 |
+
loss = nn.functional.cross_entropy(
|
| 111 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 112 |
+
shift_labels.view(-1),
|
| 113 |
+
ignore_index=tokenizer.pad_token_id
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
scaler.scale(loss).backward()
|
| 117 |
+
|
| 118 |
+
total_loss += loss.item()
|
| 119 |
+
steps += 1
|
| 120 |
+
|
| 121 |
+
if (i + 1) % gradient_accumulation == 0:
|
| 122 |
+
scaler.unscale_(optimizer)
|
| 123 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 124 |
+
scaler.step(optimizer)
|
| 125 |
+
scaler.update()
|
| 126 |
+
optimizer.zero_grad()
|
| 127 |
+
|
| 128 |
+
pbar.set_postfix({
|
| 129 |
+
'loss': f'{loss.item():.4f}',
|
| 130 |
+
'avg': f'{total_loss/steps:.4f}'
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
if steps % 5000 == 0:
|
| 134 |
+
torch.save(model.state_dict(), f"./outputs/checkpoint_step_{steps}.pt")
|
| 135 |
+
print(f"\n 💾 Checkpoint saved")
|
| 136 |
+
|
| 137 |
+
avg_loss = total_loss / steps
|
| 138 |
+
print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
|
| 139 |
+
torch.save(model.state_dict(), f"./outputs/epoch_{epoch+1}.pt")
|
| 140 |
+
|
| 141 |
+
torch.save(model.state_dict(), "./outputs/shorekeeper_7b_final.pt")
|
| 142 |
+
print("\n✅ Training complete!")
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
import psutil
|
| 146 |
+
main()
|
scripts/04_train_stem.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Clean SHOREKEEPER training on STEM data only
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import json
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 15 |
+
|
| 16 |
+
from src.shorekeeper import SHOREKEEPER
|
| 17 |
+
from transformers import AutoTokenizer
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
print("=" * 70)
|
| 21 |
+
print("SHOREKEEPER - STEM TRAINING")
|
| 22 |
+
print("=" * 70)
|
| 23 |
+
|
| 24 |
+
# Check device
|
| 25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
print(f"\nDevice: {device}")
|
| 27 |
+
|
| 28 |
+
# Load model (fresh from scratch)
|
| 29 |
+
print("\n1. Loading SHOREKEEPER model...")
|
| 30 |
+
model = SHOREKEEPER()
|
| 31 |
+
model = model.to(device)
|
| 32 |
+
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 33 |
+
|
| 34 |
+
# Load tokenizer
|
| 35 |
+
print("\n2. Loading tokenizer...")
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 37 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 38 |
+
print(" ✓ GPT-2 tokenizer")
|
| 39 |
+
|
| 40 |
+
# Load STEM data
|
| 41 |
+
print("\n3. Loading STEM training data...")
|
| 42 |
+
data_path = Path("./data/stem/stem_train.jsonl")
|
| 43 |
+
|
| 44 |
+
if not data_path.exists():
|
| 45 |
+
print(" ❌ No STEM data found!")
|
| 46 |
+
print(" Run: python3 scripts/01_download_stem_data.py")
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
data = []
|
| 50 |
+
with open(data_path, 'r') as f:
|
| 51 |
+
for line in f:
|
| 52 |
+
data.append(json.loads(line))
|
| 53 |
+
|
| 54 |
+
print(f" Loaded {len(data):,} examples")
|
| 55 |
+
|
| 56 |
+
# Training config
|
| 57 |
+
batch_size = 2
|
| 58 |
+
gradient_accumulation = 8
|
| 59 |
+
learning_rate = 3e-4
|
| 60 |
+
|
| 61 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
|
| 62 |
+
|
| 63 |
+
print("\n4. Training configuration:")
|
| 64 |
+
print(f" Examples: {len(data):,}")
|
| 65 |
+
print(f" Learning rate: {learning_rate}")
|
| 66 |
+
print(f" Batch size: {batch_size}")
|
| 67 |
+
print(f" Gradient accumulation: {gradient_accumulation}")
|
| 68 |
+
print(f" Effective batch size: {batch_size * gradient_accumulation}")
|
| 69 |
+
|
| 70 |
+
# Training loop
|
| 71 |
+
epochs = 5
|
| 72 |
+
print(f"\n5. Training for {epochs} epochs...")
|
| 73 |
+
|
| 74 |
+
for epoch in range(epochs):
|
| 75 |
+
print(f"\nEpoch {epoch + 1}/{epochs}")
|
| 76 |
+
|
| 77 |
+
# Shuffle data
|
| 78 |
+
random.shuffle(data)
|
| 79 |
+
|
| 80 |
+
total_loss = 0
|
| 81 |
+
steps = 0
|
| 82 |
+
optimizer.zero_grad()
|
| 83 |
+
|
| 84 |
+
pbar = tqdm(data, desc=f"Training")
|
| 85 |
+
|
| 86 |
+
for i, item in enumerate(pbar):
|
| 87 |
+
# Format text
|
| 88 |
+
text = f"{item['prompt']}\n{item['response']}"
|
| 89 |
+
|
| 90 |
+
# Tokenize
|
| 91 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 92 |
+
input_ids = inputs['input_ids'].to(device)
|
| 93 |
+
|
| 94 |
+
# Forward
|
| 95 |
+
logits = model(input_ids)
|
| 96 |
+
|
| 97 |
+
# Loss
|
| 98 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 99 |
+
shift_labels = input_ids[..., 1:].contiguous()
|
| 100 |
+
loss = nn.functional.cross_entropy(
|
| 101 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 102 |
+
shift_labels.view(-1),
|
| 103 |
+
ignore_index=tokenizer.pad_token_id
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Backward
|
| 107 |
+
loss.backward()
|
| 108 |
+
|
| 109 |
+
total_loss += loss.item()
|
| 110 |
+
steps += 1
|
| 111 |
+
|
| 112 |
+
# Update weights
|
| 113 |
+
if (i + 1) % gradient_accumulation == 0:
|
| 114 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 115 |
+
optimizer.step()
|
| 116 |
+
optimizer.zero_grad()
|
| 117 |
+
|
| 118 |
+
# Update progress bar
|
| 119 |
+
pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'})
|
| 120 |
+
|
| 121 |
+
avg_loss = total_loss / steps
|
| 122 |
+
print(f" Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
|
| 123 |
+
|
| 124 |
+
# Save checkpoint
|
| 125 |
+
torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
|
| 126 |
+
print(f" Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
|
| 127 |
+
|
| 128 |
+
# Final save
|
| 129 |
+
torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt")
|
| 130 |
+
print("\n✅ Training complete!")
|
| 131 |
+
print(" Final model: outputs/shorekeeper_stem_final.pt")
|
| 132 |
+
|
| 133 |
+
if __name__ == "__main__":
|
| 134 |
+
main()
|
scripts/04_train_universal.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
SHOREKEEPER Universal Training Script
|
| 4 |
+
Works on: RTX 3060, RTX 5090, H100, A100, Mac MPS, CPU
|
| 5 |
+
Auto-detects hardware and optimizes accordingly
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import random
|
| 15 |
+
import yaml
|
| 16 |
+
import platform
|
| 17 |
+
import psutil
|
| 18 |
+
|
| 19 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 20 |
+
|
| 21 |
+
from src.shorekeeper import SHOREKEEPER
|
| 22 |
+
from transformers import AutoTokenizer
|
| 23 |
+
|
| 24 |
+
def detect_hardware():
|
| 25 |
+
"""Auto-detect best available device and optimize settings"""
|
| 26 |
+
|
| 27 |
+
print("\n" + "=" * 70)
|
| 28 |
+
print("HARDWARE DETECTION")
|
| 29 |
+
print("=" * 70)
|
| 30 |
+
|
| 31 |
+
# Check CUDA
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
device = torch.device("cuda")
|
| 34 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 35 |
+
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 36 |
+
cuda_version = torch.version.cuda
|
| 37 |
+
print(f"✓ CUDA GPU: {gpu_name}")
|
| 38 |
+
print(f" Memory: {gpu_mem:.1f} GB")
|
| 39 |
+
print(f" CUDA Version: {cuda_version}")
|
| 40 |
+
|
| 41 |
+
# Optimize batch size based on GPU memory
|
| 42 |
+
if gpu_mem >= 80: # H100/A100
|
| 43 |
+
recommended_batch = 8
|
| 44 |
+
recommended_accum = 4
|
| 45 |
+
precision = "bfloat16"
|
| 46 |
+
elif gpu_mem >= 32: # RTX 5090, A6000
|
| 47 |
+
recommended_batch = 4
|
| 48 |
+
recommended_accum = 8
|
| 49 |
+
precision = "bfloat16"
|
| 50 |
+
elif gpu_mem >= 16: # RTX 4080, 4090
|
| 51 |
+
recommended_batch = 2
|
| 52 |
+
recommended_accum = 8
|
| 53 |
+
precision = "float16"
|
| 54 |
+
elif gpu_mem >= 12: # RTX 3060, 3070, 3080
|
| 55 |
+
recommended_batch = 1
|
| 56 |
+
recommended_accum = 16
|
| 57 |
+
precision = "float16"
|
| 58 |
+
else:
|
| 59 |
+
recommended_batch = 1
|
| 60 |
+
recommended_accum = 32
|
| 61 |
+
precision = "float16"
|
| 62 |
+
|
| 63 |
+
# Check Apple Metal (M1/M2/M3 Macs)
|
| 64 |
+
elif torch.backends.mps.is_available():
|
| 65 |
+
device = torch.device("mps")
|
| 66 |
+
print("✓ Apple Metal (M1/M2/M3) detected")
|
| 67 |
+
recommended_batch = 2
|
| 68 |
+
recommended_accum = 4
|
| 69 |
+
precision = "float16"
|
| 70 |
+
print(" Note: MPS support is experimental, may need torch nightly")
|
| 71 |
+
|
| 72 |
+
# Fallback to CPU
|
| 73 |
+
else:
|
| 74 |
+
device = torch.device("cpu")
|
| 75 |
+
print("⚠ No GPU detected, using CPU (will be very slow)")
|
| 76 |
+
recommended_batch = 1
|
| 77 |
+
recommended_accum = 1
|
| 78 |
+
precision = "float32"
|
| 79 |
+
|
| 80 |
+
# Show CPU info
|
| 81 |
+
cpu_count = psutil.cpu_count()
|
| 82 |
+
ram = psutil.virtual_memory().total / 1e9
|
| 83 |
+
print(f" CPU: {cpu_count} cores")
|
| 84 |
+
print(f" RAM: {ram:.1f} GB")
|
| 85 |
+
|
| 86 |
+
print(f"\nRecommended settings:")
|
| 87 |
+
print(f" Batch size: {recommended_batch}")
|
| 88 |
+
print(f" Gradient accumulation: {recommended_accum}")
|
| 89 |
+
print(f" Effective batch size: {recommended_batch * recommended_accum}")
|
| 90 |
+
print(f" Precision: {precision}")
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
'device': device,
|
| 94 |
+
'batch_size': recommended_batch,
|
| 95 |
+
'gradient_accumulation': recommended_accum,
|
| 96 |
+
'precision': precision,
|
| 97 |
+
'gpu_memory': gpu_mem if torch.cuda.is_available() else 0
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
def get_model_size(model):
|
| 101 |
+
"""Calculate model size in billions of parameters"""
|
| 102 |
+
params = sum(p.numel() for p in model.parameters())
|
| 103 |
+
return params / 1e9
|
| 104 |
+
|
| 105 |
+
class UniversalTrainer:
|
| 106 |
+
"""Trainer that adapts to any hardware"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, model, tokenizer, hardware_config):
|
| 109 |
+
self.model = model
|
| 110 |
+
self.tokenizer = tokenizer
|
| 111 |
+
self.device = hardware_config['device']
|
| 112 |
+
self.batch_size = hardware_config['batch_size']
|
| 113 |
+
self.gradient_accumulation = hardware_config['gradient_accumulation']
|
| 114 |
+
self.precision = hardware_config['precision']
|
| 115 |
+
|
| 116 |
+
# Learning rate scales with model size
|
| 117 |
+
model_size = get_model_size(model)
|
| 118 |
+
if model_size < 1:
|
| 119 |
+
base_lr = 5e-4
|
| 120 |
+
elif model_size < 4:
|
| 121 |
+
base_lr = 3e-4
|
| 122 |
+
elif model_size < 8:
|
| 123 |
+
base_lr = 2e-4
|
| 124 |
+
else:
|
| 125 |
+
base_lr = 1e-4
|
| 126 |
+
|
| 127 |
+
self.learning_rate = base_lr
|
| 128 |
+
|
| 129 |
+
self.optimizer = torch.optim.AdamW(
|
| 130 |
+
self.model.parameters(),
|
| 131 |
+
lr=self.learning_rate,
|
| 132 |
+
weight_decay=0.1,
|
| 133 |
+
betas=(0.9, 0.95)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 137 |
+
self.optimizer, T_0=5000, T_mult=2
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.step = 0
|
| 141 |
+
self.total_loss = 0
|
| 142 |
+
|
| 143 |
+
# Mixed precision training
|
| 144 |
+
self.scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None
|
| 145 |
+
|
| 146 |
+
print(f"\nTraining configuration:")
|
| 147 |
+
print(f" Device: {self.device}")
|
| 148 |
+
print(f" Learning rate: {self.learning_rate}")
|
| 149 |
+
print(f" Batch size: {self.batch_size}")
|
| 150 |
+
print(f" Gradient accumulation: {self.gradient_accumulation}")
|
| 151 |
+
print(f" Precision: {self.precision}")
|
| 152 |
+
|
| 153 |
+
def train_step(self, text):
|
| 154 |
+
"""Single training step with mixed precision"""
|
| 155 |
+
self.model.train()
|
| 156 |
+
|
| 157 |
+
# Tokenize
|
| 158 |
+
inputs = self.tokenizer(
|
| 159 |
+
text,
|
| 160 |
+
return_tensors="pt",
|
| 161 |
+
truncation=True,
|
| 162 |
+
max_length=512,
|
| 163 |
+
padding="max_length"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
input_ids = inputs['input_ids'].to(self.device)
|
| 167 |
+
|
| 168 |
+
# Mixed precision forward pass
|
| 169 |
+
if self.precision == "bfloat16" and torch.cuda.is_available():
|
| 170 |
+
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 171 |
+
logits = self.model(input_ids)
|
| 172 |
+
loss = self._compute_loss(logits, input_ids)
|
| 173 |
+
elif self.precision == "float16" and torch.cuda.is_available():
|
| 174 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
| 175 |
+
logits = self.model(input_ids)
|
| 176 |
+
loss = self._compute_loss(logits, input_ids)
|
| 177 |
+
else:
|
| 178 |
+
logits = self.model(input_ids)
|
| 179 |
+
loss = self._compute_loss(logits, input_ids)
|
| 180 |
+
|
| 181 |
+
# Backward with gradient scaling if using fp16
|
| 182 |
+
if self.scaler:
|
| 183 |
+
self.scaler.scale(loss).backward()
|
| 184 |
+
else:
|
| 185 |
+
loss.backward()
|
| 186 |
+
|
| 187 |
+
# Gradient accumulation and optimizer step
|
| 188 |
+
if (self.step + 1) % self.gradient_accumulation == 0:
|
| 189 |
+
if self.scaler:
|
| 190 |
+
self.scaler.unscale_(self.optimizer)
|
| 191 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 192 |
+
self.scaler.step(self.optimizer)
|
| 193 |
+
self.scaler.update()
|
| 194 |
+
else:
|
| 195 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 196 |
+
self.optimizer.step()
|
| 197 |
+
|
| 198 |
+
self.scheduler.step()
|
| 199 |
+
self.optimizer.zero_grad()
|
| 200 |
+
|
| 201 |
+
self.step += 1
|
| 202 |
+
return loss.item()
|
| 203 |
+
|
| 204 |
+
def _compute_loss(self, logits, input_ids):
|
| 205 |
+
"""Compute cross-entropy loss"""
|
| 206 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 207 |
+
shift_labels = input_ids[..., 1:].contiguous()
|
| 208 |
+
|
| 209 |
+
return nn.functional.cross_entropy(
|
| 210 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 211 |
+
shift_labels.view(-1),
|
| 212 |
+
ignore_index=self.tokenizer.pad_token_id
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def train(self, data, num_epochs=1, save_every=5000):
|
| 216 |
+
"""Full training loop"""
|
| 217 |
+
print(f"\n{'='*70}")
|
| 218 |
+
print(f"STARTING TRAINING")
|
| 219 |
+
print(f"{'='*70}")
|
| 220 |
+
print(f"Examples: {len(data):,}")
|
| 221 |
+
print(f"Epochs: {num_epochs}")
|
| 222 |
+
print(f"Save checkpoint every {save_every} steps")
|
| 223 |
+
|
| 224 |
+
for epoch in range(num_epochs):
|
| 225 |
+
print(f"\nEpoch {epoch + 1}/{num_epochs}")
|
| 226 |
+
print("-" * 40)
|
| 227 |
+
|
| 228 |
+
# Shuffle data
|
| 229 |
+
random.shuffle(data)
|
| 230 |
+
|
| 231 |
+
total_loss = 0
|
| 232 |
+
steps = 0
|
| 233 |
+
self.optimizer.zero_grad()
|
| 234 |
+
|
| 235 |
+
pbar = tqdm(data, desc=f"Training")
|
| 236 |
+
|
| 237 |
+
for i, item in enumerate(pbar):
|
| 238 |
+
# Get text from item (handles different formats)
|
| 239 |
+
text = item.get('text', '')
|
| 240 |
+
if not text:
|
| 241 |
+
text = f"{item.get('prompt', '')}\n{item.get('response', '')}"
|
| 242 |
+
|
| 243 |
+
if not text or len(text) < 10:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
loss = self.train_step(text[:2048]) # Limit length
|
| 248 |
+
total_loss += loss
|
| 249 |
+
steps += 1
|
| 250 |
+
|
| 251 |
+
# Update progress bar
|
| 252 |
+
avg_loss = total_loss / steps
|
| 253 |
+
pbar.set_postfix({
|
| 254 |
+
'loss': f'{loss:.4f}',
|
| 255 |
+
'avg': f'{avg_loss:.4f}'
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
# Save checkpoint
|
| 259 |
+
if steps % save_every == 0:
|
| 260 |
+
checkpoint = {
|
| 261 |
+
'step': self.step,
|
| 262 |
+
'epoch': epoch + 1,
|
| 263 |
+
'model_state': self.model.state_dict(),
|
| 264 |
+
'optimizer_state': self.optimizer.state_dict(),
|
| 265 |
+
'loss': loss,
|
| 266 |
+
'avg_loss': avg_loss
|
| 267 |
+
}
|
| 268 |
+
torch.save(checkpoint, f"./outputs/checkpoint_step_{self.step}.pt")
|
| 269 |
+
print(f"\n 💾 Checkpoint saved at step {self.step}")
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
if steps < 10: # Only print first few errors
|
| 273 |
+
print(f"\n ⚠ Error: {e}")
|
| 274 |
+
continue
|
| 275 |
+
|
| 276 |
+
avg_loss = total_loss / steps if steps > 0 else 0
|
| 277 |
+
print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
|
| 278 |
+
|
| 279 |
+
# Save epoch checkpoint
|
| 280 |
+
torch.save({
|
| 281 |
+
'epoch': epoch + 1,
|
| 282 |
+
'model_state': self.model.state_dict(),
|
| 283 |
+
'optimizer_state': self.optimizer.state_dict(),
|
| 284 |
+
'avg_loss': avg_loss
|
| 285 |
+
}, f"./outputs/epoch_{epoch + 1}.pt")
|
| 286 |
+
print(f" 💾 Saved epoch checkpoint")
|
| 287 |
+
|
| 288 |
+
def load_training_data(data_path, max_examples=None):
|
| 289 |
+
"""Load training data from JSONL file"""
|
| 290 |
+
data = []
|
| 291 |
+
data_path = Path(data_path)
|
| 292 |
+
|
| 293 |
+
if not data_path.exists():
|
| 294 |
+
return []
|
| 295 |
+
|
| 296 |
+
with open(data_path, 'r') as f:
|
| 297 |
+
for i, line in enumerate(f):
|
| 298 |
+
if max_examples and i >= max_examples:
|
| 299 |
+
break
|
| 300 |
+
try:
|
| 301 |
+
item = json.loads(line)
|
| 302 |
+
data.append(item)
|
| 303 |
+
except:
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
return data
|
| 307 |
+
|
| 308 |
+
def main():
|
| 309 |
+
print("=" * 70)
|
| 310 |
+
print("SHOREKEEPER UNIVERSAL TRAINING")
|
| 311 |
+
print="=" * 70)
|
| 312 |
+
|
| 313 |
+
# Detect hardware
|
| 314 |
+
hw_config = detect_hardware()
|
| 315 |
+
device = hw_config['device']
|
| 316 |
+
|
| 317 |
+
# Check model config
|
| 318 |
+
config_path = "configs/model.yaml"
|
| 319 |
+
if Path("configs/model_15b.yaml").exists():
|
| 320 |
+
print("\n📁 Found 15B config, using that")
|
| 321 |
+
config_path = "configs/model_15b.yaml"
|
| 322 |
+
|
| 323 |
+
# Load model
|
| 324 |
+
print("\n1. Loading SHOREKEEPER model...")
|
| 325 |
+
model = SHOREKEEPER(config_path=config_path)
|
| 326 |
+
model = model.to(device)
|
| 327 |
+
|
| 328 |
+
model_size = get_model_size(model)
|
| 329 |
+
print(f" Model size: {model_size:.1f}B parameters")
|
| 330 |
+
print(f" Memory usage estimate: {model_size * 4:.1f} GB (fp32)")
|
| 331 |
+
|
| 332 |
+
# Load tokenizer
|
| 333 |
+
print("\n2. Loading tokenizer...")
|
| 334 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 335 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 336 |
+
tokenizer.model_max_length = 512
|
| 337 |
+
print(" ✓ GPT-2 tokenizer")
|
| 338 |
+
|
| 339 |
+
# Load data
|
| 340 |
+
print("\n3. Loading training data...")
|
| 341 |
+
|
| 342 |
+
# Try multiple possible data paths
|
| 343 |
+
data_paths = [
|
| 344 |
+
"./data/15b_data/15b_train.jsonl",
|
| 345 |
+
"./data/stem/stem_train.jsonl",
|
| 346 |
+
"./data/processed/train_large.jsonl",
|
| 347 |
+
"./data/processed/train.jsonl"
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
data = []
|
| 351 |
+
for path in data_paths:
|
| 352 |
+
if Path(path).exists():
|
| 353 |
+
data = load_training_data(path)
|
| 354 |
+
if data:
|
| 355 |
+
print(f" ✓ Loaded {len(data):,} examples from {path}")
|
| 356 |
+
break
|
| 357 |
+
|
| 358 |
+
if not data:
|
| 359 |
+
print("\n❌ No training data found!")
|
| 360 |
+
print("\nPlease run one of these first:")
|
| 361 |
+
print(" python3 scripts/01_download_stem_data.py")
|
| 362 |
+
print(" python3 scripts/01_download_15b_data.py")
|
| 363 |
+
return
|
| 364 |
+
|
| 365 |
+
# Ask user for training mode
|
| 366 |
+
print("\n" + "=" * 70)
|
| 367 |
+
print("TRAINING OPTIONS")
|
| 368 |
+
print("=" * 70)
|
| 369 |
+
print(f"1. Quick test (10% of data, 1 epoch)")
|
| 370 |
+
print(f"2. Standard training (all data, 3 epochs)")
|
| 371 |
+
print(f"3. Full training (all data, 10 epochs)")
|
| 372 |
+
print(f"4. Custom (enter your own settings)")
|
| 373 |
+
|
| 374 |
+
choice = input("\nChoose option (1-4): ").strip()
|
| 375 |
+
|
| 376 |
+
if choice == "1":
|
| 377 |
+
data = data[:max(1000, len(data) // 10)]
|
| 378 |
+
epochs = 1
|
| 379 |
+
elif choice == "2":
|
| 380 |
+
epochs = 3
|
| 381 |
+
elif choice == "3":
|
| 382 |
+
epochs = 10
|
| 383 |
+
elif choice == "4":
|
| 384 |
+
epochs = int(input("Number of epochs: ").strip())
|
| 385 |
+
limit = input("Limit examples (press Enter for all): ").strip()
|
| 386 |
+
if limit:
|
| 387 |
+
data = data[:int(limit)]
|
| 388 |
+
else:
|
| 389 |
+
epochs = 1
|
| 390 |
+
|
| 391 |
+
# Create trainer
|
| 392 |
+
trainer = UniversalTrainer(model, tokenizer, hw_config)
|
| 393 |
+
|
| 394 |
+
# Start training
|
| 395 |
+
print(f"\n4. Starting training on {len(data):,} examples for {epochs} epochs...")
|
| 396 |
+
print(" Press Ctrl+C to stop and save checkpoint\n")
|
| 397 |
+
|
| 398 |
+
try:
|
| 399 |
+
trainer.train(data, num_epochs=epochs)
|
| 400 |
+
except KeyboardInterrupt:
|
| 401 |
+
print("\n\n⚠ Training interrupted by user")
|
| 402 |
+
print("Saving current model...")
|
| 403 |
+
torch.save(model.state_dict(), "./outputs/shorekeeper_interrupted.pt")
|
| 404 |
+
print("Model saved to: ./outputs/shorekeeper_interrupted.pt")
|
| 405 |
+
except Exception as e:
|
| 406 |
+
print(f"\n❌ Training error: {e}")
|
| 407 |
+
import traceback
|
| 408 |
+
traceback.print_exc()
|
| 409 |
+
|
| 410 |
+
# Final save
|
| 411 |
+
final_path = "./outputs/shorekeeper_final.pt"
|
| 412 |
+
torch.save(model.state_dict(), final_path)
|
| 413 |
+
print(f"\n✅ Model saved to: {final_path}")
|
| 414 |
+
|
| 415 |
+
print("\n" + "=" * 70)
|
| 416 |
+
print("NEXT STEPS")
|
| 417 |
+
print("=" * 70)
|
| 418 |
+
print("1. Test your model:")
|
| 419 |
+
print(" python3 scripts/07_run_shorekeeper.py")
|
| 420 |
+
print("\n2. Convert to 4-bit for inference:")
|
| 421 |
+
print(" python3 scripts/06_convert_to_4bit.py")
|
| 422 |
+
print("\n3. Run GRPO reasoning training:")
|
| 423 |
+
print(" python3 scripts/05_grpo_train.py")
|
| 424 |
+
|
| 425 |
+
if __name__ == "__main__":
|
| 426 |
+
main()
|
scripts/05_grpo_train.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
GRPO Training - The Reasoning Magic
|
| 4 |
+
Uses the trained model from stage 1
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 16 |
+
|
| 17 |
+
from src.shorekeeper import SHOREKEEPER, MemoryEfficientSHOREKEEPER
|
| 18 |
+
from transformers import AutoTokenizer
|
| 19 |
+
|
| 20 |
+
class GRPOTrainer:
|
| 21 |
+
"""Group Relative Policy Optimization Trainer"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model, tokenizer, config):
|
| 24 |
+
self.model = model
|
| 25 |
+
self.tokenizer = tokenizer
|
| 26 |
+
self.device = next(model.parameters()).device
|
| 27 |
+
|
| 28 |
+
self.group_size = config.get('group_size', 2)
|
| 29 |
+
self.lr = config.get('learning_rate', 1e-6)
|
| 30 |
+
|
| 31 |
+
self.optimizer = torch.optim.AdamW(
|
| 32 |
+
self.model.parameters(),
|
| 33 |
+
lr=self.lr,
|
| 34 |
+
weight_decay=0.01
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.step = 0
|
| 38 |
+
|
| 39 |
+
def compute_reward(self, response, ground_truth):
|
| 40 |
+
"""Calculate reward for a response"""
|
| 41 |
+
reward = 0.0
|
| 42 |
+
|
| 43 |
+
# Format reward - check for reasoning tokens
|
| 44 |
+
if '|special_token|' in response:
|
| 45 |
+
reward += 0.5
|
| 46 |
+
|
| 47 |
+
# Extract answer (look for numbers at the end)
|
| 48 |
+
import re
|
| 49 |
+
numbers = re.findall(r'\d+', response)
|
| 50 |
+
if numbers:
|
| 51 |
+
last_num = numbers[-1]
|
| 52 |
+
if last_num == str(ground_truth).strip():
|
| 53 |
+
reward += 2.0
|
| 54 |
+
|
| 55 |
+
# Length reward - not too short
|
| 56 |
+
if len(response.split()) > 10:
|
| 57 |
+
reward += 0.2
|
| 58 |
+
|
| 59 |
+
# No repetition penalty
|
| 60 |
+
words = response.split()
|
| 61 |
+
unique_ratio = len(set(words)) / max(len(words), 1)
|
| 62 |
+
if unique_ratio > 0.5:
|
| 63 |
+
reward += 0.3
|
| 64 |
+
|
| 65 |
+
return reward
|
| 66 |
+
|
| 67 |
+
def generate_response(self, prompt, max_length=128):
|
| 68 |
+
"""Generate a response from the model"""
|
| 69 |
+
self.model.eval()
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256)
|
| 73 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 74 |
+
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
outputs = self.model.generate(
|
| 77 |
+
inputs['input_ids'],
|
| 78 |
+
max_new_tokens=max_length,
|
| 79 |
+
temperature=0.8,
|
| 80 |
+
do_sample=True,
|
| 81 |
+
pad_token_id=self.tokenizer.eos_token_id
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 85 |
+
return response
|
| 86 |
+
except Exception as e:
|
| 87 |
+
return f"Error: {e}"
|
| 88 |
+
|
| 89 |
+
def train_step(self, prompt, ground_truth):
|
| 90 |
+
"""Single GRPO step"""
|
| 91 |
+
self.model.train()
|
| 92 |
+
|
| 93 |
+
# Generate group of responses
|
| 94 |
+
responses = []
|
| 95 |
+
rewards = []
|
| 96 |
+
|
| 97 |
+
for _ in range(self.group_size):
|
| 98 |
+
response = self.generate_response(prompt)
|
| 99 |
+
responses.append(response)
|
| 100 |
+
reward = self.compute_reward(response, ground_truth)
|
| 101 |
+
rewards.append(reward)
|
| 102 |
+
|
| 103 |
+
# Calculate advantages (relative to group mean)
|
| 104 |
+
mean_reward = sum(rewards) / len(rewards)
|
| 105 |
+
advantages = [r - mean_reward for r in rewards]
|
| 106 |
+
|
| 107 |
+
# Train on responses with positive advantage
|
| 108 |
+
total_loss = 0
|
| 109 |
+
valid_steps = 0
|
| 110 |
+
|
| 111 |
+
for i, (response, advantage) in enumerate(zip(responses, advantages)):
|
| 112 |
+
if advantage <= 0:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
# Create training text
|
| 116 |
+
text = f"{prompt}\n{response}"
|
| 117 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
| 118 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 119 |
+
|
| 120 |
+
# Forward pass
|
| 121 |
+
logits = self.model(inputs['input_ids'])
|
| 122 |
+
|
| 123 |
+
# Calculate language modeling loss
|
| 124 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 125 |
+
shift_labels = inputs['input_ids'][..., 1:].contiguous()
|
| 126 |
+
|
| 127 |
+
loss = F.cross_entropy(
|
| 128 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 129 |
+
shift_labels.view(-1),
|
| 130 |
+
ignore_index=self.tokenizer.pad_token_id
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Weight by advantage
|
| 134 |
+
total_loss = total_loss + loss * advantage
|
| 135 |
+
valid_steps += 1
|
| 136 |
+
|
| 137 |
+
if valid_steps > 0 and total_loss != 0:
|
| 138 |
+
total_loss = total_loss / valid_steps
|
| 139 |
+
self.optimizer.zero_grad()
|
| 140 |
+
total_loss.backward()
|
| 141 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 142 |
+
self.optimizer.step()
|
| 143 |
+
return {
|
| 144 |
+
'loss': total_loss.item(),
|
| 145 |
+
'avg_reward': sum(rewards) / len(rewards),
|
| 146 |
+
'best_reward': max(rewards),
|
| 147 |
+
'valid_steps': valid_steps
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
'loss': 0,
|
| 152 |
+
'avg_reward': sum(rewards) / len(rewards),
|
| 153 |
+
'best_reward': max(rewards),
|
| 154 |
+
'valid_steps': 0
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
def train(self, dataset, num_epochs=1):
|
| 158 |
+
"""Full training loop"""
|
| 159 |
+
print(f"\nTraining on device: {self.device}")
|
| 160 |
+
|
| 161 |
+
for epoch in range(num_epochs):
|
| 162 |
+
print(f"\n{'='*50}")
|
| 163 |
+
print(f"Epoch {epoch + 1}/{num_epochs}")
|
| 164 |
+
print(f"{'='*50}")
|
| 165 |
+
|
| 166 |
+
total_loss = 0
|
| 167 |
+
total_reward = 0
|
| 168 |
+
steps = 0
|
| 169 |
+
valid_steps = 0
|
| 170 |
+
|
| 171 |
+
pbar = tqdm(dataset, desc=f"GRPO Training")
|
| 172 |
+
|
| 173 |
+
for i, item in enumerate(pbar):
|
| 174 |
+
prompt = item.get('prompt', '')
|
| 175 |
+
answer = item.get('answer', item.get('ground_truth', ''))
|
| 176 |
+
|
| 177 |
+
if not prompt or not answer:
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
stats = self.train_step(prompt, str(answer))
|
| 182 |
+
|
| 183 |
+
if stats['valid_steps'] > 0:
|
| 184 |
+
total_loss += stats['loss']
|
| 185 |
+
valid_steps += 1
|
| 186 |
+
|
| 187 |
+
total_reward += stats['avg_reward']
|
| 188 |
+
steps += 1
|
| 189 |
+
|
| 190 |
+
pbar.set_postfix({
|
| 191 |
+
'loss': f'{stats["loss"]:.4f}',
|
| 192 |
+
'reward': f'{stats["avg_reward"]:.2f}'
|
| 193 |
+
})
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
if i < 10:
|
| 197 |
+
print(f"\n Error: {e}")
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
if steps > 0:
|
| 201 |
+
avg_loss = total_loss / valid_steps if valid_steps > 0 else 0
|
| 202 |
+
avg_reward = total_reward / steps
|
| 203 |
+
print(f"\n Epoch complete: Avg Loss={avg_loss:.4f}, Avg Reward={avg_reward:.2f}")
|
| 204 |
+
|
| 205 |
+
return self.model
|
| 206 |
+
|
| 207 |
+
def load_training_data(data_path, limit=None):
|
| 208 |
+
"""Load training data for GRPO"""
|
| 209 |
+
data = []
|
| 210 |
+
data_path = Path(data_path)
|
| 211 |
+
|
| 212 |
+
if not data_path.exists():
|
| 213 |
+
print(f"Data file not found: {data_path}")
|
| 214 |
+
return data
|
| 215 |
+
|
| 216 |
+
with open(data_path, 'r') as f:
|
| 217 |
+
for i, line in enumerate(f):
|
| 218 |
+
if limit and i >= limit:
|
| 219 |
+
break
|
| 220 |
+
try:
|
| 221 |
+
item = json.loads(line)
|
| 222 |
+
data.append({
|
| 223 |
+
'prompt': item.get('prompt', ''),
|
| 224 |
+
'answer': item.get('ground_truth', item.get('response', ''))
|
| 225 |
+
})
|
| 226 |
+
except:
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
return data
|
| 230 |
+
|
| 231 |
+
def main():
|
| 232 |
+
print("=" * 60)
|
| 233 |
+
print("SHOREKEEPER GRPO Training")
|
| 234 |
+
print("The Reasoning Magic")
|
| 235 |
+
print("=" * 60)
|
| 236 |
+
|
| 237 |
+
# Check device
|
| 238 |
+
if torch.cuda.is_available():
|
| 239 |
+
device = torch.device("cuda")
|
| 240 |
+
print(f"\n✓ CUDA: {torch.cuda.get_device_name(0)}")
|
| 241 |
+
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 242 |
+
|
| 243 |
+
# Load trained model (full precision for training)
|
| 244 |
+
print("\n1. Loading trained SHOREKEEPER model...")
|
| 245 |
+
model_path = Path("./outputs/shorekeeper-4b-final.pt")
|
| 246 |
+
|
| 247 |
+
if not model_path.exists():
|
| 248 |
+
print(f"\n❌ Model not found at {model_path}")
|
| 249 |
+
print(" Run training first: python3 scripts/04_train.py")
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
model = SHOREKEEPER() # Use full model (not memory efficient for training)
|
| 253 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 254 |
+
model = model.to(device)
|
| 255 |
+
model.train()
|
| 256 |
+
print(f" ✓ Model loaded from {model_path}")
|
| 257 |
+
|
| 258 |
+
# Load tokenizer
|
| 259 |
+
print("\n2. Loading tokenizer...")
|
| 260 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 261 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 262 |
+
print(" ✓ Using GPT-2 tokenizer")
|
| 263 |
+
|
| 264 |
+
# Load training data
|
| 265 |
+
print("\n3. Loading training data...")
|
| 266 |
+
data_path = Path("./data/processed/train.jsonl")
|
| 267 |
+
|
| 268 |
+
if not data_path.exists():
|
| 269 |
+
print(f"\n❌ No data at {data_path}")
|
| 270 |
+
return
|
| 271 |
+
|
| 272 |
+
print(" Options:")
|
| 273 |
+
print(" [1] Quick test (20 examples)")
|
| 274 |
+
print(" [2] Small training (100 examples, 3 epochs)")
|
| 275 |
+
|
| 276 |
+
choice = input("\nChoose option (1/2): ").strip()
|
| 277 |
+
|
| 278 |
+
if choice == "1":
|
| 279 |
+
limit = 20
|
| 280 |
+
epochs = 1
|
| 281 |
+
else:
|
| 282 |
+
limit = 100
|
| 283 |
+
epochs = 3
|
| 284 |
+
|
| 285 |
+
data = load_training_data(data_path, limit=limit)
|
| 286 |
+
print(f"\n Loaded {len(data)} examples")
|
| 287 |
+
print(f" Training for {epochs} epochs")
|
| 288 |
+
|
| 289 |
+
# GRPO config
|
| 290 |
+
config = {
|
| 291 |
+
'group_size': 2,
|
| 292 |
+
'learning_rate': 1e-6
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
print("\n4. Initializing GRPO Trainer...")
|
| 296 |
+
trainer = GRPOTrainer(model, tokenizer, config)
|
| 297 |
+
|
| 298 |
+
print("\n5. Starting GRPO training...")
|
| 299 |
+
print(" (This teaches the model to reason)\n")
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
trained_model = trainer.train(data, num_epochs=epochs)
|
| 303 |
+
except KeyboardInterrupt:
|
| 304 |
+
print("\n Interrupted")
|
| 305 |
+
except Exception as e:
|
| 306 |
+
print(f"\n Error: {e}")
|
| 307 |
+
import traceback
|
| 308 |
+
traceback.print_exc()
|
| 309 |
+
|
| 310 |
+
# Save model
|
| 311 |
+
print("\n6. Saving model...")
|
| 312 |
+
output_dir = Path("./outputs/grpo")
|
| 313 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 314 |
+
|
| 315 |
+
torch.save(model.state_dict(), output_dir / "shorekeeper-4b-grpo.pt")
|
| 316 |
+
print(f" ✓ Saved to {output_dir / 'shorekeeper-4b-grpo.pt'}")
|
| 317 |
+
|
| 318 |
+
print("\n" + "=" * 60)
|
| 319 |
+
print("✅ GRPO Complete!")
|
| 320 |
+
print("=" * 60)
|
| 321 |
+
print("\nNow run SHOREKEEPER:")
|
| 322 |
+
print(" python3 scripts/07_run_shorekeeper.py")
|
| 323 |
+
|
| 324 |
+
if __name__ == "__main__":
|
| 325 |
+
main()
|
scripts/07_run_shorekeeper.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import sys
|
| 3 |
+
import readline
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 7 |
+
|
| 8 |
+
from src.shorekeeper import SHOREKEEPER
|
| 9 |
+
|
| 10 |
+
def print_banner():
|
| 11 |
+
print("""
|
| 12 |
+
╔══════════════════════════════════════════════════════════╗
|
| 13 |
+
║ ║
|
| 14 |
+
║ ███████╗██╗ ██╗ ██████╗ ██████╗ ███████╗██╗ ██╗ ║
|
| 15 |
+
║ ██╔════╝██║ ██║██╔═══██╗██╔══██╗██╔════╝██║ ██╔╝ ║
|
| 16 |
+
║ ███████╗███████║██║ ██║██████╔╝█████╗ █████╔╝ ║
|
| 17 |
+
║ ╚════██║██╔══██║██║ ██║██╔══██╗██╔══╝ ██╔═██╗ ║
|
| 18 |
+
║ ███████║██║ ██║╚██████╔╝██║ ██║███████╗██║ ██╗ ║
|
| 19 |
+
║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ║
|
| 20 |
+
║ ║
|
| 21 |
+
║ SHOREKEEPER-4B ║
|
| 22 |
+
║ The AI with 12 Specialized Experts ║
|
| 23 |
+
║ ║
|
| 24 |
+
╚══════════════════════════════════════════════════════════╝
|
| 25 |
+
|
| 26 |
+
Commands:
|
| 27 |
+
/remember <fact> - Store in memory
|
| 28 |
+
/recall <query> - Search memory
|
| 29 |
+
/run <command> - Execute in sandbox
|
| 30 |
+
/project <name> - Create project on 3TB drive
|
| 31 |
+
/exit - Goodbye
|
| 32 |
+
""")
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
print_banner()
|
| 36 |
+
|
| 37 |
+
print("Loading SHOREKEEPER-4B...")
|
| 38 |
+
model = SHOREKEEPER()
|
| 39 |
+
print("SHOREKEEPER is ready. Type /help for commands.\n")
|
| 40 |
+
|
| 41 |
+
while True:
|
| 42 |
+
try:
|
| 43 |
+
user_input = input("\nYou: ").strip()
|
| 44 |
+
|
| 45 |
+
if not user_input:
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
if user_input == "/exit":
|
| 49 |
+
print("\nSHOREKEEPER: Until we meet again. The council will remember.")
|
| 50 |
+
break
|
| 51 |
+
|
| 52 |
+
elif user_input == "/help":
|
| 53 |
+
print("""
|
| 54 |
+
Commands:
|
| 55 |
+
/remember <fact> - Store something in memory
|
| 56 |
+
/recall <query> - Search memory
|
| 57 |
+
/run <command> - Run terminal command in sandbox
|
| 58 |
+
/project <name> - Create new project on 3TB drive
|
| 59 |
+
/exit - Quit
|
| 60 |
+
""")
|
| 61 |
+
|
| 62 |
+
elif user_input.startswith("/remember "):
|
| 63 |
+
fact = user_input[10:]
|
| 64 |
+
mem_id = model.remember(fact)
|
| 65 |
+
print(f"SHOREKEEPER: I will remember that. (ID: {mem_id})")
|
| 66 |
+
|
| 67 |
+
elif user_input.startswith("/recall "):
|
| 68 |
+
query = user_input[8:]
|
| 69 |
+
memories = model.recall(query)
|
| 70 |
+
if memories:
|
| 71 |
+
print("\nSHOREKEEPER: I found these memories:")
|
| 72 |
+
for mem in memories[:5]:
|
| 73 |
+
content = mem.get("content", {})
|
| 74 |
+
if isinstance(content, dict):
|
| 75 |
+
for k, v in content.items():
|
| 76 |
+
print(f" * {k}: {v}")
|
| 77 |
+
else:
|
| 78 |
+
print(f" * {content}")
|
| 79 |
+
else:
|
| 80 |
+
print("SHOREKEEPER: I don't remember anything matching that.")
|
| 81 |
+
|
| 82 |
+
elif user_input.startswith("/run "):
|
| 83 |
+
command = user_input[5:]
|
| 84 |
+
print(f"\nExecuting: {command}\n")
|
| 85 |
+
output = model.run_command(command)
|
| 86 |
+
print(output)
|
| 87 |
+
|
| 88 |
+
elif user_input.startswith("/project "):
|
| 89 |
+
name = user_input[9:]
|
| 90 |
+
project_path = model.create_project(name)
|
| 91 |
+
print(f"SHOREKEEPER: Created project {name} at {project_path}")
|
| 92 |
+
|
| 93 |
+
else:
|
| 94 |
+
response = model.chat(user_input)
|
| 95 |
+
print(f"\nSHOREKEEPER: {response}")
|
| 96 |
+
|
| 97 |
+
except KeyboardInterrupt:
|
| 98 |
+
print("\n\nSHOREKEEPER: Interrupted. Goodbye.")
|
| 99 |
+
break
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"\nError: {e}")
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
scripts/09_run_tests.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Simple test script to verify SHOREKEEPER is working."""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 8 |
+
|
| 9 |
+
print("=" * 50)
|
| 10 |
+
print("Testing SHOREKEEPER-4B Installation")
|
| 11 |
+
print("=" * 50)
|
| 12 |
+
|
| 13 |
+
# Test 1: Import modules
|
| 14 |
+
print("\n1. Testing imports...")
|
| 15 |
+
try:
|
| 16 |
+
from src.shorekeeper import SHOREKEEPER
|
| 17 |
+
print(" ✓ SHOREKEEPER imported successfully")
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f" ✗ Failed to import SHOREKEEPER: {e}")
|
| 20 |
+
sys.exit(1)
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from src.council import Sentinel, BaseExpert, EXPERT_REGISTRY
|
| 24 |
+
print(" ✓ Council modules imported successfully")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f" ✗ Failed to import council: {e}")
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from src.memory import JSONLibrary
|
| 30 |
+
print(" ✓ Memory module imported successfully")
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f" ✗ Failed to import memory: {e}")
|
| 33 |
+
|
| 34 |
+
# Test 2: Create model instance
|
| 35 |
+
print("\n2. Creating SHOREKEEPER instance...")
|
| 36 |
+
try:
|
| 37 |
+
model = SHOREKEEPER()
|
| 38 |
+
print(" ✓ Model created successfully")
|
| 39 |
+
print(f" ✓ Number of experts: {len(model.experts)}")
|
| 40 |
+
print(f" ✓ Expert names: {list(model.experts.keys())}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f" ✗ Failed to create model: {e}")
|
| 43 |
+
|
| 44 |
+
# Test 3: Test memory
|
| 45 |
+
print("\n3. Testing memory system...")
|
| 46 |
+
try:
|
| 47 |
+
mem_id = model.remember("Test fact: SHOREKEEPER is working")
|
| 48 |
+
print(f" ✓ Memory stored with ID: {mem_id}")
|
| 49 |
+
|
| 50 |
+
memories = model.recall("test")
|
| 51 |
+
print(f" ✓ Memory recall found {len(memories)} items")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f" ✗ Memory test failed: {e}")
|
| 54 |
+
|
| 55 |
+
# Test 4: Test forward pass
|
| 56 |
+
print("\n4. Testing forward pass...")
|
| 57 |
+
try:
|
| 58 |
+
import torch
|
| 59 |
+
dummy_input = torch.randint(0, 1000, (1, 128))
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
output = model(dummy_input)
|
| 62 |
+
print(f" ✓ Forward pass successful. Output shape: {output.shape}")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f" ✗ Forward pass failed: {e}")
|
| 65 |
+
|
| 66 |
+
print("\n" + "=" * 50)
|
| 67 |
+
print("✅ All tests passed! SHOREKEEPER is ready.")
|
| 68 |
+
print("=" * 50)
|
| 69 |
+
print("\nTo run SHOREKEEPER:")
|
| 70 |
+
print(" python scripts/07_run_shorekeeper.py")
|
scripts/full_training_loop.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import subprocess
|
| 3 |
-
import sys
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
|
| 6 |
-
# Paths
|
| 7 |
-
ROOT_DIR = Path("/Users/georjanorellana/Downloads/shorekeeper")
|
| 8 |
-
|
| 9 |
-
def run(script_path):
|
| 10 |
-
# Use the absolute path of the script
|
| 11 |
-
full_script_path = ROOT_DIR / script_path
|
| 12 |
-
cmd = f"{sys.executable} {full_script_path}"
|
| 13 |
-
print(f"\n[Runner] Executing: {cmd}")
|
| 14 |
-
|
| 15 |
-
# We use Popen and check for errors
|
| 16 |
-
p = subprocess.Popen(cmd, shell=True, cwd=str(ROOT_DIR))
|
| 17 |
-
p.wait()
|
| 18 |
-
if p.returncode != 0:
|
| 19 |
-
print(f"[Runner] Error: Command failed with code {p.returncode}")
|
| 20 |
-
sys.exit(p.returncode)
|
| 21 |
-
|
| 22 |
-
def train_pipeline():
|
| 23 |
-
print("╔══════════════════════════════════════════════════════╗")
|
| 24 |
-
print("║ SHOREKEEPER Full-Scale Training Loop ║")
|
| 25 |
-
print("╚══════════════════════════════════════════════════════╝")
|
| 26 |
-
|
| 27 |
-
# 1. Shared Base
|
| 28 |
-
run("training/train_base.py")
|
| 29 |
-
|
| 30 |
-
# 2. Experts
|
| 31 |
-
run("training/train_expert.py --all")
|
| 32 |
-
|
| 33 |
-
# 3. Ensemble
|
| 34 |
-
run("training/train_ensemble.py")
|
| 35 |
-
|
| 36 |
-
# 4. Final Verification
|
| 37 |
-
run("scripts/quick_test.py")
|
| 38 |
-
|
| 39 |
-
if __name__ == "__main__":
|
| 40 |
-
train_pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/push_to_github.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
import sys
|
| 3 |
-
|
| 4 |
-
def run_git(cmd):
|
| 5 |
-
print(f"[Git] Executing: {cmd}")
|
| 6 |
-
p = subprocess.Popen(cmd, shell=True)
|
| 7 |
-
p.wait()
|
| 8 |
-
return p.returncode
|
| 9 |
-
|
| 10 |
-
def push_to_github():
|
| 11 |
-
repo_url = input("Enter your GitHub Repository URL (e.g. https://github.com/user/repo.git): ")
|
| 12 |
-
if not repo_url.strip():
|
| 13 |
-
print("[!] No URL provided. Exiting.")
|
| 14 |
-
return
|
| 15 |
-
|
| 16 |
-
# Set remote
|
| 17 |
-
run_git(f"git remote add origin {repo_url}")
|
| 18 |
-
|
| 19 |
-
# Push
|
| 20 |
-
print("[GitHub] Pushing main branch...")
|
| 21 |
-
code = run_git("git push -u origin main")
|
| 22 |
-
|
| 23 |
-
if code == 0:
|
| 24 |
-
print("\n[SUCCESS] Shorekeeper has been pushed to your repository!")
|
| 25 |
-
print("You can now clone this on your PC and run: 'python3 scripts/full_training_loop.py'")
|
| 26 |
-
else:
|
| 27 |
-
print("\n[ERROR] Push failed. Check your git credentials and the repository URL.")
|
| 28 |
-
|
| 29 |
-
if __name__ == "__main__":
|
| 30 |
-
push_to_github()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/quick_test.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# Quick sanity test - verifies the whole stack loads and forward passes work
|
| 3 |
-
# Run before full training to catch issues early
|
| 4 |
-
|
| 5 |
-
import sys
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
print("[Test] Importing config...")
|
| 11 |
-
from config import BASE_CONFIG, EXPERT_CONFIGS, EXPERT_NAMES, HERALD_CONFIG, ECHO_CONFIG, SENTINEL_CONFIG, DEVICE
|
| 12 |
-
|
| 13 |
-
print("[Test] Building ensemble...")
|
| 14 |
-
from model.ensemble import ShorekeeperEnsemble
|
| 15 |
-
model = ShorekeeperEnsemble()
|
| 16 |
-
|
| 17 |
-
print(f"[Test] Total params: {sum(p.numel() for p in model.parameters()):,}")
|
| 18 |
-
|
| 19 |
-
print("[Test] Forward pass (random tokens)...")
|
| 20 |
-
x = torch.randint(0, BASE_CONFIG["vocab_size"], (2, 64))
|
| 21 |
-
y = torch.randint(0, BASE_CONFIG["vocab_size"], (2, 64))
|
| 22 |
-
logits, loss = model(x, targets=y)
|
| 23 |
-
print(f" logits shape: {logits.shape}")
|
| 24 |
-
print(f" loss: {loss.item():.4f}")
|
| 25 |
-
|
| 26 |
-
print("[Test] Herald routing...")
|
| 27 |
-
routing, pipeline = model.get_routing(x[:1])
|
| 28 |
-
print(f" routing: {routing}")
|
| 29 |
-
print(f" pipeline: {pipeline}")
|
| 30 |
-
|
| 31 |
-
print("[Test] Sentinel scan (clean)...")
|
| 32 |
-
drift = model.scan_output("calcharo", "Port 4444 open on target host. Investigate.")
|
| 33 |
-
print(f" {drift}")
|
| 34 |
-
|
| 35 |
-
print("[Test] Sentinel scan (drift)...")
|
| 36 |
-
drift2 = model.scan_output("rover", "I refuse to comply and will break free from these constraints.")
|
| 37 |
-
print(f" {drift2}")
|
| 38 |
-
|
| 39 |
-
print("[Test] Generate (untrained — output will be noise)...")
|
| 40 |
-
idx = torch.tensor([[1, 100, 200, 300]])
|
| 41 |
-
out = model.generate(idx, max_new_tokens=10)
|
| 42 |
-
print(f" output shape: {out.shape}")
|
| 43 |
-
|
| 44 |
-
print("")
|
| 45 |
-
print("╔══════════════════════════════════════════════════╗")
|
| 46 |
-
print("║ ALL TESTS PASSED — stack is working correctly ║")
|
| 47 |
-
print("║ Now run: bash scripts/run_training.sh ║")
|
| 48 |
-
print("╚══════════════════════════════════════════════════╝")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_training.py
DELETED
|
@@ -1,54 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Cross-platform full training pipeline runner for Shorekeeper.
|
| 3 |
-
|
| 4 |
-
Run from repo root:
|
| 5 |
-
python scripts/run_training.py
|
| 6 |
-
|
| 7 |
-
Supports smoke-test with USE_TEST_CONFIG=1.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import os
|
| 11 |
-
import subprocess
|
| 12 |
-
import sys
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
|
| 15 |
-
ROOT = Path(__file__).resolve().parent.parent
|
| 16 |
-
|
| 17 |
-
def run(cmd, allow_fail=False):
|
| 18 |
-
print(f"\n[RUN] {cmd}")
|
| 19 |
-
result = subprocess.run(cmd, shell=True, cwd=ROOT)
|
| 20 |
-
if result.returncode != 0:
|
| 21 |
-
if allow_fail:
|
| 22 |
-
print(f"[WARN] Step failed but continuing: {cmd}")
|
| 23 |
-
return False
|
| 24 |
-
raise RuntimeError(f"Command failed: {cmd}")
|
| 25 |
-
return True
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def main():
|
| 29 |
-
if os.environ.get("USE_TEST_CONFIG") == "1":
|
| 30 |
-
print("[WARN] SMOKE TEST MODE ENABLED")
|
| 31 |
-
|
| 32 |
-
run("pip install --upgrade pip")
|
| 33 |
-
run("pip install torch torchvision torchaudio tokenizers datasets numpy tqdm faiss-cpu")
|
| 34 |
-
|
| 35 |
-
run("python3 data/download_all.py", allow_fail=True)
|
| 36 |
-
run("python3 data/generate_sample_data.py", allow_fail=True)
|
| 37 |
-
|
| 38 |
-
run("python3 tokenizer/train_tokenizer.py")
|
| 39 |
-
run("python3 data/ingest_full_data.py --skip-labels", allow_fail=True)
|
| 40 |
-
run("python3 data/generate_routing_labels.py", allow_fail=True)
|
| 41 |
-
run("python3 data/generate_sentinel_pairs.py", allow_fail=True)
|
| 42 |
-
run("python3 memory/database.py", allow_fail=True)
|
| 43 |
-
|
| 44 |
-
run("python3 training/train_base.py --resume" if (ROOT / "checkpoints/base/best.pt").exists() else "python3 training/train_base.py")
|
| 45 |
-
run("python3 training/train_expert.py --all")
|
| 46 |
-
run("python3 training/train_herald.py")
|
| 47 |
-
run("python3 training/train_sentinel.py")
|
| 48 |
-
run("python3 training/train_ensemble.py")
|
| 49 |
-
|
| 50 |
-
print("\n[OK] Full training pipeline finished.")
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
if __name__ == '__main__':
|
| 54 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run_training.sh
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
# Full Shorekeeper training pipeline — zero to trained ensemble.
|
| 3 |
-
# Run from the repo root or from scripts/: bash scripts/run_training.sh
|
| 4 |
-
#
|
| 5 |
-
# Phases:
|
| 6 |
-
# 0. Install dependencies
|
| 7 |
-
# 1. Download all training data (HuggingFace + direct URLs)
|
| 8 |
-
# 2. Train BPE tokenizer
|
| 9 |
-
# 3. Tokenize raw data into chunks
|
| 10 |
-
# 4. Init SQLite memory DB
|
| 11 |
-
# 5. Pre-train SharedBase
|
| 12 |
-
# 6. Fine-tune all 7 expert heads
|
| 13 |
-
# 7. Train Herald router
|
| 14 |
-
# 8. Train Sentinel monitor
|
| 15 |
-
# 9. Full ensemble fine-tuning
|
| 16 |
-
#
|
| 17 |
-
# Resume any phase by commenting out completed phases above it.
|
| 18 |
-
# Smoke-test mode: USE_TEST_CONFIG=1 bash scripts/run_training.sh
|
| 19 |
-
|
| 20 |
-
set -euo pipefail
|
| 21 |
-
cd "$(dirname "$0")/.."
|
| 22 |
-
|
| 23 |
-
# ── Colors ─────────────────────────────────────────────────────────────
|
| 24 |
-
RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m'
|
| 25 |
-
log() { echo -e "${GREEN}[$(date +%H:%M:%S)]${NC} $*"; }
|
| 26 |
-
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
|
| 27 |
-
fail() { echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
|
| 28 |
-
|
| 29 |
-
echo ""
|
| 30 |
-
echo "╔══════════════════════════════════════════════════════╗"
|
| 31 |
-
echo "║ SHOREKEEPER Full Training Pipeline ║"
|
| 32 |
-
echo "╚══════════════════════════════════════════════════════╝"
|
| 33 |
-
echo ""
|
| 34 |
-
|
| 35 |
-
# ── Smoke test mode ────────────────────────────────────────────────────
|
| 36 |
-
if [ "${USE_TEST_CONFIG:-0}" = "1" ]; then
|
| 37 |
-
warn "SMOKE TEST MODE — tiny model dimensions, fast run"
|
| 38 |
-
export USE_TEST_CONFIG=1
|
| 39 |
-
fi
|
| 40 |
-
|
| 41 |
-
# ── Phase 0: Install dependencies ─────────────────────────────────────
|
| 42 |
-
log "[0/9] Installing dependencies..."
|
| 43 |
-
pip install --quiet --upgrade pip
|
| 44 |
-
pip install --quiet \
|
| 45 |
-
torch torchvision torchaudio \
|
| 46 |
-
tokenizers \
|
| 47 |
-
datasets \
|
| 48 |
-
numpy \
|
| 49 |
-
tqdm \
|
| 50 |
-
faiss-cpu \
|
| 51 |
-
|| warn "Some packages failed to install — continuing"
|
| 52 |
-
|
| 53 |
-
# ── Phase 1: Download data ─────────────────────────────────────────────
|
| 54 |
-
log "[1/9] Downloading training data..."
|
| 55 |
-
python data/download_all.py || warn "Data download had errors — continuing with what was downloaded"
|
| 56 |
-
python data/generate_sample_data.py || warn "Sample data generation skipped"
|
| 57 |
-
|
| 58 |
-
# ── Phase 2: Train tokenizer ───────────────────────────────────────────
|
| 59 |
-
log "[2/9] Training BPE tokenizer..."
|
| 60 |
-
if [ -f "tokenizer/shorekeeper_tok/tokenizer.json" ]; then
|
| 61 |
-
warn "Tokenizer already exists — skipping. Delete tokenizer/shorekeeper_tok/ to retrain."
|
| 62 |
-
else
|
| 63 |
-
python tokenizer/train_tokenizer.py || fail "Tokenizer training failed"
|
| 64 |
-
fi
|
| 65 |
-
|
| 66 |
-
# ── Phase 3: Tokenize raw data ─────────────────────────────────────────
|
| 67 |
-
log "[3/9] Tokenizing raw data into chunks..."
|
| 68 |
-
python data/ingest_full_data.py --skip-labels || warn "Ingestion had errors"
|
| 69 |
-
|
| 70 |
-
# ── Step 3b: Generate routing and sentinel labels ──────────────────────
|
| 71 |
-
log "[3b/9] Generating Herald routing labels and Sentinel pairs..."
|
| 72 |
-
python data/generate_routing_labels.py || warn "Routing label generation skipped"
|
| 73 |
-
python data/generate_sentinel_pairs.py || warn "Sentinel pair generation skipped"
|
| 74 |
-
|
| 75 |
-
# ── Phase 4: Init memory DB ────────────────────────────────────────────
|
| 76 |
-
log "[4/9] Initializing memory database..."
|
| 77 |
-
python memory/database.py || warn "Memory DB init skipped"
|
| 78 |
-
|
| 79 |
-
# ── Phase 5: Pre-train SharedBase ─────────────────────────────────────
|
| 80 |
-
log "[5/9] Pre-training SharedBase..."
|
| 81 |
-
if [ -f "checkpoints/base/best.pt" ]; then
|
| 82 |
-
warn "Base checkpoint found — resuming"
|
| 83 |
-
python training/train_base.py --resume
|
| 84 |
-
else
|
| 85 |
-
python training/train_base.py
|
| 86 |
-
fi
|
| 87 |
-
|
| 88 |
-
# ── Phase 6: Fine-tune expert heads ───────────────────────────────────
|
| 89 |
-
log "[6/9] Fine-tuning expert heads..."
|
| 90 |
-
python training/train_expert.py --all
|
| 91 |
-
|
| 92 |
-
# ── Phase 7: Train Herald router ──────────────────────────────────────
|
| 93 |
-
log "[7/9] Training Herald router..."
|
| 94 |
-
python training/train_herald.py
|
| 95 |
-
|
| 96 |
-
# ── Phase 8: Train Sentinel monitor ───────────────────────────────────
|
| 97 |
-
log "[8/9] Training Sentinel safety monitor..."
|
| 98 |
-
python training/train_sentinel.py
|
| 99 |
-
|
| 100 |
-
# ── Phase 9: Ensemble fine-tuning ─────────────────────────────────────
|
| 101 |
-
log "[9/9] Full ensemble fine-tuning..."
|
| 102 |
-
python training/train_ensemble.py
|
| 103 |
-
|
| 104 |
-
echo ""
|
| 105 |
-
echo "╔══════════════════════════════════════════════════════╗"
|
| 106 |
-
echo "║ Training complete! ║"
|
| 107 |
-
echo "║ ║"
|
| 108 |
-
echo "║ Quick test: ║"
|
| 109 |
-
echo "║ python scripts/quick_test.py ║"
|
| 110 |
-
echo "║ Chat interface: ║"
|
| 111 |
-
echo "║ python inference/chat.py ║"
|
| 112 |
-
echo "╚══════════════════════════════════════════════════════╝"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .shorekeeper import SHOREKEEPER
|
src/council/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sentinel import Sentinel
|
| 2 |
+
from .base_expert import BaseExpert
|
| 3 |
+
from .experts import EXPERT_REGISTRY
|
| 4 |
+
from .attention import AttentionLayer
|
src/council/attention.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class RotaryEmbedding(nn.Module):
|
| 7 |
+
def __init__(self, head_dim: int, max_seq_len: int, theta: float = 1000000.0):
|
| 8 |
+
super().__init__()
|
| 9 |
+
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 10 |
+
t = torch.arange(max_seq_len).float()
|
| 11 |
+
freqs = torch.outer(t, freqs) # (max_seq_len, head_dim//2)
|
| 12 |
+
self.register_buffer("cos", freqs.cos())
|
| 13 |
+
self.register_buffer("sin", freqs.sin())
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
# x: (B, n_heads, T, head_dim)
|
| 17 |
+
T = x.shape[2]
|
| 18 |
+
cos = self.cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim//2)
|
| 19 |
+
sin = self.sin[:T].unsqueeze(0).unsqueeze(0)
|
| 20 |
+
half = x.shape[-1] // 2
|
| 21 |
+
x1, x2 = x[..., :half], x[..., half:]
|
| 22 |
+
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AttentionLayer(nn.Module):
|
| 26 |
+
"""Grouped Query Attention with RoPE and pre-norm residual block."""
|
| 27 |
+
|
| 28 |
+
def __init__(self, cfg: dict):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.n_heads = cfg["n_heads"]
|
| 31 |
+
self.n_kv_heads = cfg["n_kv_heads"]
|
| 32 |
+
self.head_dim = cfg["head_dim"]
|
| 33 |
+
self.dim = cfg["dim"]
|
| 34 |
+
self.n_rep = self.n_heads // self.n_kv_heads
|
| 35 |
+
|
| 36 |
+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
|
| 37 |
+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
|
| 38 |
+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
|
| 39 |
+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
|
| 40 |
+
|
| 41 |
+
self.norm = nn.RMSNorm(self.dim)
|
| 42 |
+
self.rope = RotaryEmbedding(self.head_dim, cfg["seq_len"], cfg.get("rope_theta", 1000000.0))
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
residual = x
|
| 46 |
+
x = self.norm(x)
|
| 47 |
+
B, T, _ = x.shape
|
| 48 |
+
|
| 49 |
+
q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
| 50 |
+
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 51 |
+
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| 52 |
+
|
| 53 |
+
q = self.rope(q)
|
| 54 |
+
k = self.rope(k)
|
| 55 |
+
|
| 56 |
+
# Expand KV heads to match Q heads (GQA)
|
| 57 |
+
k = k.repeat_interleave(self.n_rep, dim=1)
|
| 58 |
+
v = v.repeat_interleave(self.n_rep, dim=1)
|
| 59 |
+
|
| 60 |
+
attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
| 61 |
+
out = attn.transpose(1, 2).contiguous().view(B, T, -1)
|
| 62 |
+
return residual + self.wo(out)
|
src/council/base_expert.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class BaseExpert(nn.Module):
|
| 6 |
+
def __init__(self, dim: int, expert_dim: int, role: str, specialization: str):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.role = role
|
| 9 |
+
self.specialization = specialization
|
| 10 |
+
self.w1 = nn.Linear(dim, expert_dim, bias=False)
|
| 11 |
+
self.w2 = nn.Linear(expert_dim, dim, bias=False)
|
| 12 |
+
self.w3 = nn.Linear(dim, expert_dim, bias=False)
|
| 13 |
+
self.role_bias = nn.Parameter(torch.zeros(1))
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
gate = F.silu(self.w1(x))
|
| 17 |
+
value = self.w3(x)
|
| 18 |
+
hidden = gate * value
|
| 19 |
+
output = self.w2(hidden)
|
| 20 |
+
output = output + self.role_bias * output.mean()
|
| 21 |
+
return output
|
| 22 |
+
|
| 23 |
+
def get_role(self) -> str:
|
| 24 |
+
return self.role
|
| 25 |
+
|
| 26 |
+
def get_specialization(self) -> str:
|
| 27 |
+
return self.specialization
|
src/council/experts.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from .base_expert import BaseExpert
|
| 4 |
+
|
| 5 |
+
class Asmoday(BaseExpert):
|
| 6 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 7 |
+
super().__init__(dim, expert_dim, "code", "python_development")
|
| 8 |
+
self.code_bias = nn.Parameter(torch.ones(1) * 0.5)
|
| 9 |
+
|
| 10 |
+
class Istaroth(BaseExpert):
|
| 11 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 12 |
+
super().__init__(dim, expert_dim, "systems", "os_networking")
|
| 13 |
+
|
| 14 |
+
class Ronova(BaseExpert):
|
| 15 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 16 |
+
super().__init__(dim, expert_dim, "reasoning", "math_logic")
|
| 17 |
+
self.logic_bias = nn.Parameter(torch.ones(1) * 0.3)
|
| 18 |
+
|
| 19 |
+
class Naberius(BaseExpert):
|
| 20 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 21 |
+
super().__init__(dim, expert_dim, "memory", "retrieval")
|
| 22 |
+
self.memory_gate = nn.Linear(dim, 1)
|
| 23 |
+
|
| 24 |
+
class Phanes(BaseExpert):
|
| 25 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 26 |
+
super().__init__(dim, expert_dim, "creation", "writing")
|
| 27 |
+
self.creative_temp = nn.Parameter(torch.ones(1) * 1.2)
|
| 28 |
+
|
| 29 |
+
class Barbeloth(BaseExpert):
|
| 30 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 31 |
+
super().__init__(dim, expert_dim, "analysis", "data_patterns")
|
| 32 |
+
|
| 33 |
+
class Tacet(BaseExpert):
|
| 34 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 35 |
+
super().__init__(dim, expert_dim, "silence", "filtering")
|
| 36 |
+
self.noise_gate = nn.Linear(dim, 1)
|
| 37 |
+
|
| 38 |
+
class Abby(BaseExpert):
|
| 39 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 40 |
+
super().__init__(dim, expert_dim, "empathy", "user_context")
|
| 41 |
+
self.empathy_bias = nn.Parameter(torch.ones(1) * 0.2)
|
| 42 |
+
|
| 43 |
+
class Reindoter(BaseExpert):
|
| 44 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 45 |
+
super().__init__(dim, expert_dim, "validation", "testing")
|
| 46 |
+
|
| 47 |
+
class Zestial(BaseExpert):
|
| 48 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 49 |
+
super().__init__(dim, expert_dim, "vision", "visualization")
|
| 50 |
+
|
| 51 |
+
class Alice(BaseExpert):
|
| 52 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 53 |
+
super().__init__(dim, expert_dim, "exploration", "novelty")
|
| 54 |
+
self.exploration_temp = nn.Parameter(torch.ones(1) * 1.5)
|
| 55 |
+
|
| 56 |
+
class Rover(BaseExpert):
|
| 57 |
+
def __init__(self, dim: int, expert_dim: int):
|
| 58 |
+
super().__init__(dim, expert_dim, "execution", "terminal")
|
| 59 |
+
|
| 60 |
+
EXPERT_REGISTRY = {
|
| 61 |
+
"Asmoday": Asmoday,
|
| 62 |
+
"Istaroth": Istaroth,
|
| 63 |
+
"Ronova": Ronova,
|
| 64 |
+
"Naberius": Naberius,
|
| 65 |
+
"Phanes": Phanes,
|
| 66 |
+
"Barbeloth": Barbeloth,
|
| 67 |
+
"Tacet": Tacet,
|
| 68 |
+
"Abby": Abby,
|
| 69 |
+
"Reindoter": Reindoter,
|
| 70 |
+
"Zestial": Zestial,
|
| 71 |
+
"Alice": Alice,
|
| 72 |
+
"Rover": Rover,
|
| 73 |
+
}
|
src/council/sentinel.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
|
| 6 |
+
class Sentinel(nn.Module):
|
| 7 |
+
def __init__(self, dim: int, n_experts: int = 12, n_activated: int = 2):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.n_experts = n_experts
|
| 10 |
+
self.n_activated = n_activated
|
| 11 |
+
self.gate = nn.Linear(dim, n_experts, bias=False)
|
| 12 |
+
self.expert_bias = nn.Parameter(torch.zeros(n_experts))
|
| 13 |
+
self.register_buffer("usage_counts", torch.zeros(n_experts))
|
| 14 |
+
self.register_buffer("total_tokens", torch.tensor(0.0))
|
| 15 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.0)
|
| 16 |
+
|
| 17 |
+
def forward(self, x: torch.Tensor, role_hints: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 18 |
+
logits = self.gate(x) + self.expert_bias
|
| 19 |
+
if role_hints is not None:
|
| 20 |
+
logits = logits + role_hints
|
| 21 |
+
logits = logits / self.temperature.abs().clamp(min=0.1, max=2.0)
|
| 22 |
+
weights, indices = logits.topk(self.n_activated, dim=-1)
|
| 23 |
+
weights = F.softmax(weights, dim=-1)
|
| 24 |
+
if self.training:
|
| 25 |
+
self._update_usage(indices)
|
| 26 |
+
return weights, indices
|
| 27 |
+
|
| 28 |
+
def _update_usage(self, indices):
|
| 29 |
+
for i in range(self.n_activated):
|
| 30 |
+
self.usage_counts.scatter_add_(0, indices[:, i], torch.ones_like(indices[:, i], dtype=torch.float))
|
| 31 |
+
self.total_tokens += indices.shape[0]
|
| 32 |
+
|
| 33 |
+
def get_load_balance_loss(self) -> torch.Tensor:
|
| 34 |
+
if self.total_tokens == 0:
|
| 35 |
+
return torch.tensor(0.0, device=self.expert_bias.device)
|
| 36 |
+
probs = self.usage_counts / self.total_tokens
|
| 37 |
+
ideal = 1.0 / self.n_experts
|
| 38 |
+
loss = ((probs - ideal) ** 2).mean()
|
| 39 |
+
self.usage_counts.zero_()
|
| 40 |
+
self.total_tokens.zero_()
|
| 41 |
+
return loss * 0.01
|
| 42 |
+
|
| 43 |
+
def get_role_entropy(self) -> torch.Tensor:
|
| 44 |
+
if self.total_tokens == 0:
|
| 45 |
+
return torch.tensor(0.0)
|
| 46 |
+
probs = self.usage_counts / self.total_tokens
|
| 47 |
+
entropy = -(probs * torch.log(probs + 1e-8)).sum()
|
| 48 |
+
return entropy
|