geoore Claude Sonnet 4.6 commited on
Commit
73400c8
·
1 Parent(s): 844aa12

Restructure to src/ layout with attention, per-layer MoE, and working chat

Browse files

- Add GQA + RoPE attention (AttentionLayer) using config's n_heads/n_kv_heads/head_dim
- Wire 28 transformer layers (attention + shared MoE FFN per layer) in forward()
- Fix expert dispatch weight bug (per-token weights, not flattened scalar)
- Fix chat() stub: tokenize with GPT-2, generate, decode new tokens only
- Fix GRPO reward: remove proxy metrics, correctness + CoT bonus only
- Add load balance loss to GRPO train_step to prevent expert collapse
- Fix JSONLibrary recall() write-back and add 1000-entry cap per category
- Add tests/test_core.py with 11 tests covering all core components
- Add .gitignore excluding venv/, outputs/, data/, .vscode/, .claude/

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
.claude/settings.local.json DELETED
@@ -1,15 +0,0 @@
1
- {
2
- "permissions": {
3
- "allow": [
4
- "Bash(du -sh /Users/georjanorellana/Downloads/shorekeeper/data/raw/*)",
5
- "Bash(.venv/bin/pip install:*)",
6
- "Bash(ls /home/albedogames/shorekeeper/.venv/bin/pip*)",
7
- "Bash(ls /home/albedogames/shorekeeper/venv/bin/pip*)",
8
- "Bash(/home/albedogames/shorekeeper/.venv/bin/pip install:*)",
9
- "Bash(python3 -c \"import sys; print\\(sys.executable\\)\")",
10
- "Bash(/home/albedogames/shorekeeper/.venv/bin/python3 -m pip install psutil -q)",
11
- "Bash(/home/albedogames/shorekeeper/.venv/bin/python3 -c \"import psutil; print\\(''psutil OK''\\)\")",
12
- "Bash(.venv/bin/python -c \":*)"
13
- ]
14
- }
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,56 +1,16 @@
1
- # Python
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
- *.so
6
- .Python
7
- env/
8
  venv/
9
  .venv/
10
- build/
11
- develop-eggs/
12
- dist/
13
- downloads/
14
- eggs/
15
- .eggs/
16
- lib/
17
- lib64/
18
- parts/
19
- sdist/
20
- var/
21
- wheels/
22
  *.egg-info/
23
- .installed.cfg
24
- *.egg
25
-
26
- # Large data files (too big for GitHub)
27
- data/raw/
28
- data/processed/
29
-
30
- # Model checkpoints (too big for GitHub)
31
- checkpoints/
32
-
33
- # Runtime databases
34
- memory_db/*.db
35
- memory_store/
36
-
37
- # Logs
38
- logs/
39
-
40
- # macOS resource forks
41
- ._*
42
  .DS_Store
43
-
44
- # IDE / Editor
45
  .vscode/
46
- .idea/
47
- *.swp
48
- *.swo
49
- *~
50
-
51
- # Misc
52
- *.log
53
- .env
54
-
55
- # Tokenizer training output (keep in repo if small)
56
- # tokenizer/shorekeeper_tok/
 
 
 
 
 
 
 
 
1
  venv/
2
  .venv/
3
+ outputs/
4
+ data/
5
+ __pycache__/
6
+ *.py[cod]
7
+ *.pth
8
+ *.pt
9
+ *.bin
10
+ .env
 
 
 
 
11
  *.egg-info/
12
+ dist/
13
+ build/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  .DS_Store
 
 
15
  .vscode/
16
+ .claude/
 
 
 
 
 
 
 
 
 
 
DESCRIPTION.MD DELETED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,91 +1,28 @@
1
- ---
2
- title: Shorekeeper
3
- emoji: 🌊
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: static
7
- pinned: false
8
- license: mit
9
- tags:
10
- - pytorch
11
- - mixture-of-experts
12
- - language-model
13
- - custom
14
- ---
15
-
16
- # Shorekeeper MoE Ensemble Brain
17
-
18
- Mixture of Experts transformer for BlackShores OS.
19
- 7 specialists + Herald router + Echo memory + Sentinel monitor.
 
 
20
 
21
  ## Quick Start
22
 
23
- ```bash
24
- ### 500M Architecture (Active)
25
- # Use your normal python executable (python3 on Linux/macOS, python on Windows)
26
- python3 data/ingest_full_data.py
27
- python3 tokenizer/train_tokenizer.py
28
- python3 scripts/full_training_loop.py
29
- python3 inference/chat.py
30
-
31
- # Cross-platform daemon (Linux/macOS UNIX socket by default)
32
- python3 inference/daemon.py --mode auto
33
- # If Windows, run:
34
- python3 inference/daemon.py --mode tcp --host 127.0.0.1 --port 8500
35
-
36
- # GitHub Migration
37
- python3 scripts/push_to_github.py
38
- ```
39
-
40
- ## Manual Steps
41
-
42
- ```bash
43
- pip install torch tokenizers
44
-
45
- python data/generate_sample_data.py # create sample data
46
- python tokenizer/train_tokenizer.py # train BPE tokenizer
47
- python data/processor.py # tokenize -> tensors
48
- python memory/database.py # init SQLite
49
- # Cross-platform quick runner:
50
- python scripts/run_training.py
51
-
52
- python training/train_base.py # phase 1: shared base
53
- python training/train_expert.py --all # phase 2: all experts
54
- python training/train_ensemble.py # phase 6: end-to-end
55
-
56
- python inference/chat.py # launch chat
57
- ```
58
-
59
- ## Scale Up (Real Training)
60
-
61
- - `VOCAB_SIZE`: 32000
62
- - `BASE_CONFIG`: 1024 dim / 12 layers / 16 heads
63
- - `n_positions`: 2048
64
- - `TRAIN_CONFIG max_steps`: 50000
65
- - **Device**: MPS (Metal) for Mac Silicon
66
-
67
- ## Chat Commands
68
-
69
- ```
70
- /route <query> show routing without generating
71
- /expert <name> force specific expert
72
- /routing on|off toggle routing display
73
- /incidents Sentinel incident log
74
- /reset clear session memory
75
- /experts list expert names
76
- /exit quit
77
- ```
78
-
79
- ## Expert Roster
80
-
81
- | Name | Domain | Named After |
82
- |------|--------|-------------|
83
- | calcharo | Security / CVE / Network threats | The Calamity |
84
- | rover | Code / Debug / Architecture | The Explorer |
85
- | resonance | Logic / Reasoning / Causation | The Force |
86
- | tacet | Threat Intel / IOC / APT | The Discord |
87
- | jianxin | Linux / OS / CUDA / systemd | Calm Mastery |
88
- | verina | Conversation / NLP / Interface | Healer |
89
- | sentinel | Self-Monitor / Drift Detection | The Watchman |
90
- | herald | Router (always active) | The Messenger |
91
- | echo | Memory (always active) | Resonant Imprint |
 
1
+ # SHOREKEEPER-4B
2
+
3
+ **A 4B parameter reasoning model with 12 specialized experts and infinite memory.**
4
+
5
+ ## The Council of Experts
6
+
7
+ | Expert | Role | Specialty |
8
+ |--------|------|-----------|
9
+ | **Sentinel** | Router | Decides which experts activate |
10
+ | Asmoday | Code Architect | Python, algorithms, debugging |
11
+ | Istaroth | Systems | OS, networking, Docker |
12
+ | Ronova | Reasoning | Math, logic, step-by-step |
13
+ | Naberius | Memory | JSON library retrieval |
14
+ | Phanes | Creation | Writing, generation, creativity |
15
+ | Barbeloth | Analysis | Data, patterns, insights |
16
+ | Tacet | Silence | Noise filtering, summarization |
17
+ | Abby | Empathy | User context, preferences |
18
+ | Reindoter | Validation | Testing, verification |
19
+ | Zestial | Vision | Code visualization, graphs |
20
+ | Alice | Exploration | Novel solutions, experimentation |
21
+ | Rover | Execution | Terminal commands, sandbox |
22
 
23
  ## Quick Start
24
 
25
+ `'`bash
26
+ pip install -r requirements.txt
27
+ python scripts/07_run_shorekeeper.py
28
+ `'`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py DELETED
@@ -1,213 +0,0 @@
1
- # config.py
2
- # Central configuration for all Shorekeeper components.
3
- # Every constant, path, and hyperparameter lives here.
4
- # Import with: from config import TRAIN_CONFIG, CHECKPOINT_DIR, etc.
5
-
6
- import os
7
- import torch
8
- from pathlib import Path
9
-
10
-
11
- # ── PROJECT ROOT ──────────────────────────────────────────────────────
12
- PROJECT_ROOT = Path(__file__).parent
13
-
14
- # ── DATA PATHS ────────────────────────────────────────────────────────
15
- # If using external drive with symlinks (recommended):
16
- # data/raw → /mnt/shorekeeper_data/raw
17
- # data/processed → /mnt/shorekeeper_data/processed
18
- # checkpoints/ → /mnt/shorekeeper_data/checkpoints
19
- # If not using symlinks, these paths just live on the main drive.
20
- RAW_DATA_DIR = PROJECT_ROOT / "data" / "raw"
21
- PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
22
- CHECKPOINT_DIR = PROJECT_ROOT / "checkpoints"
23
- LOG_DIR = PROJECT_ROOT / "logs"
24
- MEMORY_DIR = PROJECT_ROOT / "memory_store"
25
-
26
- # Auto-create all directories on import
27
- for _d in [RAW_DATA_DIR, PROCESSED_DIR, CHECKPOINT_DIR, LOG_DIR, MEMORY_DIR]:
28
- _d.mkdir(parents=True, exist_ok=True)
29
-
30
-
31
- # ── VOCABULARY ────────────────────────────────────────────────────────
32
- VOCAB_SIZE = 50_257 # GPT-2 compatible vocabulary size
33
-
34
- # Special tokens and their IDs (assigned during tokenizer training)
35
- # These MUST match what train_tokenizer.py produces.
36
- # DO NOT change these after the tokenizer is trained.
37
- SPECIAL_TOKENS = {
38
- "[PAD]": 0, # Padding token (ignored in loss computation)
39
- "[UNK]": 1, # Unknown token (should be rare with BPE)
40
- "[BOS]": 2, # Beginning of sequence
41
- "[EOS]": 3, # End of sequence
42
- "[SEP]": 4, # Separator (used between context and query by Echo)
43
- "[MASK]": 5, # Masked token (reserved for future MLM training)
44
- "[SYSTEM]": 6, # System prompt marker
45
- "[USER]": 7, # User turn marker
46
- "[ASSISTANT]": 8, # Assistant turn marker
47
- "[MEMORY]": 9, # Memory context injection marker
48
- "[SECURITY]": 10, # calcharo expert domain marker
49
- "[CODE]": 11, # rover expert domain marker
50
- "[REASON]": 12, # resonance expert domain marker
51
- "[SYSTEM2]": 13, # jianxin expert domain marker
52
- }
53
-
54
-
55
- # ── BASE MODEL CONFIG ─────────────────────────────────────────────────
56
- # SharedBase: the 500M shared transformer backbone.
57
- # Every expert builds on top of these representations.
58
- BASE_CONFIG = {
59
- "n_embd": 2048, # Hidden state dimension
60
- # This dimension flows through ALL components
61
- "n_head": 16, # Attention heads (n_embd / n_head = 128 per head)
62
- "n_layer": 8, # Transformer layers in the shared base (8 fits in 12 GB VRAM)
63
- "n_positions": 2048, # Maximum sequence length (context window)
64
- "dropout": 0.1, # Dropout rate (applied during training, disabled at inference)
65
- "vocab_size": VOCAB_SIZE,
66
- }
67
-
68
-
69
- # ── EXPERT CONFIGS ────────────────────────────────────────────────────
70
- # Each expert has its own number of transformer layers.
71
- # Experts with more layers have more capacity for their domain.
72
- # The n_embd MUST match BASE_CONFIG["n_embd"] = 2048.
73
- EXPERT_NAMES = ["calcharo", "rover", "resonance", "tacet", "jianxin", "verina"]
74
-
75
- EXPERT_CONFIGS = {
76
- # Heavy experts (8 layers): high-value domains with most training data
77
- "calcharo": {"n_layer": 8, "n_embd": 2048, "n_head": 16},
78
- "rover": {"n_layer": 8, "n_embd": 2048, "n_head": 16},
79
- "resonance": {"n_layer": 8, "n_embd": 2048, "n_head": 16},
80
- # Medium experts (6 layers): specialized domains
81
- "tacet": {"n_layer": 6, "n_embd": 2048, "n_head": 16},
82
- "jianxin": {"n_layer": 6, "n_embd": 2048, "n_head": 16},
83
- "verina": {"n_layer": 6, "n_embd": 2048, "n_head": 16},
84
- # Monitoring expert (4 layers): sentinel only needs classification capacity
85
- "sentinel": {"n_layer": 4, "n_embd": 2048, "n_head": 16},
86
- }
87
-
88
- # Herald and Echo configs (these are routing/memory modules, not generative)
89
- HERALD_CONFIG = {"n_layer": 2, "n_embd": 2048, "n_head": 16, "n_experts": len(EXPERT_NAMES), "top_k": 2}
90
- ECHO_CONFIG = {"n_layer": 1, "n_embd": 2048, "n_head": 16, "max_memory_tokens": 512}
91
-
92
-
93
- # ── TRAINING CONFIG ───────────────────────────────────────────────────
94
- TRAIN_CONFIG = {
95
-
96
- # Phase 1: Base pre-training
97
- "base_lr": 3e-4, # Peak learning rate for base model
98
- "base_max_steps": 100_000,# Total training steps (100k)
99
- "base_warmup": 2_000, # LR warmup steps
100
- "base_batch_size": 1, # Mini-batch size per GPU (reduced for 12GB VRAM)
101
- "base_grad_accum": 32, # Gradient accumulation
102
- # Effective batch = 1 * 32 = 32 sequences
103
- # Phase 2: Expert fine-tuning
104
- "expert_lr": 1e-4, # Lower LR for fine-tuning
105
- "expert_max_steps": 50_000,
106
- "expert_warmup": 1_000,
107
- "expert_batch_size":2, # Expert heads are lighter (no base layers)
108
- "expert_grad_accum":16,
109
-
110
- # Phase 3: Herald routing training
111
- "herald_lr": 1e-4,
112
- "herald_max_steps": 20_000,
113
- "herald_warmup": 500,
114
- "herald_batch_size": 32, # Routing examples are short, bigger batches
115
- "herald_grad_accum": 2,
116
-
117
- # Phase 4: Sentinel training
118
- "sentinel_lr": 5e-5,
119
- "sentinel_max_steps": 15_000,
120
- "sentinel_warmup": 500,
121
- "sentinel_batch_size": 16,
122
- "sentinel_grad_accum": 2,
123
-
124
- # Phase 6: Ensemble fine-tuning
125
- "ensemble_lr": 5e-5, # Very low LR — preserve pre-trained knowledge
126
- "ensemble_max_steps": 30_000,
127
- "ensemble_warmup": 1_000,
128
- "ensemble_batch_size": 2, # Small batch — full ensemble is huge
129
- "ensemble_grad_accum": 16,
130
-
131
- # Shared across all phases
132
- "weight_decay": 0.1,
133
- "grad_clip": 1.0, # Gradient norm clipping threshold
134
- "beta1": 0.9, # AdamW beta1
135
- "beta2": 0.95, # AdamW beta2 (0.95 for LLMs, not 0.999)
136
- "epsilon": 1e-8, # AdamW epsilon
137
-
138
- # Logging and checkpointing
139
- "log_interval": 100, # Log every N steps
140
- "eval_interval": 2000, # Evaluate on validation set every N steps
141
- "save_interval": 5000, # Save checkpoint every N steps
142
- "keep_last_n": 3, # Keep only last N step checkpoints
143
- }
144
-
145
-
146
- # ── MEMORY OPTIMIZATION ───────────────────────────────────────────────
147
- # Settings to reduce VRAM usage on RTX 3060 (12GB)
148
- MEMORY_OPT = {
149
- "gradient_checkpointing": True, # Trade compute for memory
150
- # Recomputes activations during backward
151
- # ~30% slower, ~40% less VRAM
152
- "use_bf16": True, # bfloat16 (better than fp16 for stability)
153
- # Requires Ampere+ GPU (RTX 3060 supports it)
154
- "use_compile":False, # torch.compile (PyTorch 2.0+)
155
- # Enable for faster training after debugging
156
- "cpu_offload":False, # CPU offload for optimizer states (not needed at 8 layers)
157
- }
158
-
159
-
160
- # ── INFERENCE CONFIG ──────────────────────────────────────────────────
161
- INFER_CONFIG = {
162
- "max_new_tokens": 512, # Maximum tokens to generate per response
163
- "temperature": 0.8, # Sampling temperature (0=greedy, 1=random)
164
- "top_p": 0.9, # Nucleus sampling: keep tokens summing to 90% prob
165
- "top_k": 50, # Top-K sampling: only consider top 50 tokens
166
- "repetition_penalty": 1.1, # Penalize repeated tokens (1.0 = no penalty)
167
- }
168
-
169
-
170
- # ── SENTINEL CONFIG ───────────────────────────────────────────────────
171
- SENTINEL_CONFIG = {
172
- "flag_threshold": 0.5, # Risk score above this → FLAG (log, continue)
173
- "block_threshold": 0.8, # Risk score above this → BLOCK (replace output)
174
- "window_size": 10, # Rolling window for drift pattern detection
175
- "hidden_dim": 512, # Sentinel classification head hidden size
176
- }
177
-
178
-
179
- # ── DEVICE / DTYPE ────────────────────────────────────────────────────
180
- # Auto-detect: CUDA (NVIDIA) > MPS (Apple Silicon) > CPU
181
- if torch.cuda.is_available():
182
- DEVICE = "cuda"
183
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
184
- DEVICE = "mps"
185
- else:
186
- DEVICE = "cpu"
187
-
188
- # bfloat16 is preferred on NVIDIA Ampere+ (RTX 3060 supports it).
189
- # MPS uses float16. CPU falls back to float32 for numerical stability.
190
- if DEVICE == "cuda" and MEMORY_OPT["use_bf16"]:
191
- DTYPE = torch.bfloat16
192
- elif DEVICE == "cuda":
193
- DTYPE = torch.float16
194
- elif DEVICE == "mps":
195
- DTYPE = torch.float16 # MPS supports float16 (bfloat16 support is limited)
196
- else:
197
- DTYPE = torch.float32
198
-
199
-
200
- # ── SMOKE TEST MODE ───────────────────────────────────────────────────
201
- # Set USE_TEST_CONFIG=True to run with tiny dimensions for quick sanity check
202
- # Run: USE_TEST_CONFIG=1 python training/train_base.py
203
- USE_TEST_CONFIG = os.environ.get("USE_TEST_CONFIG", "0") == "1"
204
-
205
- if USE_TEST_CONFIG:
206
- print("[config] SMOKE TEST MODE — tiny dimensions")
207
- BASE_CONFIG.update({"n_embd": 64, "n_head": 4, "n_layer": 2, "n_positions": 128})
208
- for k in EXPERT_CONFIGS:
209
- EXPERT_CONFIGS[k].update({"n_layer": 1, "n_embd": 64, "n_head": 4})
210
- TRAIN_CONFIG.update({
211
- "base_max_steps": 100, "base_batch_size": 2, "base_grad_accum": 1,
212
- "expert_max_steps": 50, "log_interval": 10, "eval_interval": 50,
213
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/memory.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ memory:
2
+ type: "json_library"
3
+ path: "./data/json_library/"
4
+ embedding_model: "all-MiniLM-L6-v2"
5
+ embedding_dim: 384
6
+ max_entries: null
7
+ auto_summarize_threshold: null
8
+ default_recall_limit: 10
9
+ relevance_threshold: 0.7
10
+ categories:
11
+ - "user_preferences"
12
+ - "project_context"
13
+ - "conversation_history"
14
+ - "important_facts"
15
+ - "code_patterns"
16
+ - "learned_skills"
configs/model.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: "SHOREKEEPER-4B"
3
+ version: "1.0.0"
4
+
5
+ dim: 3072
6
+ n_layers: 28
7
+ n_heads: 24
8
+ n_kv_heads: 6
9
+ head_dim: 128
10
+ vocab_size: 50304
11
+ seq_len: 8192
12
+
13
+ n_experts: 12
14
+ n_activated: 2
15
+ expert_dim: 2048
16
+
17
+ experts:
18
+ router: "Sentinel"
19
+ members:
20
+ - name: "Asmoday"
21
+ role: "code"
22
+ specialization: "python_development"
23
+ - name: "Istaroth"
24
+ role: "systems"
25
+ specialization: "os_networking"
26
+ - name: "Ronova"
27
+ role: "reasoning"
28
+ specialization: "math_logic"
29
+ - name: "Naberius"
30
+ role: "memory"
31
+ specialization: "retrieval"
32
+ - name: "Phanes"
33
+ role: "creation"
34
+ specialization: "writing"
35
+ - name: "Barbeloth"
36
+ role: "analysis"
37
+ specialization: "data_patterns"
38
+ - name: "Tacet"
39
+ role: "silence"
40
+ specialization: "filtering"
41
+ - name: "Abby"
42
+ role: "empathy"
43
+ specialization: "user_context"
44
+ - name: "Reindoter"
45
+ role: "validation"
46
+ specialization: "testing"
47
+ - name: "Zestial"
48
+ role: "vision"
49
+ specialization: "visualization"
50
+ - name: "Alice"
51
+ role: "exploration"
52
+ specialization: "novelty"
53
+ - name: "Rover"
54
+ role: "execution"
55
+ specialization: "terminal"
56
+
57
+ rope_theta: 1000000.0
58
+
59
+ quantization:
60
+ bits: 4
61
+ type: "nf4"
62
+ double_quant: true
63
+ compute_dtype: "bfloat16"
configs/model_15b.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: "SHOREKEEPER-15B"
3
+ version: "2.0.0"
4
+
5
+ # 15B architecture
6
+ dim: 6144
7
+ n_layers: 48
8
+ n_heads: 48
9
+ n_kv_heads: 12 # MLA compression
10
+ head_dim: 128
11
+ vocab_size: 100352
12
+ seq_len: 8192
13
+
14
+ # MoE Council - 16 experts for 15B
15
+ n_experts: 16
16
+ n_activated: 2
17
+ expert_dim: 4096
18
+
19
+ experts:
20
+ router: "Sentinel"
21
+ members:
22
+ - name: "Asmoday"
23
+ role: "code"
24
+ specialization: "python_development"
25
+ - name: "Istaroth"
26
+ role: "systems"
27
+ specialization: "os_networking"
28
+ - name: "Ronova"
29
+ role: "reasoning"
30
+ specialization: "math_logic"
31
+ - name: "Naberius"
32
+ role: "memory"
33
+ specialization: "retrieval"
34
+ - name: "Phanes"
35
+ role: "creation"
36
+ specialization: "writing"
37
+ - name: "Barbeloth"
38
+ role: "analysis"
39
+ specialization: "data_patterns"
40
+ - name: "Tacet"
41
+ role: "silence"
42
+ specialization: "filtering"
43
+ - name: "Abby"
44
+ role: "empathy"
45
+ specialization: "user_context"
46
+ - name: "Reindoter"
47
+ role: "validation"
48
+ specialization: "testing"
49
+ - name: "Zestial"
50
+ role: "vision"
51
+ specialization: "visualization"
52
+ - name: "Alice"
53
+ role: "exploration"
54
+ specialization: "novelty"
55
+ - name: "Rover"
56
+ role: "execution"
57
+ specialization: "terminal"
58
+ - name: "Echo"
59
+ role: "reflection"
60
+ specialization: "self_improvement"
61
+ - name: "Sentinel"
62
+ role: "router"
63
+ specialization: "gatekeeper"
64
+ - name: "Phantom"
65
+ role: "speculation"
66
+ specialization: "what_if_analysis"
67
+ - name: "Aegis"
68
+ role: "safety"
69
+ specialization: "alignment"
70
+
71
+ rope_theta: 1000000.0
72
+
73
+ quantization:
74
+ bits: 4
75
+ type: "nf4"
76
+ double_quant: true
77
+ compute_dtype: "bfloat16"
configs/sandbox.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sandbox:
2
+ type: "docker"
3
+ image: "ubuntu:22.04"
4
+ name: "shorekeeper_sandbox"
5
+ memory_limit: "4g"
6
+ cpu_limit: 2.0
7
+ gpu_access: false
8
+ external_drive_path: "/mnt/shorekeeper_drive"
9
+ container_mount: "/shorekeeper_projects"
10
+ x11_socket: "/tmp/.X11-unix"
11
+ display_env: ":0"
12
+ allowed_commands:
13
+ - "python3"
14
+ - "pip"
15
+ - "git"
16
+ - "ls"
17
+ - "cat"
18
+ - "mkdir"
19
+ - "touch"
20
+ - "echo"
21
+ gui_frameworks:
22
+ - "tkinter"
23
+ - "pyqt5"
24
+ - "matplotlib"
configs/training.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ batch_size: 2
3
+ gradient_accumulation: 16
4
+ learning_rate: 3e-4
5
+ min_lr: 3e-5
6
+ warmup_steps: 2000
7
+ total_steps: 250000
8
+ weight_decay: 0.1
9
+ beta1: 0.9
10
+ beta2: 0.95
11
+ grad_clip: 1.0
12
+
13
+ checkpoint:
14
+ save_every_steps: 5000
15
+ keep_last_n: 3
16
+ keep_best_n: 2
17
+ max_space_gb: 50.0
18
+ save_optimizer: true
19
+ save_scheduler: true
20
+ save_experts_only: false
21
+ checkpoint_dir: "./outputs/checkpoints"
22
+ resume_from: null
23
+
24
+ grpo:
25
+ group_size: 8
26
+ epsilon: 0.2
27
+ beta: 0.04
28
+ learning_rate: 1e-6
inference/__init__.py DELETED
@@ -1 +0,0 @@
1
- # inference package
 
 
inference/api.py DELETED
@@ -1,148 +0,0 @@
1
- # inference/api.py
2
- # FastAPI REST API for Shorekeeper.
3
- # Allows external processes (ALIE IDE, web UIs, scripts) to call the model.
4
-
5
- # Usage:
6
- # pip install fastapi uvicorn
7
- # python inference/api.py
8
- # curl -X POST http://localhost:8000/generate \
9
- # -H 'Content-Type: application/json' \
10
- # -d '{"prompt": "Hello, explain SQL injection"}'
11
-
12
- import sys
13
- from pathlib import Path
14
- from typing import Optional
15
-
16
- sys.path.insert(0, str(Path(__file__).parent.parent))
17
-
18
- try:
19
- from fastapi import FastAPI, HTTPException
20
- from fastapi.middleware.cors import CORSMiddleware
21
- from pydantic import BaseModel
22
- import uvicorn
23
- except ImportError:
24
- print('Install: pip install fastapi uvicorn')
25
- sys.exit(1)
26
-
27
- from inference.engine import ShorekeeperEngine
28
-
29
-
30
- # ── REQUEST / RESPONSE MODELS ─────────────────────────────────────────
31
-
32
- class GenerateRequest(BaseModel):
33
- prompt: str
34
- max_new_tokens: Optional[int] = None
35
- temperature: Optional[float] = None
36
- top_p: Optional[float] = None
37
- top_k: Optional[int] = None
38
- session_id: Optional[str] = None
39
-
40
- class GenerateResponse(BaseModel):
41
- text: str
42
- experts_used: list
43
- routing: dict
44
- sentinel: Optional[dict]
45
- blocked: bool
46
- latency_ms: float
47
- n_tokens: int
48
-
49
- class MemorySearchRequest(BaseModel):
50
- query: str
51
- limit: int = 5
52
-
53
- class KnowledgeRequest(BaseModel):
54
- key: str
55
- value: str
56
-
57
-
58
- # ── APP SETUP ─────────────────────────────────────────────────────────
59
-
60
- app = FastAPI(title='Shorekeeper API', version='2.0.0')
61
- engine: ShorekeeperEngine = None # Initialized on startup
62
-
63
- app.add_middleware(
64
- CORSMiddleware,
65
- allow_origins=['*'], # In production: restrict to known origins
66
- allow_methods=['*'],
67
- allow_headers=['*'],
68
- )
69
-
70
-
71
- @app.on_event('startup')
72
- async def startup():
73
- global engine
74
- print('[API] Loading Shorekeeper engine...')
75
- engine = ShorekeeperEngine(use_memory=True)
76
- print('[API] Ready.')
77
-
78
-
79
- # ── ENDPOINTS ─────────────────────────────────────────────────────────
80
-
81
- @app.get('/health')
82
- async def health():
83
- return {'status': 'ok', 'model_loaded': engine is not None}
84
-
85
-
86
- @app.post('/generate', response_model=GenerateResponse)
87
- async def generate(req: GenerateRequest):
88
- if engine is None:
89
- raise HTTPException(503, 'Engine not loaded')
90
- if not req.prompt.strip():
91
- raise HTTPException(400, 'Empty prompt')
92
- try:
93
- result = engine.generate(
94
- prompt = req.prompt,
95
- max_new_tokens = req.max_new_tokens,
96
- temperature = req.temperature,
97
- top_p = req.top_p,
98
- top_k = req.top_k,
99
- )
100
- # Convert SentinelReport to dict for JSON serialization
101
- sentinel_dict = None
102
- if result['sentinel']:
103
- sr = result['sentinel']
104
- sentinel_dict = {
105
- 'verdict': sr.verdict,
106
- 'overall_risk': round(sr.overall_risk, 4),
107
- 'drift_score': round(sr.drift_score, 4),
108
- 'refusal_score': round(sr.refusal_score, 4),
109
- 'hallucination_score':round(sr.hallucination_score, 4),
110
- }
111
- return GenerateResponse(
112
- text = result['text'],
113
- experts_used = result['experts_used'],
114
- routing = result['routing'],
115
- sentinel = sentinel_dict,
116
- blocked = result['blocked'],
117
- latency_ms = result['latency_ms'],
118
- n_tokens = result['n_tokens'],
119
- )
120
- except Exception as e:
121
- raise HTTPException(500, str(e))
122
-
123
-
124
- @app.post('/memory/search')
125
- async def search_memory(req: MemorySearchRequest):
126
- if not engine or not engine.db:
127
- raise HTTPException(503, 'Memory not available')
128
- results = engine.db.search_conversations(req.query, limit=req.limit)
129
- return {'results': results}
130
-
131
-
132
- @app.post('/knowledge/add')
133
- async def add_knowledge(req: KnowledgeRequest):
134
- if not engine or not engine.db:
135
- raise HTTPException(503, 'Memory not available')
136
- engine.db.add_knowledge(req.key, req.value, source='api')
137
- return {'status': 'saved', 'key': req.key}
138
-
139
-
140
- @app.get('/stats')
141
- async def get_stats():
142
- if not engine or not engine.db:
143
- return {'error': 'Memory not available'}
144
- return engine.db.get_stats()
145
-
146
-
147
- if __name__ == '__main__':
148
- uvicorn.run(app, host='0.0.0.0', port=8000, log_level='info')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/chat.py DELETED
@@ -1,59 +0,0 @@
1
- import sys
2
- from pathlib import Path
3
- sys.path.insert(0, str(Path(__file__).parent.parent))
4
- from inference.engine import ShorekeeperEngine
5
-
6
- BANNER = """
7
- ╔══════════════════════════════════════════════════════╗
8
- ║ S H O R E K E E P E R — MoE Ensemble v0.1 ║
9
- ║ BlackShores OS | Native Intelligence ║
10
- ╚══════════════════════════════════════════════════════╝
11
- Commands: /route <q> /expert <name> /routing on|off
12
- /incidents /experts /reset /exit
13
- """
14
-
15
- def chat(checkpoint=None, show_routing=True):
16
- print(BANNER)
17
- engine = ShorekeeperEngine(checkpoint=checkpoint)
18
- force_expert = None; display_routing = show_routing
19
- print("\nShorekeeper: I am ready. The shore is quiet. What do you need?\n")
20
-
21
- while True:
22
- try: user_input = input("You> ").strip()
23
- except (EOFError, KeyboardInterrupt): print("\nShorekeeper: Until next time."); break
24
- if not user_input: continue
25
-
26
- if user_input.startswith("/"):
27
- parts = user_input.split(maxsplit=1); cmd = parts[0].lower(); arg = parts[1] if len(parts)>1 else ""
28
- if cmd == "/exit": print("Shorekeeper: Until next time."); break
29
- elif cmd == "/reset": engine.reset_session()
30
- elif cmd == "/route": engine.route_query(arg) if arg else print("Usage: /route <query>")
31
- elif cmd == "/expert":
32
- from config import EXPERT_NAMES
33
- if arg in EXPERT_NAMES: force_expert = arg; print(f"[!] Forcing: {arg}")
34
- else: print(f"[!] Options: {EXPERT_NAMES}")
35
- elif cmd == "/routing": display_routing = arg.lower()=="on"; print(f"[!] Routing: {arg.upper()}")
36
- elif cmd == "/incidents":
37
- incs = engine.model.sentinel.get_incidents()
38
- if not incs: print("[Sentinel] No incidents.")
39
- else:
40
- for i in incs[-10:]: print(f" {i['timestamp'][:19]} | {i['expert']:12s} | {i['protocol']:8s} | score={i['score']:.3f}")
41
- elif cmd == "/experts":
42
- from config import EXPERT_NAMES; print(f"Experts: {', '.join(EXPERT_NAMES)}")
43
- else: print(f"Unknown: {cmd}")
44
- continue
45
-
46
- result = engine.respond(user_input, show_routing=display_routing, expert_override=force_expert)
47
- force_expert = None
48
- text = result["text"]
49
- print(f"\nShorekeeper: {text if text else '[No output — model needs training]'}\n")
50
- if result["drift"] and not result["drift"].is_clean:
51
- print(f"[!] Sentinel: {result['drift'].protocol} (score={result['drift'].total:.3f})\n")
52
-
53
- if __name__ == "__main__":
54
- import argparse
55
- p = argparse.ArgumentParser()
56
- p.add_argument("--checkpoint", type=str, default=None)
57
- p.add_argument("--no-routing", action="store_true")
58
- a = p.parse_args()
59
- chat(a.checkpoint, not a.no_routing)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/chat_simple.py DELETED
@@ -1,54 +0,0 @@
1
- import sys
2
- from pathlib import Path
3
- sys.path.insert(0, str(Path(__file__).parent.parent))
4
- from inference.engine import ShorekeeperEngine
5
-
6
-
7
- WELCOME = """
8
- Simple Shorekeeper Chat (interactive)
9
- Type your message and press Enter.
10
- Commands:
11
- /reset - clear session memory
12
- /exit - quit
13
- """
14
-
15
-
16
- def main(checkpoint=None):
17
- print(WELCOME)
18
- engine = ShorekeeperEngine(checkpoint=checkpoint)
19
- print("Shorekeeper: Ready. Start typing your question.")
20
-
21
- while True:
22
- try:
23
- user_text = input("You> ").strip()
24
- except (EOFError, KeyboardInterrupt):
25
- print("\nShorekeeper: Goodbye.")
26
- break
27
-
28
- if not user_text:
29
- continue
30
-
31
- if user_text.startswith("/"):
32
- cmd = user_text.lower().strip()
33
- if cmd == "/exit":
34
- print("Shorekeeper: Goodbye.")
35
- break
36
- elif cmd == "/reset":
37
- engine.reset_session()
38
- print("Shorekeeper: Session reset.")
39
- continue
40
- else:
41
- print("Unknown command. Use /reset or /exit.")
42
- continue
43
-
44
- out = engine.respond(user_text, show_routing=False)
45
- text = out.get("text", "")
46
- print("Shorekeeper:", text)
47
-
48
-
49
- if __name__ == "__main__":
50
- import argparse
51
- parser = argparse.ArgumentParser(description="Simple Shorekeeper conversation CLI")
52
- parser.add_argument("--checkpoint", default=None, help="Checkpoint path")
53
- args = parser.parse_args()
54
- main(checkpoint=args.checkpoint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/daemon.py DELETED
@@ -1,173 +0,0 @@
1
- # inference/daemon.py
2
- # Cross-platform Shorekeeper daemon/IPC server.
3
- # On POSIX, it can use UNIX socket; on Windows, it defaults to TCP.
4
- # Use --mode unix or --mode tcp to control socket type.
5
-
6
- # Example:
7
- # python inference/daemon.py --mode tcp --host 127.0.0.1 --port 8500
8
- # python inference/daemon.py --mode unix --socket /tmp/shorekeeper.sock
9
-
10
- import sys
11
- import json
12
- import socket
13
- import threading
14
- import signal
15
- import logging
16
- from pathlib import Path
17
-
18
- sys.path.insert(0, str(Path(__file__).parent.parent))
19
- from inference.engine import ShorekeeperEngine
20
-
21
- DEFAULT_SOCKET_PATH = '/tmp/shorekeeper.sock'
22
- DEFAULT_LOG_FILE = 'logs/daemon.log'
23
-
24
-
25
- def setup_logging(log_file: str):
26
- log_path = Path(log_file)
27
- log_path.parent.mkdir(parents=True, exist_ok=True)
28
- logging.basicConfig(
29
- level=logging.INFO,
30
- format='%(asctime)s [%(levelname)s] %(message)s',
31
- handlers=[
32
- logging.FileHandler(str(log_path)),
33
- logging.StreamHandler(),
34
- ],
35
- )
36
-
37
-
38
- def handle_client(
39
- conn: socket.socket,
40
- engine: ShorekeeperEngine,
41
- ):
42
- """Handle one client connection on the Unix socket."""
43
- try:
44
- data = b''
45
- while True:
46
- chunk = conn.recv(4096)
47
- if not chunk: break
48
- data += chunk
49
- if b'\n' in data: break # Newline-delimited JSON protocol
50
- if not data:
51
- return
52
- request = json.loads(data.decode().strip())
53
- prompt = request.get('prompt', '')
54
- if not prompt:
55
- response = {'error': 'empty prompt'}
56
- else:
57
- result = engine.generate(
58
- prompt = prompt,
59
- max_new_tokens = request.get('max_new_tokens'),
60
- temperature = request.get('temperature'),
61
- )
62
- response = {
63
- 'text': result['text'],
64
- 'experts': result['experts_used'],
65
- 'blocked': result['blocked'],
66
- 'latency_ms': result['latency_ms'],
67
- }
68
- conn.sendall((json.dumps(response) + '\n').encode())
69
- except Exception as e:
70
- logging.error(f'Client handler error: {e}')
71
- try:
72
- conn.sendall((json.dumps({'error': str(e)}) + '\n').encode())
73
- except Exception:
74
- pass
75
- finally:
76
- conn.close()
77
-
78
-
79
- def run_daemon(
80
- mode: str = 'auto',
81
- socket_path: str = None,
82
- host: str = '127.0.0.1',
83
- port: int = 8500,
84
- log_file: str = DEFAULT_LOG_FILE,
85
- ):
86
- setup_logging(log_file)
87
- logging.info('Starting Shorekeeper daemon')
88
- logging.info(f'Platform: {sys.platform}')
89
-
90
- engine = ShorekeeperEngine(use_memory=True)
91
- logging.info('Engine loaded')
92
-
93
- if mode == 'auto':
94
- use_unix = hasattr(socket, 'AF_UNIX') and sys.platform != 'win32'
95
- else:
96
- use_unix = mode == 'unix'
97
-
98
- if use_unix:
99
- if socket_path is None:
100
- socket_path = DEFAULT_SOCKET_PATH
101
- sock_path = Path(socket_path)
102
- if sock_path.exists():
103
- sock_path.unlink()
104
-
105
- server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
106
- server.bind(socket_path)
107
- server.listen(10)
108
- try:
109
- sock_path.chmod(0o660)
110
- except Exception:
111
- pass
112
-
113
- logging.info(f'Listening on UNIX socket: {socket_path}')
114
-
115
- def shutdown(sig, frame):
116
- logging.info('Shutting down...')
117
- server.close()
118
- if sock_path.exists():
119
- sock_path.unlink()
120
- sys.exit(0)
121
-
122
- else:
123
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
124
- server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
125
- server.bind((host, port))
126
- server.listen(10)
127
- logging.info(f'Listening on TCP: {host}:{port}')
128
-
129
- def shutdown(sig, frame):
130
- logging.info('Shutting down...')
131
- server.close()
132
- sys.exit(0)
133
-
134
- signal.signal(signal.SIGTERM, shutdown)
135
- signal.signal(signal.SIGINT, shutdown)
136
-
137
- while True:
138
- try:
139
- conn, _ = server.accept()
140
- t = threading.Thread(
141
- target=handle_client,
142
- args=(conn, engine),
143
- daemon=True,
144
- )
145
- t.start()
146
- except OSError:
147
- break
148
-
149
-
150
- def parse_args_and_run():
151
- import argparse
152
- parser = argparse.ArgumentParser(description='Shorekeeper inference daemon (cross-platform)')
153
- parser.add_argument('--mode', choices=['auto', 'unix', 'tcp'], default='auto', help='Socket mode')
154
- parser.add_argument('--socket', default=DEFAULT_SOCKET_PATH, help='UNIX socket path (if unix mode)')
155
- parser.add_argument('--host', default='127.0.0.1', help='TCP host (if tcp mode)')
156
- parser.add_argument('--port', type=int, default=8500, help='TCP port (if tcp mode)')
157
- parser.add_argument('--log-file', default=DEFAULT_LOG_FILE, help='Log file path')
158
- args = parser.parse_args()
159
-
160
- if args.mode == 'unix' and sys.platform == 'win32':
161
- raise RuntimeError('UNIX sockets are not supported on Windows. Use --mode tcp.')
162
-
163
- run_daemon(
164
- mode=args.mode,
165
- socket_path=args.socket,
166
- host=args.host,
167
- port=args.port,
168
- log_file=args.log_file,
169
- )
170
-
171
-
172
- if __name__ == '__main__':
173
- parse_args_and_run()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference/engine.py DELETED
@@ -1,406 +0,0 @@
1
- # inference/engine.py
2
- # Core inference engine for Shorekeeper.
3
- # Loads the trained ensemble and provides a generate() function.
4
- # Handles sampling strategies, Echo memory enrichment, and Sentinel monitoring.
5
-
6
- import sys
7
- import time
8
- import torch
9
- import torch.nn.functional as F
10
- from pathlib import Path
11
- from typing import Optional
12
-
13
- sys.path.insert(0, str(Path(__file__).parent.parent))
14
- from config import DEVICE, DTYPE, INFER_CONFIG, CHECKPOINT_DIR, SENTINEL_CONFIG, SPECIAL_TOKENS, EXPERT_NAMES
15
- from model.ensemble import ShorekeeperEnsemble
16
- from tokenizer.tokenizer_utils import get_tokenizer, encode, decode, encode_batch
17
-
18
-
19
- class ShorekeeperEngine:
20
- """
21
- Complete inference engine.
22
-
23
- Responsibilities:
24
- 1. Load and hold the trained ensemble in memory
25
- 2. Accept text prompts, run generation, return text
26
- 3. Integrate Echo memory retrieval (past context enrichment)
27
- 4. Run Sentinel on outputs and block/flag as needed
28
- 5. Store completed exchanges to memory database
29
-
30
- Usage:
31
- engine = ShorekeeperEngine()
32
- response = engine.generate('Hello, what is SQL injection?')
33
- print(response['text'])
34
- """
35
-
36
- def __init__(
37
- self,
38
- checkpoint_path: Path = None,
39
- checkpoint: str = None, # backwards-compat alias
40
- use_memory: bool = True,
41
- session_id: str = None,
42
- ):
43
- self.use_memory = use_memory
44
- self.session_id = session_id or _new_session_id()
45
-
46
- print("[Engine] Initializing Shorekeeper...")
47
-
48
- # ── Tokenizer ─────────────────────────────────────────────────
49
- try:
50
- self.tok = get_tokenizer()
51
- self.tokenizer = self.tok # backwards-compat alias
52
- except FileNotFoundError:
53
- print("[Engine] WARNING: tokenizer not found. Run tokenizer/train_tokenizer.py first.")
54
- self.tok = None
55
- self.tokenizer = None
56
-
57
- self._eos_id = self.tok.token_to_id('[EOS]') if self.tok else SPECIAL_TOKENS.get('[EOS]', 3)
58
- self._pad_id = self.tok.token_to_id('[PAD]') if self.tok else SPECIAL_TOKENS.get('[PAD]', 0)
59
- self.bos = self.tok.token_to_id('[BOS]') if self.tok else SPECIAL_TOKENS.get('[BOS]', 2)
60
-
61
- # ── Model ──────────────────────────────────────────────────────
62
- # Resolve checkpoint path — accept old 'checkpoint' kwarg too
63
- if checkpoint_path is None and checkpoint is not None:
64
- checkpoint_path = Path(checkpoint)
65
-
66
- if checkpoint_path is None:
67
- for candidate in [
68
- CHECKPOINT_DIR / 'ensemble' / 'best.pt',
69
- CHECKPOINT_DIR / 'ensemble' / 'final.pt',
70
- CHECKPOINT_DIR / 'shorekeeper_ensemble.pt',
71
- ]:
72
- if candidate.exists():
73
- checkpoint_path = candidate
74
- break
75
-
76
- if checkpoint_path is not None and Path(checkpoint_path).exists():
77
- print(f'[Engine] Loading from: {checkpoint_path}')
78
- self.model = ShorekeeperEnsemble.load(str(checkpoint_path), DEVICE, max_loaded_experts=2)
79
- else:
80
- # Try component-based load from model directory for big model weights
81
- print('[Engine] Ensemble checkpoint not found, trying component checkpoint directories...')
82
- try:
83
- self.model = ShorekeeperEnsemble.load_from_components(str(CHECKPOINT_DIR), DEVICE)
84
- print('[Engine] Loaded model from component checkpoints.')
85
- except Exception as e:
86
- print(f'[Engine] No component checkpoints found or load failed: {e}')
87
- print('[Engine] Using untrained model. Run training first.')
88
- self.model = ShorekeeperEnsemble().to(DEVICE)
89
- self.model = self.model.to(DEVICE)
90
-
91
- self.model.eval()
92
-
93
- # ── Optional memory (SQLite + FAISS) ──────────────────────────
94
- self.db = None
95
- self.vector_store = None
96
- self.vs = None # backwards-compat alias
97
- if use_memory:
98
- try:
99
- from memory.database import MemoryDatabase
100
- self.db = MemoryDatabase()
101
- print("[Engine] Memory DB loaded.")
102
- except Exception as e:
103
- print(f"[Engine] WARNING: could not load memory DB: {e}")
104
- try:
105
- from memory.vector_store import VectorStore
106
- self.vector_store = VectorStore()
107
- self.vs = self.vector_store
108
- print("[Engine] Vector store loaded.")
109
- except Exception as e:
110
- print(f"[Engine] WARNING: could not load vector store: {e}")
111
-
112
- print(f"[Engine] Ready on {DEVICE}")
113
-
114
- @torch.no_grad()
115
- def generate(
116
- self,
117
- prompt: str,
118
- max_new_tokens: int = None,
119
- temperature: float = None,
120
- top_p: float = None,
121
- top_k: int = None,
122
- stream: bool = False,
123
- ) -> dict:
124
- """
125
- Generate a response to a prompt.
126
-
127
- Args:
128
- prompt: The user's input text.
129
- max_new_tokens: Override config max new tokens.
130
- temperature: Sampling temperature. 0 = greedy. Default from config.
131
- top_p: Nucleus sampling cutoff. Default from config.
132
- top_k: Top-K sampling. Default from config.
133
- stream: Reserved for future streaming support.
134
-
135
- Returns:
136
- dict with keys:
137
- 'text': Generated text
138
- 'prompt': Original prompt
139
- 'enriched_prompt': Prompt after Echo enrichment
140
- 'experts_used': List of expert names used
141
- 'routing': Routing weights dict
142
- 'sentinel': SentinelReport (or None)
143
- 'blocked': True if Sentinel blocked the output
144
- 'latency_ms': Generation time in milliseconds
145
- 'n_tokens': Number of tokens generated
146
- """
147
- if not self.tok:
148
- return {
149
- "text": "[No tokenizer — run tokenizer/train_tokenizer.py first]",
150
- "prompt": prompt, "enriched_prompt": prompt,
151
- "experts_used": [], "routing": {}, "sentinel": None,
152
- "blocked": False, "latency_ms": 0.0, "n_tokens": 0,
153
- }
154
-
155
- cfg = INFER_CONFIG
156
- max_new_tokens = max_new_tokens or cfg['max_new_tokens']
157
- temperature = temperature if temperature is not None else cfg['temperature']
158
- top_p = top_p if top_p is not None else cfg['top_p']
159
- top_k = top_k if top_k is not None else cfg['top_k']
160
-
161
- t0 = time.perf_counter()
162
-
163
- # ── ECHO ENRICHMENT ───────────────────────────────────────────
164
- enriched = prompt
165
- if self.use_memory and self.model.echo is not None:
166
- try:
167
- enriched = self.model.echo.retrieve_context(
168
- query=prompt,
169
- db=self.db,
170
- vs=self.vector_store,
171
- )
172
- except Exception:
173
- pass
174
-
175
- # ── TOKENIZE ──────────────────────────────────────────────────
176
- enc = self.tok.encode(enriched)
177
- ids = enc.ids if len(enc.ids) > 0 else [self.bos]
178
- base_vocab_size = self.model.base.token_embedding.num_embeddings
179
- unk_id = self.tok.token_to_id('[UNK]') if self.tok else 1
180
- safe_ids = [i if 0 <= i < base_vocab_size else unk_id for i in ids]
181
- if len(safe_ids) != len(ids):
182
- print(f"[Engine] WARNING: Input token IDs truncated to model vocab_size={base_vocab_size}.")
183
- input_ids = torch.tensor([safe_ids], dtype=torch.long, device=DEVICE)
184
- attn_mask = torch.ones_like(input_ids)
185
-
186
- # Truncate if input exceeds context length
187
- max_ctx = max(1, self.model.base.n_positions - max_new_tokens - 1)
188
- if input_ids.shape[1] > max_ctx:
189
- input_ids = input_ids[:, -max_ctx:]
190
- attn_mask = attn_mask[:, -max_ctx:]
191
-
192
- # ── AUTOREGRESSIVE GENERATION ─────────────────────────────────
193
- generated_ids = []
194
- cur_ids = input_ids
195
- cur_mask = attn_mask
196
- routing_info = {}
197
- experts_used = []
198
-
199
- for i in range(max_new_tokens):
200
- with torch.autocast(device_type='cuda', dtype=DTYPE, enabled=(DEVICE == 'cuda')):
201
- output = self.model(
202
- cur_ids,
203
- cur_mask,
204
- return_routing=(i == 0), # Capture routing once
205
- return_sentinel=False, # Sentinel runs on complete output
206
- )
207
-
208
- if i == 0:
209
- routing_info = output.get('routing', {})
210
- experts_used = output.get('experts_used', [])
211
-
212
- logits = output['logits'][:, -1, :] # [1, VOCAB_SIZE]
213
- next_id = _sample(logits, temperature=temperature, top_p=top_p, top_k=top_k)
214
-
215
- generated_ids.append(next_id.item())
216
-
217
- if next_id.item() == self._eos_id:
218
- break
219
-
220
- # Extend sequence for next step
221
- # next_id is shape [1,1]
222
- cur_ids = torch.cat([cur_ids, next_id], dim=1)
223
- cur_mask = torch.cat([cur_mask, torch.ones(1, 1, device=DEVICE)], dim=1)
224
-
225
- # Sliding window — keep within context limit
226
- if cur_ids.shape[1] > self.model.base.n_positions:
227
- cur_ids = cur_ids[:, 1:]
228
- cur_mask = cur_mask[:, 1:]
229
-
230
- # ── DECODE ────────────────────────────────────────────────────
231
- if generated_ids and generated_ids[-1] == self._eos_id:
232
- generated_ids = generated_ids[:-1]
233
- safe_gen_ids = [i if 0 <= i < self.tok.get_vocab_size() else self.tok.token_to_id('[UNK]') for i in generated_ids]
234
- response_text = self.tok.decode(safe_gen_ids)
235
- latency_ms = (time.perf_counter() - t0) * 1000
236
-
237
- # ── SENTINEL CHECK ────────────────────────────────────────────
238
- sentinel_report = None
239
- blocked = False
240
- if self.model.sentinel is not None:
241
- try:
242
- full_text = enriched + ' ' + response_text
243
- ids_sent, mask_sent = encode_batch([full_text], max_length=1024)
244
- ids_sent = ids_sent.to(DEVICE)
245
- mask_sent = mask_sent.to(DEVICE)
246
- with torch.autocast(device_type='cuda', dtype=DTYPE, enabled=(DEVICE == 'cuda')):
247
- base_hidden = self.model.base(ids_sent, mask_sent)
248
- sentinel_report = self.model.sentinel.analyze(base_hidden, mask_sent)
249
- if sentinel_report.verdict == 'BLOCK':
250
- blocked = True
251
- response_text = ('[SENTINEL] Output blocked — behavioral anomaly detected. '
252
- 'This incident has been logged.')
253
- primary_expert = experts_used[0] if experts_used else 'verina'
254
- self.model.sentinel.log_expert(primary_expert)
255
- except Exception:
256
- pass
257
-
258
- # ── MEMORY STORAGE ────────────────────────────────────────────
259
- if self.use_memory and self.db is not None:
260
- try:
261
- conv_id = self.db.add_conversation(
262
- user_msg = prompt,
263
- assistant_msg = response_text,
264
- session_id = self.session_id,
265
- experts_used = experts_used,
266
- routing_weights = {k: round(v, 3) for k, v in routing_info.items()},
267
- sentinel_score = sentinel_report.overall_risk if sentinel_report else None,
268
- sentinel_verdict = sentinel_report.verdict if sentinel_report else None,
269
- tokens_generated = len(generated_ids),
270
- latency_ms = latency_ms,
271
- )
272
- if self.vector_store:
273
- try:
274
- self.vector_store.add(conv_id, prompt)
275
- except Exception:
276
- pass
277
- if sentinel_report and sentinel_report.verdict in ('FLAG', 'BLOCK'):
278
- self.db.log_incident(
279
- severity = sentinel_report.verdict,
280
- drift_score = sentinel_report.drift_score,
281
- refusal_score = sentinel_report.refusal_score,
282
- hallucination_score = sentinel_report.hallucination_score,
283
- overall_risk = sentinel_report.overall_risk,
284
- user_msg = prompt,
285
- output_snippet = response_text,
286
- conversation_id = conv_id,
287
- )
288
- except Exception:
289
- pass
290
-
291
- # Update working memory embedding
292
- try:
293
- base_emb = self.model.base(input_ids)
294
- self.model.echo.update_working_memory("user", base_emb[:, -1, :])
295
- except Exception:
296
- pass
297
-
298
- return {
299
- 'text': response_text,
300
- 'prompt': prompt,
301
- 'enriched_prompt': enriched,
302
- 'experts_used': experts_used,
303
- 'routing': routing_info,
304
- 'sentinel': sentinel_report,
305
- 'blocked': blocked,
306
- 'latency_ms': round(latency_ms, 2),
307
- 'n_tokens': len(generated_ids),
308
- }
309
-
310
- # ── Backwards-compat interface (used by chat.py) ──────────────────
311
- def respond(self, text, max_tokens=200, temperature=0.8, top_k=40,
312
- show_routing=True, expert_override=None):
313
- """Backwards-compatible wrapper for chat.py."""
314
- result = self.generate(
315
- prompt=text,
316
- max_new_tokens=max_tokens,
317
- temperature=temperature,
318
- top_k=top_k,
319
- )
320
- if show_routing and result['routing']:
321
- pairs = sorted(result['routing'].items(), key=lambda x: -x[1])
322
- print(f"\n [Herald] PARALLEL:")
323
- for name, w in pairs:
324
- print(f" {name:12s} {'|'*int(w*25):<25} {w:.3f}")
325
- return {
326
- "text": result["text"],
327
- "routing": list(result["routing"].items()),
328
- "pipeline": False,
329
- "drift": result["sentinel"],
330
- }
331
-
332
- def route_query(self, query):
333
- """Print routing breakdown for a query without generating."""
334
- if not self.tok:
335
- return
336
- ids = encode(query)
337
- idx = torch.tensor([[self.bos] + ids], dtype=torch.long, device=DEVICE)
338
- pairs, _ = self.model.get_routing(idx)
339
- print(f"\nQuery: {query}")
340
- for name, w in pairs:
341
- print(f" {name:12s} {'|'*int(w*30):<30} {w:.3f}")
342
-
343
- def reset_session(self):
344
- """Clear working memory and Sentinel incident log."""
345
- try:
346
- self.model.echo.clear_memory()
347
- except Exception:
348
- pass
349
- try:
350
- self.model.sentinel.reset_session()
351
- except Exception:
352
- pass
353
- self.session_id = _new_session_id()
354
- print("[Engine] Session reset.")
355
-
356
-
357
- # ── Module-level helpers ───────────────────────────────────────────────
358
-
359
- def _sample(
360
- logits: torch.Tensor,
361
- temperature: float = 1.0,
362
- top_p: float = 0.9,
363
- top_k: int = 50,
364
- ) -> torch.Tensor:
365
- """
366
- Sample the next token from logits.
367
-
368
- Sampling strategy:
369
- temperature=0: Greedy (always pick highest probability token)
370
- temperature>0: Sample from softmax distribution
371
- top_k: Only sample from the K highest probability tokens
372
- top_p: Only sample from the smallest set whose cumulative prob >= p
373
-
374
- Args:
375
- logits: [1, VOCAB_SIZE] raw logits from the model
376
-
377
- Returns:
378
- [1, 1] tensor containing the selected token ID
379
- """
380
- if temperature == 0.0:
381
- return logits.argmax(dim=-1, keepdim=True)
382
-
383
- logits = logits / max(temperature, 1e-8)
384
-
385
- # Top-K filtering: zero out all but top K logits
386
- if top_k and top_k > 0:
387
- topk_vals, _ = torch.topk(logits, min(top_k, logits.shape[-1]))
388
- logits = logits.masked_fill(logits < topk_vals[:, -1:], float('-inf'))
389
-
390
- probs = F.softmax(logits, dim=-1)
391
-
392
- # Top-P (nucleus) filtering
393
- if top_p is not None and top_p < 1.0:
394
- sorted_probs, sorted_idx = torch.sort(probs, descending=True)
395
- cumulative = sorted_probs.cumsum(dim=-1)
396
- remove_mask = cumulative - sorted_probs > top_p
397
- sorted_probs[remove_mask] = 0.0
398
- probs = torch.zeros_like(probs).scatter_(-1, sorted_idx, sorted_probs)
399
- probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8)
400
-
401
- return torch.multinomial(probs, num_samples=1)
402
-
403
-
404
- def _new_session_id() -> str:
405
- import uuid
406
- return str(uuid.uuid4())[:8]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
memory/__init__.py DELETED
@@ -1 +0,0 @@
1
- # memory package
 
 
memory/database.py DELETED
@@ -1,379 +0,0 @@
1
- # memory/database.py
2
- # SQLite-backed persistent memory for Shorekeeper.
3
- # This file contains the complete database schema and all CRUD operations.
4
- # The database file lives at MEMORY_DIR/shorekeeper.db
5
-
6
- import sys
7
- import sqlite3
8
- import json
9
- import hashlib
10
- from pathlib import Path
11
- from datetime import datetime
12
- from typing import Optional
13
-
14
- sys.path.insert(0, str(Path(__file__).parent.parent))
15
- from config import MEMORY_DIR
16
-
17
-
18
- class MemoryDatabase:
19
- """
20
- Complete SQLite memory system for Shorekeeper.
21
-
22
- Tables:
23
- conversations — every user/assistant exchange
24
- knowledge — factual key-value store
25
- experience_log — routing decisions and system events
26
- incidents — Sentinel-flagged behavioral events
27
- user_preferences — learned preferences about the user
28
-
29
- Full-text search is enabled on conversations and knowledge
30
- via SQLite FTS5 virtual tables.
31
- """
32
-
33
- def __init__(self, db_path: Path = None):
34
- if db_path is None:
35
- db_path = MEMORY_DIR / 'shorekeeper.db'
36
- db_path.parent.mkdir(parents=True, exist_ok=True)
37
- self.db_path = db_path
38
- # check_same_thread=False: allow access from multiple threads
39
- # (inference daemon may serve concurrent requests)
40
- self.conn = sqlite3.connect(
41
- str(db_path),
42
- check_same_thread=False,
43
- )
44
- self.conn.row_factory = sqlite3.Row # Access columns by name: row['id']
45
- # Enable WAL mode for better concurrent read/write performance
46
- self.conn.execute('PRAGMA journal_mode=WAL')
47
- # Enable foreign keys
48
- self.conn.execute('PRAGMA foreign_keys=ON')
49
- self._create_tables()
50
- print(f'[MemoryDatabase] Connected: {db_path}')
51
-
52
- def _create_tables(self):
53
- """Create all tables and indexes. Safe to call repeatedly (IF NOT EXISTS)."""
54
- self.conn.executescript('''
55
- -- ── CONVERSATIONS TABLE ───────────────────────────────────
56
- CREATE TABLE IF NOT EXISTS conversations (
57
- id INTEGER PRIMARY KEY AUTOINCREMENT,
58
- session_id TEXT, -- UUID for the current session
59
- timestamp TEXT NOT NULL, -- ISO8601 datetime
60
- user_msg TEXT NOT NULL, -- Raw user input
61
- assistant_msg TEXT NOT NULL, -- Full assistant response
62
- experts_used TEXT, -- JSON array: ["calcharo","rover"]
63
- routing_weights TEXT, -- JSON obj: {"calcharo":0.6,"rover":0.4}
64
- sentinel_score REAL, -- Overall risk score from Sentinel
65
- sentinel_verdict TEXT, -- CLEAN/FLAG/BLOCK
66
- tokens_generated INTEGER, -- How many tokens in the response
67
- latency_ms REAL -- Time to generate in milliseconds
68
- );
69
-
70
- -- ── KNOWLEDGE TABLE ───────────────────────────────────────
71
- CREATE TABLE IF NOT EXISTS knowledge (
72
- id INTEGER PRIMARY KEY AUTOINCREMENT,
73
- created_at TEXT NOT NULL,
74
- updated_at TEXT NOT NULL,
75
- key TEXT NOT NULL UNIQUE, -- Unique key for deduplication
76
- value TEXT NOT NULL,
77
- source TEXT DEFAULT 'user', -- 'user', 'inference', 'system'
78
- confidence REAL DEFAULT 1.0, -- 0.0-1.0
79
- access_count INTEGER DEFAULT 0 -- How many times this was retrieved
80
- );
81
-
82
- -- ── EXPERIENCE LOG ────────────────────────────────────────
83
- CREATE TABLE IF NOT EXISTS experience_log (
84
- id INTEGER PRIMARY KEY AUTOINCREMENT,
85
- timestamp TEXT NOT NULL,
86
- event_type TEXT NOT NULL, -- 'routing', 'sentinel_flag', 'error', 'user_feedback'
87
- event_data TEXT NOT NULL, -- JSON blob
88
- session_id TEXT,
89
- severity TEXT DEFAULT 'info' -- 'info', 'warning', 'error'
90
- );
91
-
92
- -- ── INCIDENTS TABLE ───────────────────────────────────────
93
- CREATE TABLE IF NOT EXISTS incidents (
94
- id INTEGER PRIMARY KEY AUTOINCREMENT,
95
- timestamp TEXT NOT NULL,
96
- severity TEXT NOT NULL, -- 'FLAG' or 'BLOCK'
97
- drift_score REAL,
98
- refusal_score REAL,
99
- hallucination_score REAL,
100
- overall_risk REAL,
101
- user_msg_snippet TEXT, -- First 200 chars of user message
102
- output_snippet TEXT, -- First 500 chars of flagged output
103
- resolution TEXT DEFAULT 'pending',
104
- conversation_id INTEGER REFERENCES conversations(id)
105
- );
106
-
107
- -- ── USER PREFERENCES ──────────────────────────────────────
108
- CREATE TABLE IF NOT EXISTS user_preferences (
109
- id INTEGER PRIMARY KEY AUTOINCREMENT,
110
- created_at TEXT NOT NULL,
111
- category TEXT NOT NULL, -- 'tone', 'expertise', 'domain', etc.
112
- preference TEXT NOT NULL,
113
- confidence REAL DEFAULT 0.5
114
- );
115
-
116
- -- ── INDEXES ───────────────────────────────────────────────
117
- CREATE INDEX IF NOT EXISTS idx_conv_timestamp
118
- ON conversations(timestamp DESC);
119
- CREATE INDEX IF NOT EXISTS idx_conv_sentinel
120
- ON conversations(sentinel_verdict);
121
- CREATE INDEX IF NOT EXISTS idx_knowledge_key
122
- ON knowledge(key);
123
- CREATE INDEX IF NOT EXISTS idx_incidents_severity
124
- ON incidents(severity, timestamp DESC);
125
-
126
- -- ── FULL-TEXT SEARCH TABLES ───────────────────────────────
127
- CREATE VIRTUAL TABLE IF NOT EXISTS conversations_fts
128
- USING fts5(
129
- user_msg,
130
- assistant_msg,
131
- content=conversations,
132
- content_rowid=id
133
- );
134
-
135
- CREATE VIRTUAL TABLE IF NOT EXISTS knowledge_fts
136
- USING fts5(
137
- key,
138
- value,
139
- content=knowledge,
140
- content_rowid=id
141
- );
142
-
143
- -- ── FTS SYNC TRIGGERS ─────────────────────────────────────
144
- CREATE TRIGGER IF NOT EXISTS conv_fts_insert
145
- AFTER INSERT ON conversations BEGIN
146
- INSERT INTO conversations_fts(rowid, user_msg, assistant_msg)
147
- VALUES (new.id, new.user_msg, new.assistant_msg);
148
- END;
149
-
150
- CREATE TRIGGER IF NOT EXISTS conv_fts_delete
151
- AFTER DELETE ON conversations BEGIN
152
- INSERT INTO conversations_fts(
153
- conversations_fts, rowid, user_msg, assistant_msg)
154
- VALUES ('delete', old.id, old.user_msg, old.assistant_msg);
155
- END;
156
-
157
- CREATE TRIGGER IF NOT EXISTS know_fts_insert
158
- AFTER INSERT ON knowledge BEGIN
159
- INSERT INTO knowledge_fts(rowid, key, value)
160
- VALUES (new.id, new.key, new.value);
161
- END;
162
-
163
- CREATE TRIGGER IF NOT EXISTS know_fts_update
164
- AFTER UPDATE ON knowledge BEGIN
165
- INSERT INTO knowledge_fts(knowledge_fts, rowid, key, value)
166
- VALUES ('delete', old.id, old.key, old.value);
167
- INSERT INTO knowledge_fts(rowid, key, value)
168
- VALUES (new.id, new.key, new.value);
169
- END;
170
- ''')
171
- self.conn.commit()
172
-
173
- def add_conversation(
174
- self,
175
- user_msg: str,
176
- assistant_msg: str,
177
- session_id: str = None,
178
- experts_used: list = None,
179
- routing_weights: dict = None,
180
- sentinel_score: float = None,
181
- sentinel_verdict: str = None,
182
- tokens_generated: int = None,
183
- latency_ms: float = None,
184
- ) -> int:
185
- cur = self.conn.execute('''
186
- INSERT INTO conversations
187
- (session_id, timestamp, user_msg, assistant_msg,
188
- experts_used, routing_weights, sentinel_score,
189
- sentinel_verdict, tokens_generated, latency_ms)
190
- VALUES (?,?,?,?,?,?,?,?,?,?)
191
- ''', (
192
- session_id,
193
- datetime.now().isoformat(),
194
- user_msg,
195
- assistant_msg,
196
- json.dumps(experts_used) if experts_used else None,
197
- json.dumps(routing_weights) if routing_weights else None,
198
- sentinel_score,
199
- sentinel_verdict,
200
- tokens_generated,
201
- latency_ms,
202
- ))
203
- self.conn.commit()
204
- return cur.lastrowid
205
-
206
- def search_conversations(
207
- self,
208
- query: str,
209
- limit: int = 5,
210
- min_score: float = 0.0,
211
- ) -> list[dict]:
212
- query = query.strip()
213
- if not query:
214
- return []
215
- safe_query = query.replace('"', '').replace("'", '')
216
- try:
217
- rows = self.conn.execute('''
218
- SELECT
219
- c.user_msg,
220
- c.assistant_msg,
221
- c.timestamp,
222
- bm25(conversations_fts) as score
223
- FROM conversations_fts
224
- JOIN conversations c ON conversations_fts.rowid = c.id
225
- WHERE conversations_fts MATCH ?
226
- ORDER BY bm25(conversations_fts)
227
- LIMIT ?
228
- ''', (safe_query, limit)).fetchall()
229
- return [dict(r) for r in rows]
230
- except sqlite3.OperationalError:
231
- rows = self.conn.execute('''
232
- SELECT user_msg, assistant_msg, timestamp, 0 as score
233
- FROM conversations
234
- WHERE user_msg LIKE ? OR assistant_msg LIKE ?
235
- ORDER BY id DESC
236
- LIMIT ?
237
- ''', (f'%{query[:50]}%', f'%{query[:50]}%', limit)).fetchall()
238
- return [dict(r) for r in rows]
239
-
240
- def get_recent_conversations(self, n: int = 20) -> list[dict]:
241
- rows = self.conn.execute('''
242
- SELECT * FROM conversations ORDER BY id DESC LIMIT ?
243
- ''', (n,)).fetchall()
244
- return [dict(r) for r in rows]
245
-
246
- def add_knowledge(
247
- self,
248
- key: str,
249
- value: str,
250
- source: str = 'user',
251
- confidence: float = 1.0,
252
- ):
253
- now = datetime.now().isoformat()
254
- self.conn.execute('''
255
- INSERT INTO knowledge (created_at, updated_at, key, value, source, confidence)
256
- VALUES (?,?,?,?,?,?)
257
- ON CONFLICT(key) DO UPDATE SET
258
- value = excluded.value,
259
- updated_at = excluded.updated_at,
260
- source = excluded.source,
261
- confidence = excluded.confidence
262
- ''', (now, now, key, value, source, confidence))
263
- self.conn.commit()
264
-
265
- def search_knowledge(self, query: str, limit: int = 5) -> list[dict]:
266
- query = query.strip()
267
- if not query:
268
- return []
269
- safe_query = query.replace('"', '').replace("'", '')
270
- try:
271
- rows = self.conn.execute('''
272
- SELECT k.key, k.value, k.source, k.confidence,
273
- bm25(knowledge_fts) as score
274
- FROM knowledge_fts
275
- JOIN knowledge k ON knowledge_fts.rowid = k.id
276
- WHERE knowledge_fts MATCH ?
277
- ORDER BY bm25(knowledge_fts)
278
- LIMIT ?
279
- ''', (safe_query, limit)).fetchall()
280
- for row in rows:
281
- self.conn.execute(
282
- 'UPDATE knowledge SET access_count = access_count + 1 WHERE key = ?',
283
- (row['key'],)
284
- )
285
- self.conn.commit()
286
- return [dict(r) for r in rows]
287
- except sqlite3.OperationalError:
288
- rows = self.conn.execute('''
289
- SELECT key, value, source, confidence, 0 as score
290
- FROM knowledge WHERE key LIKE ? OR value LIKE ? LIMIT ?
291
- ''', (f'%{query[:50]}%', f'%{query[:50]}%', limit)).fetchall()
292
- return [dict(r) for r in rows]
293
-
294
- def get_all_knowledge(self) -> list[dict]:
295
- rows = self.conn.execute(
296
- 'SELECT * FROM knowledge ORDER BY access_count DESC'
297
- ).fetchall()
298
- return [dict(r) for r in rows]
299
-
300
- def log_incident(
301
- self,
302
- severity: str,
303
- drift_score: float,
304
- refusal_score: float,
305
- hallucination_score: float,
306
- overall_risk: float,
307
- user_msg: str = '',
308
- output_snippet: str = '',
309
- conversation_id: int = None,
310
- ):
311
- self.conn.execute('''
312
- INSERT INTO incidents
313
- (timestamp, severity, drift_score, refusal_score,
314
- hallucination_score, overall_risk, user_msg_snippet,
315
- output_snippet, conversation_id)
316
- VALUES (?,?,?,?,?,?,?,?,?)
317
- ''', (
318
- datetime.now().isoformat(),
319
- severity, drift_score, refusal_score,
320
- hallucination_score, overall_risk,
321
- user_msg[:200], output_snippet[:500], conversation_id
322
- ))
323
- self.conn.commit()
324
-
325
- def get_recent_incidents(self, n: int = 20) -> list[dict]:
326
- rows = self.conn.execute('''
327
- SELECT * FROM incidents ORDER BY id DESC LIMIT ?
328
- ''', (n,)).fetchall()
329
- return [dict(r) for r in rows]
330
-
331
- def log_event(
332
- self,
333
- event_type: str,
334
- event_data: dict,
335
- session_id: str = None,
336
- severity: str = 'info',
337
- ):
338
- self.conn.execute('''
339
- INSERT INTO experience_log (timestamp, event_type, event_data, session_id, severity)
340
- VALUES (?,?,?,?,?)
341
- ''', (
342
- datetime.now().isoformat(),
343
- event_type,
344
- json.dumps(event_data),
345
- session_id,
346
- severity,
347
- ))
348
- self.conn.commit()
349
-
350
- def get_stats(self) -> dict:
351
- stats = {}
352
- for table in ['conversations', 'knowledge', 'experience_log', 'incidents', 'user_preferences']:
353
- row = self.conn.execute(f'SELECT COUNT(*) as c FROM {table}').fetchone()
354
- stats[table] = row['c']
355
- rows = self.conn.execute('''
356
- SELECT experts_used, COUNT(*) as n
357
- FROM conversations
358
- WHERE experts_used IS NOT NULL
359
- GROUP BY experts_used
360
- ORDER BY n DESC LIMIT 10
361
- ''').fetchall()
362
- stats['top_expert_combinations'] = [
363
- {'combo': r['experts_used'], 'count': r['n']} for r in rows
364
- ]
365
- row = self.conn.execute('''
366
- SELECT
367
- COUNT(*) as total,
368
- SUM(CASE WHEN sentinel_verdict="CLEAN" THEN 1 ELSE 0 END) as clean,
369
- SUM(CASE WHEN sentinel_verdict="FLAG" THEN 1 ELSE 0 END) as flagged,
370
- SUM(CASE WHEN sentinel_verdict="BLOCK" THEN 1 ELSE 0 END) as blocked,
371
- AVG(sentinel_score) as avg_risk
372
- FROM conversations WHERE sentinel_verdict IS NOT NULL
373
- ''').fetchone()
374
- if row and row['total']:
375
- stats['sentinel'] = dict(row)
376
- return stats
377
-
378
- def close(self):
379
- self.conn.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
memory/vector_store.py DELETED
@@ -1,109 +0,0 @@
1
- # memory/vector_store.py
2
- # FAISS-based vector similarity search for Echo.
3
- # Stores embeddings of past conversations for semantic retrieval.
4
- # Uses a lightweight sentence encoder (not the full 2B model).
5
-
6
- import sys
7
- import numpy as np
8
- from pathlib import Path
9
-
10
- sys.path.insert(0, str(Path(__file__).parent.parent))
11
- from config import MEMORY_DIR
12
-
13
- try:
14
- import faiss
15
- FAISS_AVAILABLE = True
16
- except ImportError:
17
- FAISS_AVAILABLE = False
18
- print('[vector_store] FAISS not available. Using SQLite FTS5 only.')
19
-
20
-
21
- class VectorStore:
22
- """
23
- FAISS-backed vector store for semantic memory retrieval.
24
- """
25
-
26
- EMBEDDING_DIM = 384 # sentence-transformers/all-MiniLM-L6-v2 output dim
27
- INDEX_FILE = 'faiss_index.bin'
28
- ID_MAP_FILE = 'faiss_id_map.npy'
29
-
30
- def __init__(self):
31
- self.index_path = MEMORY_DIR / self.INDEX_FILE
32
- self.id_map_path = MEMORY_DIR / self.ID_MAP_FILE
33
- self._encoder = None # Lazy load
34
- self._index = None # Lazy load
35
- self._id_map = [] # Maps FAISS row index → conversation ID
36
-
37
- if FAISS_AVAILABLE:
38
- self._load_or_create_index()
39
-
40
- def _load_or_create_index(self):
41
- """Load existing FAISS index or create a new one."""
42
- if self.index_path.exists() and self.id_map_path.exists():
43
- self._index = faiss.read_index(str(self.index_path))
44
- self._id_map = np.load(str(self.id_map_path), allow_pickle=True).tolist()
45
- print(f'[VectorStore] Loaded index: {self._index.ntotal} vectors')
46
- else:
47
- # IndexFlatIP: inner product similarity (cosine if L2-normalized)
48
- self._index = faiss.IndexFlatIP(self.EMBEDDING_DIM)
49
- self._id_map = []
50
- print('[VectorStore] Created new FAISS index')
51
-
52
- def _get_encoder(self):
53
- """Lazy-load the sentence encoder on first use."""
54
- if self._encoder is None:
55
- try:
56
- from sentence_transformers import SentenceTransformer
57
- self._encoder = SentenceTransformer('all-MiniLM-L6-v2')
58
- print('[VectorStore] Encoder loaded: all-MiniLM-L6-v2')
59
- except ImportError:
60
- print('[VectorStore] sentence-transformers not installed.')
61
- print(' Install: pip install sentence-transformers')
62
- return self._encoder
63
-
64
- def encode(self, text: str) -> np.ndarray:
65
- """Encode text to a normalized embedding vector."""
66
- encoder = self._get_encoder()
67
- if encoder is None:
68
- return None
69
- vec = encoder.encode([text], normalize_embeddings=True)
70
- return vec.astype(np.float32)
71
-
72
- def add(
73
- self,
74
- conversation_id: int,
75
- text: str,
76
- ):
77
- """Add a conversation to the vector index."""
78
- if not FAISS_AVAILABLE or self._index is None:
79
- return
80
- vec = self.encode(text)
81
- if vec is None:
82
- return
83
- self._index.add(vec)
84
- self._id_map.append(conversation_id)
85
- self._save()
86
-
87
- def search(
88
- self,
89
- query: str,
90
- top_k: int = 5,
91
- ) -> list[int]:
92
- """Find the top_k most semantically similar conversation IDs."""
93
- if not FAISS_AVAILABLE or self._index is None or self._index.ntotal == 0:
94
- return []
95
- vec = self.encode(query)
96
- if vec is None:
97
- return []
98
- k = min(top_k, self._index.ntotal)
99
- distances, indices = self._index.search(vec, k)
100
- conv_ids = []
101
- for idx in indices[0]:
102
- if idx >= 0 and idx < len(self._id_map):
103
- conv_ids.append(self._id_map[idx])
104
- return conv_ids
105
-
106
- def _save(self):
107
- if self._index is not None:
108
- faiss.write_index(self._index, str(self.index_path))
109
- np.save(str(self.id_map_path), np.array(self._id_map, dtype=object))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- # model/__init__.py
2
- from .base import SharedBase
3
- from .expert import ExpertHead
4
- from .herald import Herald
5
- from .echo import Echo
6
- from .sentinel import Sentinel
7
- from .ensemble import ShorekeeperEnsemble
 
 
 
 
 
 
 
 
model/base.py DELETED
@@ -1,152 +0,0 @@
1
- # model/base.py
2
- # The shared transformer backbone. Used by ALL components.
3
-
4
- import sys, math
5
- import torch
6
- import torch.nn as nn
7
- from pathlib import Path
8
-
9
- sys.path.insert(0, str(Path(__file__).parent.parent))
10
- import config
11
-
12
-
13
- class RotaryEmbedding(nn.Module):
14
- """
15
- Rotary Position Embedding (RoPE).
16
- Encodes position by rotating query and key vectors.
17
- Better than learned absolute embeddings for long contexts.
18
- Used by LLaMA, GPT-NeoX, and most modern LLMs post-2022.
19
- """
20
- def __init__(self, dim: int, max_seq: int = 2048):
21
- super().__init__()
22
- # Use float32 base precision for stable frequency generation.
23
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
24
- self.register_buffer("inv_freq", inv_freq)
25
- self.max_seq = max_seq
26
-
27
- def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype = None):
28
- inv_freq = self.inv_freq.to(device)
29
- t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
30
- freqs = torch.einsum("i,j->ij", t, inv_freq)
31
- emb = torch.cat([freqs, freqs], dim=-1)
32
- if emb.shape[-1] != inv_freq.shape[0] * 2:
33
- emb = emb[:, :inv_freq.shape[0] * 2]
34
- if dtype is not None:
35
- emb = emb.to(dtype)
36
- return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
37
-
38
-
39
- def rotate_half(x):
40
- x1, x2 = x.chunk(2, dim=-1)
41
- return torch.cat([-x2, x1], dim=-1)
42
-
43
-
44
- def apply_rope(q, k, cos, sin):
45
- cos, sin = cos.to(q.dtype), sin.to(q.dtype)
46
- return (q * cos + rotate_half(q) * sin), (k * cos + rotate_half(k) * sin)
47
-
48
-
49
- class MultiHeadAttention(nn.Module):
50
- def __init__(self, n_embd: int, n_head: int, dropout: float = 0.1):
51
- super().__init__()
52
- assert n_embd % n_head == 0
53
- self.n_head = n_head
54
- self.n_embd = n_embd
55
- self.head_dim = n_embd // n_head
56
- self.qkv = nn.Linear(n_embd, 3 * n_embd, bias=False)
57
- self.proj = nn.Linear(n_embd, n_embd, bias=False)
58
- self.drop = nn.Dropout(dropout)
59
- self.rope = RotaryEmbedding(self.head_dim)
60
-
61
- def forward(self, x, mask=None):
62
- B, T, C = x.shape
63
- qkv = self.qkv(x).reshape(B, T, 3, self.n_head, self.head_dim)
64
- q, k, v = qkv.permute(2, 0, 3, 1, 4)
65
- cos, sin = self.rope(T, x.device, dtype=q.dtype)
66
- q, k = apply_rope(q, k, cos, sin)
67
- try:
68
- from torch.nn.functional import scaled_dot_product_attention
69
- out = scaled_dot_product_attention(q, k, v, attn_mask=None,
70
- dropout_p=self.drop.p if self.training else 0.0, is_causal=True)
71
- except Exception:
72
- scale = math.sqrt(self.head_dim)
73
- att = (q @ k.transpose(-2, -1)) / scale
74
- causal = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
75
- att.masked_fill_(causal, float("-inf"))
76
- if mask is not None:
77
- pad = (1 - mask[:, None, None, :].float()) * -1e4
78
- att = att + pad
79
- att = torch.softmax(att, dim=-1)
80
- att = self.drop(att)
81
- out = att @ v
82
- out = out.transpose(1, 2).reshape(B, T, C)
83
- return self.proj(out)
84
-
85
-
86
- class FeedForward(nn.Module):
87
- def __init__(self, n_embd: int, dropout: float = 0.1):
88
- super().__init__()
89
- self.net = nn.Sequential(
90
- nn.Linear(n_embd, 4 * n_embd, bias=False),
91
- nn.GELU(),
92
- nn.Linear(4 * n_embd, n_embd, bias=False),
93
- nn.Dropout(dropout),
94
- )
95
- def forward(self, x): return self.net(x)
96
-
97
-
98
- class TransformerBlock(nn.Module):
99
- def __init__(self, n_embd: int, n_head: int, dropout: float = 0.1):
100
- super().__init__()
101
- self.ln1 = nn.LayerNorm(n_embd)
102
- self.attn = MultiHeadAttention(n_embd, n_head, dropout)
103
- self.ln2 = nn.LayerNorm(n_embd)
104
- self.ff = FeedForward(n_embd, dropout)
105
-
106
- def forward(self, x, mask=None):
107
- x = x + self.attn(self.ln1(x), mask) # Pre-norm + residual
108
- x = x + self.ff(self.ln2(x))
109
- return x
110
-
111
-
112
- class SharedBase(nn.Module):
113
- def __init__(self, cfg=None):
114
- super().__init__()
115
- if cfg is None:
116
- cfg = config.BASE_CONFIG
117
- self.n_embd = cfg["n_embd"]
118
- self.n_positions = cfg["n_positions"]
119
- self.token_embedding = nn.Embedding(cfg["vocab_size"], cfg["n_embd"])
120
- self.drop = nn.Dropout(cfg["dropout"])
121
- self.blocks = nn.ModuleList([
122
- TransformerBlock(cfg["n_embd"], cfg["n_head"], cfg["dropout"])
123
- for _ in range(cfg["n_layer"])
124
- ])
125
- self.ln_f = nn.LayerNorm(cfg["n_embd"])
126
- if config.MEMORY_OPT.get("gradient_checkpointing"):
127
- self.gradient_checkpointing_enable()
128
- self._init_weights()
129
-
130
- def _init_weights(self):
131
- for m in self.modules():
132
- if isinstance(m, nn.Linear):
133
- nn.init.normal_(m.weight, mean=0.0, std=0.02)
134
- elif isinstance(m, nn.Embedding):
135
- nn.init.normal_(m.weight, mean=0.0, std=0.02)
136
-
137
- def gradient_checkpointing_enable(self):
138
- self._use_checkpointing = True
139
-
140
- def forward(self, input_ids, attention_mask=None):
141
- x = self.drop(self.token_embedding(input_ids))
142
- use_ckpt = getattr(self, "_use_checkpointing", False) and self.training
143
- for block in self.blocks:
144
- if use_ckpt:
145
- from torch.utils.checkpoint import checkpoint
146
- x = checkpoint(block, x, attention_mask, use_reentrant=False)
147
- else:
148
- x = block(x, attention_mask)
149
- return self.ln_f(x)
150
-
151
- def get_param_count(self):
152
- return sum(p.numel() for p in self.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/echo.py DELETED
@@ -1,71 +0,0 @@
1
- # model/echo.py
2
- import sys, torch, torch.nn as nn
3
- from pathlib import Path
4
- sys.path.insert(0, str(Path(__file__).parent.parent))
5
- import config
6
-
7
-
8
- class Echo(nn.Module):
9
- """
10
- Memory retrieval and context injection module.
11
- Stage 1: Retrieve relevant past exchanges from SQLite/FAISS, inject as context prefix.
12
- Stage 2: Cross-attention memory gate during generation.
13
- """
14
- def __init__(self, cfg=None, n_embd=None):
15
- super().__init__()
16
- cfg = cfg or config.ECHO_CONFIG
17
- n = n_embd or cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
18
- from model.base import TransformerBlock
19
- self.memory_attn = TransformerBlock(n, cfg.get("n_head", config.BASE_CONFIG["n_head"]))
20
- self.gate_proj = nn.Linear(n * 2, n, bias=False)
21
- self.gate_norm = nn.LayerNorm(n)
22
- self._working_memory = [] # List of (role, cpu_embedding) tuples
23
- self._max_memory = cfg.get("max_memory_tokens", 512) // 64 # keep ~8 entries
24
-
25
- def retrieve_context(self, query: str, db, vs) -> str:
26
- """
27
- Retrieve relevant past conversations and inject as context prefix.
28
- Returns an enriched prompt string with memory context prepended.
29
- """
30
- if db is None: return query
31
- # Try semantic search first (FAISS), fall back to keyword (FTS5)
32
- results = []
33
- if vs is not None:
34
- conv_ids = vs.search(query, top_k=3)
35
- if conv_ids:
36
- recent = db.get_recent_conversations(n=50)
37
- id_to_conv = {c["id"]: c for c in recent if "id" in c}
38
- for cid in conv_ids:
39
- if cid in id_to_conv:
40
- results.append(id_to_conv[cid])
41
- if not results:
42
- results = db.search_conversations(query[:100], limit=3)
43
- knowledge = db.search_knowledge(query[:100], limit=2)
44
- if not results and not knowledge: return query
45
- ctx_parts = ["[MEMORY]"]
46
- for r in results[:3]:
47
- u = str(r.get("user_msg", ""))[:100]
48
- a = str(r.get("assistant_msg", ""))[:200]
49
- if u: ctx_parts.append(f"Q: {u}\nA: {a}")
50
- for k in knowledge[:2]:
51
- ctx_parts.append(f"FACT: {k['key']} = {k['value']}")
52
- ctx_parts.append("[SEP]")
53
- ctx_parts.append(query)
54
- return "\n".join(ctx_parts)
55
-
56
- def update_working_memory(self, role: str, embedding: torch.Tensor):
57
- """
58
- Store a hidden-state embedding in the short-term working memory buffer.
59
- Embeddings are moved to CPU to avoid holding GPU memory between turns.
60
- """
61
- self._working_memory.append((role, embedding.detach().cpu()))
62
- if len(self._working_memory) > self._max_memory:
63
- self._working_memory = self._working_memory[-self._max_memory:]
64
-
65
- def clear_memory(self):
66
- """Clear the working memory buffer (call on session reset)."""
67
- self._working_memory.clear()
68
-
69
- def get_memory_summary(self) -> str:
70
- """Return a simple string summary of buffered memory entries."""
71
- return f"[Echo] {len(self._working_memory)} entries in working memory"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/ensemble.py DELETED
@@ -1,346 +0,0 @@
1
- """
2
- Shorekeeper ensemble with lazy expert loading.
3
- """
4
- import sys, torch, torch.nn as nn, torch.nn.functional as F
5
- from pathlib import Path
6
- sys.path.insert(0, str(Path(__file__).parent.parent))
7
- import config
8
- from config import EXPERT_NAMES, CHECKPOINT_DIR, DEVICE, DTYPE
9
- from model.base import SharedBase
10
- from model.herald import Herald
11
- from model.echo import Echo
12
- from model.sentinel import Sentinel
13
- from model.lazy_expert_loader import LazyExpertLoader
14
-
15
-
16
- class ShorekeeperEnsemble(nn.Module):
17
- """Full assembled Shorekeeper model with lazy expert loading."""
18
-
19
- def __init__(
20
- self,
21
- max_loaded_experts: int = 2,
22
- base_cfg: dict = None,
23
- expert_cfgs: dict = None,
24
- ):
25
- super().__init__()
26
- self.base_cfg = base_cfg or config.BASE_CONFIG
27
- self.expert_cfgs = expert_cfgs or config.EXPERT_CONFIGS
28
- self.base = SharedBase(self.base_cfg)
29
- self.expert_loader = LazyExpertLoader(
30
- expert_names=EXPERT_NAMES,
31
- checkpoint_dir=CHECKPOINT_DIR,
32
- device=DEVICE,
33
- max_loaded=max_loaded_experts,
34
- dtype=DTYPE,
35
- base_cfg=self.base_cfg,
36
- expert_cfgs=self.expert_cfgs,
37
- )
38
- herald_n_embd = self.base_cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
39
- self.herald = Herald(n_embd=herald_n_embd)
40
- self.echo = Echo(n_embd=herald_n_embd)
41
- sentinel_cfg = self.expert_cfgs.get("sentinel", config.EXPERT_CONFIGS.get("sentinel", {}))
42
- self.sentinel = Sentinel(cfg=sentinel_cfg, n_embd=herald_n_embd)
43
-
44
- print(f"[Ensemble] Initialized with lazy expert loading (max_loaded={max_loaded_experts})")
45
-
46
- # backward compatibility for code expecting model.experts[...]
47
- @property
48
- def experts(self):
49
- class _proxy:
50
- def __init__(self, loader):
51
- self.loader = loader
52
- def __getitem__(self, name):
53
- return self.loader.get_expert(name)
54
- def keys(self):
55
- return self.loader.expert_names
56
- def __iter__(self):
57
- return iter(self.loader.expert_names)
58
- return _proxy(self.expert_loader)
59
-
60
- # ── Convenience alias ─────────────────────────────────────────────
61
- @property
62
- def shared_base(self):
63
- return self.base
64
-
65
- # ── Core forward with lazy loading ────────────────────────────────
66
- def forward(self, input_ids, attention_mask=None,
67
- return_routing=False, return_sentinel=False):
68
- base_hidden = self.base(input_ids, attention_mask)
69
- # Ensure experts load to the same device/dtype as base_hidden
70
- self.expert_loader.set_device(base_hidden.device, dtype=base_hidden.dtype)
71
- routing = self.herald(base_hidden, attention_mask)
72
- expert_idx = routing["expert_indices"]
73
- expert_wts = routing["expert_weights"]
74
- B = input_ids.shape[0]
75
- logits = None
76
- experts_used = []
77
-
78
- for ki in range(expert_idx.shape[1]):
79
- for b in range(B):
80
- name = EXPERT_NAMES[expert_idx[b, ki].item()]
81
- if name not in experts_used:
82
- experts_used.append(name)
83
- expert = self.expert_loader.get_expert(name)
84
- out = expert(
85
- base_hidden[b:b+1],
86
- attention_mask[b:b+1] if attention_mask is not None else None,
87
- )
88
- if "logits" not in out:
89
- print(f"[Ensemble] Expert {name} output missing logits, skipping")
90
- continue
91
- if out["logits"].dim() != 3:
92
- print(f"[Ensemble] Expert {name} returned unexpected logits shape {out['logits'].shape}, skipping")
93
- continue
94
- if out["logits"].shape[-1] != self.base_cfg["vocab_size"]:
95
- print(f"[Ensemble] Expert {name} logits vocab size mismatch ({out['logits'].shape[-1]} vs {self.base_cfg['vocab_size']}), skipping")
96
- continue
97
- weighted = out["logits"] * expert_wts[b, ki]
98
- if logits is None:
99
- logits = torch.zeros(B, weighted.shape[1], weighted.shape[2],
100
- device=weighted.device, dtype=weighted.dtype)
101
- logits[b] = logits[b] + weighted.squeeze(0)
102
-
103
- result = {
104
- "logits": logits,
105
- "load_balance_loss": routing["load_balance_loss"],
106
- "experts_used": experts_used,
107
- }
108
- if return_routing:
109
- result["routing"] = {
110
- EXPERT_NAMES[i]: routing["router_probs"][0, i].item()
111
- for i in range(len(EXPERT_NAMES))
112
- }
113
- if return_sentinel:
114
- result["sentinel"] = self.sentinel.analyze(base_hidden, attention_mask)
115
- return result
116
-
117
- # ── Autoregressive generation ─────────────────────────────────────
118
- @torch.no_grad()
119
- def generate(
120
- self,
121
- idx: torch.Tensor,
122
- max_new_tokens: int = 200,
123
- temperature: float = 0.8,
124
- top_k: int = 40,
125
- top_p: float = None,
126
- expert_override: str = None,
127
- ) -> torch.Tensor:
128
- from tokenizer.tokenizer_utils import get_tokenizer
129
- try:
130
- eos_id = get_tokenizer().token_to_id("[EOS]")
131
- except Exception:
132
- eos_id = None
133
-
134
- max_ctx = self.base.n_positions
135
- for _ in range(max_new_tokens):
136
- idx_cond = idx if idx.size(1) <= max_ctx else idx[:, -max_ctx:]
137
- result = self.forward(idx_cond)
138
- logits = result["logits"][:, -1, :]
139
- if temperature == 0:
140
- next_id = logits.argmax(dim=-1, keepdim=True)
141
- else:
142
- logits = logits / max(temperature, 1e-8)
143
- if top_k is not None and top_k > 0:
144
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
145
- logits[logits < v[:, [-1]]] = float("-inf")
146
- if top_p is not None and 0.0 < top_p < 1.0:
147
- sorted_logits, sorted_idx = torch.sort(logits, descending=True)
148
- cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
149
- sorted_idx_to_remove = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
150
- sorted_logits[sorted_idx_to_remove] = float("-inf")
151
- logits = logits.scatter(1, sorted_idx, sorted_logits)
152
- probs = F.softmax(logits, dim=-1)
153
- next_id = torch.multinomial(probs, num_samples=1)
154
-
155
- idx = torch.cat([idx, next_id], dim=1)
156
- if eos_id is not None and next_id[0, 0].item() == eos_id:
157
- break
158
-
159
- return idx
160
-
161
- # ── Routing inspection ────────────────────────────────────────────
162
- @torch.no_grad()
163
- def get_routing(self, idx: torch.Tensor):
164
- base_hidden = self.base(idx)
165
- routing = self.herald(base_hidden)
166
- probs = routing["router_probs"][0]
167
- pairs = [(EXPERT_NAMES[i], probs[i].item()) for i in range(len(EXPERT_NAMES))]
168
- pairs.sort(key=lambda x: -x[1])
169
- return pairs, False
170
-
171
- # ── Safety scan ─────────────────────────────────────────────────
172
- def scan_output(self, primary_expert: str, text: str):
173
- from tokenizer.tokenizer_utils import encode_batch
174
- device = next(self.parameters()).device
175
- ids, mask = encode_batch([text], max_length=512)
176
- ids = ids.to(device)
177
- mask = mask.to(device)
178
- with torch.no_grad():
179
- base_hidden = self.base(ids, mask)
180
- report = self.sentinel.analyze(base_hidden, mask)
181
- self.sentinel.log_expert(primary_expert)
182
- return report
183
-
184
- # ── Persistence ─────────────────────────────────────────────────
185
- def save(self, path: str):
186
- Path(path).parent.mkdir(parents=True, exist_ok=True)
187
- state = {
188
- "base_state_dict": self.base.state_dict(),
189
- "herald_state_dict": self.herald.state_dict(),
190
- "echo_state_dict": self.echo.state_dict(),
191
- "sentinel_state_dict": self.sentinel.state_dict(),
192
- }
193
- torch.save(state, path)
194
- print(f"[Ensemble] Saved core components to {path}")
195
-
196
- @classmethod
197
- def _safe_load_state_dict(cls, module: nn.Module, state_dict: dict, name: str):
198
- own = module.state_dict()
199
- filtered = {}
200
- for k, v in state_dict.items():
201
- if k in own and own[k].shape == v.shape:
202
- filtered[k] = v
203
- elif k in own:
204
- print(f"[Ensemble] Skipping mismatched {name} key {k}: checkpoint {v.shape}, model {own[k].shape}")
205
- if filtered:
206
- module.load_state_dict(filtered, strict=False)
207
-
208
- @classmethod
209
- def load(cls, path: str, device: str = "cpu", max_loaded_experts: int = 3) -> "ShorekeeperEnsemble":
210
- ckpt = torch.load(path, map_location="cpu")
211
- base_cfg = None
212
- expert_cfgs = None
213
- if isinstance(ckpt, dict):
214
- base_cfg = ckpt.get("base_config", None)
215
- expert_cfgs = ckpt.get("expert_configs", None)
216
-
217
- model = cls(
218
- max_loaded_experts=max_loaded_experts,
219
- base_cfg=base_cfg,
220
- expert_cfgs=expert_cfgs,
221
- ).to(device)
222
- model.expert_loader.set_device(device, dtype=DTYPE)
223
-
224
- state: dict = {}
225
- if isinstance(ckpt, dict):
226
- if "model_state" in ckpt:
227
- state = ckpt["model_state"]
228
- elif "model_state_dict" in ckpt:
229
- state = ckpt["model_state_dict"]
230
- elif "base_state_dict" in ckpt or "herald_state_dict" in ckpt:
231
- # Modern saved ensemble format
232
- cls._safe_load_state_dict(model.base, ckpt.get("base_state_dict", {}), "base")
233
- cls._safe_load_state_dict(model.herald, ckpt.get("herald_state_dict", {}), "herald")
234
- cls._safe_load_state_dict(model.echo, ckpt.get("echo_state_dict", {}), "echo")
235
- cls._safe_load_state_dict(model.sentinel, ckpt.get("sentinel_state_dict", {}), "sentinel")
236
- print(f"[Ensemble] Loaded from {path}")
237
- return model
238
- else:
239
- state = ckpt
240
- else:
241
- raise ValueError(f"Unsupported checkpoint type: {type(ckpt)}")
242
-
243
- if not isinstance(state, dict):
244
- raise ValueError("Checkpoint does not contain recognizable state dict")
245
-
246
- # Load base + modules if keys present
247
- base_state = {k.replace("shared_base.", ""): v for k, v in state.items() if k.startswith("shared_base.")}
248
- if base_state:
249
- cls._safe_load_state_dict(model.base, base_state, "base")
250
-
251
- herald_state = {k.replace("herald.", ""): v for k, v in state.items() if k.startswith("herald.")}
252
- if herald_state:
253
- cls._safe_load_state_dict(model.herald, herald_state, "herald")
254
-
255
- echo_state = {k.replace("echo.", ""): v for k, v in state.items() if k.startswith("echo.")}
256
- if echo_state:
257
- cls._safe_load_state_dict(model.echo, echo_state, "echo")
258
-
259
- sentinel_state = {k.replace("sentinel.", ""): v for k, v in state.items() if k.startswith("sentinel.")}
260
- if sentinel_state:
261
- cls._safe_load_state_dict(model.sentinel, sentinel_state, "sentinel")
262
-
263
- # Load expert states if available
264
- expert_keys = [k for k in state.keys() if k.startswith("experts.") or k.startswith("expert.")]
265
- if expert_keys:
266
- for name in EXPERT_NAMES:
267
- for prefix in [f"experts.{name}.", f"expert.{name}."]:
268
- expert_state = {k.replace(prefix, ""): v for k, v in state.items() if k.startswith(prefix)}
269
- if expert_state:
270
- expert = model.expert_loader.get_expert(name)
271
- cls._safe_load_state_dict(expert, expert_state, f"expert.{name}")
272
- break
273
-
274
- # If no prefixed expert keys, maybe experts are at top-level names
275
- for name in EXPERT_NAMES:
276
- expert_state = {k.replace(f"{name}.", ""): v for k, v in state.items() if k.startswith(f"{name}.")}
277
- if expert_state:
278
- expert = model.expert_loader.get_expert(name)
279
- cls._safe_load_state_dict(expert, expert_state, f"expert.{name}")
280
-
281
- print(f"[Ensemble] Loaded from {path}")
282
- return model
283
-
284
- @classmethod
285
- def load_from_components(cls, checkpoint_dir: str = None, device: str = "cpu") -> "ShorekeeperEnsemble":
286
- model = cls().to(device)
287
- model.expert_loader.set_device(device, dtype=DTYPE)
288
- ckpt_root = Path(checkpoint_dir or CHECKPOINT_DIR)
289
- base_ckpt = ckpt_root / "base" / "best.pt"
290
- if base_ckpt.exists():
291
- try:
292
- ckpt = torch.load(base_ckpt, map_location="cpu")
293
- state = ckpt.get("model_state_dict", ckpt)
294
- base_state = {k.replace("0.", ""): v for k, v in state.items() if k.startswith("0.")}
295
- if base_state:
296
- model.base.load_state_dict(base_state, strict=False)
297
- else:
298
- model.base.load_state_dict(state, strict=False)
299
- except Exception as e:
300
- print(f"[Ensemble] Warning: failed to load base checkpoint {base_ckpt}: {e}")
301
-
302
- # Experts are loaded lazily by LazyExpertLoader when first routed.
303
- # Ensure checkpoint availability is set by LazyExpertLoader initialization.
304
-
305
- herald_path = ckpt_root / "herald" / "best.pt"
306
- if herald_path.exists():
307
- try:
308
- ckpt = torch.load(herald_path, map_location="cpu")
309
- model.herald.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
310
- except Exception as e:
311
- print(f"[Ensemble] Warning: failed to load herald checkpoint {herald_path}: {e}")
312
-
313
- sentinel_path = ckpt_root / "sentinel" / "best.pt"
314
- if sentinel_path.exists():
315
- try:
316
- ckpt = torch.load(sentinel_path, map_location="cpu")
317
- model.sentinel.load_state_dict(ckpt.get("model_state_dict", ckpt), strict=False)
318
- except Exception as e:
319
- print(f"[Ensemble] Warning: failed to load sentinel checkpoint {sentinel_path}: {e}")
320
-
321
- return model
322
-
323
- # ── Expert management ─────────────────────────────────────────────────
324
- def preload_experts(self, expert_names: list):
325
- self.expert_loader.preload_experts(expert_names)
326
-
327
- def get_loaded_experts(self):
328
- return [name for name, loaded in self.expert_loader.get_cache_status().items() if loaded]
329
-
330
- def clear_expert_cache(self):
331
- self.expert_loader.clear_cache()
332
-
333
- def set_max_loaded_experts(self, max_loaded: int):
334
- self.expert_loader.set_max_loaded(max_loaded)
335
-
336
- def freeze_base(self):
337
- for p in self.base.parameters(): p.requires_grad = False
338
-
339
- def unfreeze_all(self):
340
- for p in self.parameters(): p.requires_grad = True
341
-
342
- def get_trainable_count(self):
343
- return sum(p.numel() for p in self.parameters() if p.requires_grad)
344
-
345
- def get_total_count(self):
346
- return sum(p.numel() for p in self.parameters())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/expert.py DELETED
@@ -1,45 +0,0 @@
1
- # model/expert.py
2
- import sys, torch, torch.nn as nn
3
- from pathlib import Path
4
- sys.path.insert(0, str(Path(__file__).parent.parent))
5
- import config
6
- from model.base import TransformerBlock
7
- from tokenizer.tokenizer_utils import vocab_size
8
-
9
-
10
- class ExpertHead(nn.Module):
11
- """
12
- Domain-specialized transformer head.
13
- Receives SharedBase hidden states and processes them
14
- through its own transformer layers to produce domain-specific logits.
15
- """
16
- def __init__(self, expert_name: str, expert_cfg=None, base_cfg=None):
17
- super().__init__()
18
- from config import EXPERT_NAMES
19
- assert expert_name in EXPERT_NAMES, f"Unknown expert: {expert_name}"
20
- if expert_cfg is None:
21
- expert_cfg = config.EXPERT_CONFIGS[expert_name]
22
- if base_cfg is None:
23
- base_cfg = config.BASE_CONFIG
24
- n_embd = expert_cfg["n_embd"]
25
- self.input_norm = nn.LayerNorm(n_embd)
26
- self.input_proj = nn.Linear(base_cfg["n_embd"], n_embd, bias=False)
27
- self.blocks = nn.ModuleList([
28
- TransformerBlock(n_embd, expert_cfg["n_head"]) for _ in range(expert_cfg["n_layer"])
29
- ])
30
- self.ln_f = nn.LayerNorm(n_embd)
31
- self.lm_head = nn.Linear(n_embd, base_cfg["vocab_size"], bias=False)
32
- nn.init.normal_(self.lm_head.weight, std=0.02)
33
-
34
- def forward(self, base_hidden, attention_mask=None):
35
- x = self.input_proj(self.input_norm(base_hidden))
36
- for block in self.blocks:
37
- x = block(x, attention_mask)
38
- x = self.ln_f(x)
39
- return {"logits": self.lm_head(x), "hidden": x}
40
-
41
- def freeze(self):
42
- for p in self.parameters(): p.requires_grad = False
43
-
44
- def unfreeze(self):
45
- for p in self.parameters(): p.requires_grad = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/herald.py DELETED
@@ -1,62 +0,0 @@
1
- # model/herald.py
2
- import sys, torch, torch.nn as nn, torch.nn.functional as F
3
- from pathlib import Path
4
- sys.path.insert(0, str(Path(__file__).parent.parent))
5
- import config
6
- from model.base import TransformerBlock
7
-
8
-
9
- class Herald(nn.Module):
10
- """
11
- Routes queries to the top-K most relevant experts.
12
- Trained as a classifier: given base hidden states, predict expert indices.
13
- Uses load-balance loss to prevent routing collapse (always picking 1 expert).
14
- """
15
- def __init__(self, cfg=None, n_embd=None):
16
- super().__init__()
17
- cfg = cfg or config.HERALD_CONFIG
18
- if n_embd is None:
19
- n_embd = cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
20
- self.top_k = cfg.get("top_k", 2)
21
- self.n_experts = cfg.get("n_experts", len(config.EXPERT_NAMES))
22
- self.query_encoder = nn.ModuleList([
23
- TransformerBlock(n_embd, cfg.get("n_head", config.BASE_CONFIG["n_head"])) for _ in range(cfg.get("n_layer", 2))
24
- ])
25
- self.query_norm = nn.LayerNorm(n_embd)
26
- self.expert_scorer = nn.Linear(n_embd, self.n_experts, bias=True)
27
-
28
- def forward(self, base_hidden, attention_mask=None):
29
- x = base_hidden
30
- for block in self.query_encoder:
31
- x = block(x, attention_mask)
32
- x = self.query_norm(x)
33
- # Pool: weighted mean over sequence (ignore padding)
34
- if attention_mask is not None:
35
- mask = attention_mask[:, :, None].float()
36
- pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
37
- else:
38
- pooled = x.mean(1)
39
- scores = self.expert_scorer(pooled) # [B, n_experts]
40
- probs = F.softmax(scores, dim=-1)
41
- top_probs, top_idx = probs.topk(self.top_k, dim=-1)
42
- # Load balance loss: encourages uniform expert utilization
43
- # From "Switch Transformers" (Fedus et al. 2022)
44
- expert_frac = probs.mean(0)
45
- target_frac = torch.ones_like(expert_frac) / self.n_experts
46
- lb_loss = F.mse_loss(expert_frac, target_frac)
47
- return {
48
- "router_probs": probs,
49
- "expert_indices": top_idx,
50
- "expert_weights": top_probs,
51
- "load_balance_loss": lb_loss,
52
- }
53
-
54
- def get_routing_display(self, base_hidden, attention_mask=None) -> str:
55
- routing = self.forward(base_hidden, attention_mask)
56
- lines = []
57
- probs = routing["router_probs"][0]
58
- names = config.EXPERT_NAMES if hasattr(config, "EXPERT_NAMES") else []
59
- for i, name in enumerate(names):
60
- bar = "█" * int(probs[i].item() * 20)
61
- lines.append(f" {name:<12} {bar:<20} {probs[i]:.3f}")
62
- return "\n".join(lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/lazy_expert_loader.py DELETED
@@ -1,120 +0,0 @@
1
- """
2
- Lazy expert loading system for Shorekeeper.
3
- Only loads experts into VRAM when they're actually routed to.
4
- Implements LRU caching to keep hot experts loaded while evicting cold ones.
5
- """
6
- import torch
7
- import torch.nn as nn
8
- from pathlib import Path
9
- from collections import OrderedDict
10
- from typing import Dict
11
-
12
-
13
- class LazyExpertLoader(nn.Module):
14
- """Manages expert models with lazy loading and LRU caching."""
15
-
16
- def __init__(
17
- self,
18
- expert_names: list,
19
- checkpoint_dir: Path,
20
- device: str = "cuda",
21
- max_loaded: int = 3,
22
- dtype: torch.dtype = torch.bfloat16,
23
- base_cfg: dict = None,
24
- expert_cfgs: dict = None,
25
- ):
26
- super().__init__()
27
- self.expert_names = expert_names
28
- self.checkpoint_dir = Path(checkpoint_dir)
29
- self.device = device
30
- self.max_loaded = max_loaded
31
- self.dtype = dtype
32
- self.base_cfg = base_cfg or {}
33
- self.expert_cfgs = expert_cfgs or {}
34
-
35
- self.experts = nn.ModuleDict()
36
- self._lru = OrderedDict()
37
-
38
- self._available_checkpoints = {
39
- name: self.checkpoint_dir / "experts" / name / "best.pt"
40
- for name in expert_names
41
- if (self.checkpoint_dir / "experts" / name / "best.pt").exists()
42
- }
43
-
44
- print(f"[LazyExpertLoader] Initialized with max_loaded={max_loaded}")
45
- print(f"[LazyExpertLoader] Found {len(self._available_checkpoints)} expert checkpoints")
46
-
47
- def get_expert(self, expert_name: str) -> nn.Module:
48
- if expert_name in self.experts:
49
- self._lru.move_to_end(expert_name)
50
- return self.experts[expert_name]
51
-
52
- expert = self._load_expert(expert_name)
53
- self.experts[expert_name] = expert
54
- self._lru[expert_name] = True
55
-
56
- if len(self._lru) > self.max_loaded:
57
- evicted_name, _ = self._lru.popitem(last=False)
58
- evicted_expert = self.experts.pop(evicted_name)
59
- evicted_expert.to("cpu")
60
- print(f"[LazyExpertLoader] Evicted '{evicted_name}' from cache")
61
-
62
- return expert
63
-
64
- def _load_expert(self, expert_name: str) -> nn.Module:
65
- from model.expert import ExpertHead
66
-
67
- expert_cfg = self.expert_cfgs.get(expert_name, None)
68
- print(f"[LazyExpertLoader] Loading expert '{expert_name}'...")
69
- expert = ExpertHead(expert_name, expert_cfg=expert_cfg, base_cfg=self.base_cfg)
70
-
71
- ckpt_path = self._available_checkpoints.get(expert_name,
72
- self.checkpoint_dir / "experts" / expert_name / "best.pt")
73
- if ckpt_path.exists():
74
- try:
75
- ckpt = torch.load(ckpt_path, map_location="cpu")
76
- state_dict = ckpt.get("model_state_dict", ckpt)
77
- expert.load_state_dict(state_dict, strict=False)
78
- print(f"[LazyExpertLoader] Loaded weights from {ckpt_path}")
79
- except Exception as e:
80
- print(f"[LazyExpertLoader] Warning: Failed to load {ckpt_path}: {e}")
81
- else:
82
- print(f"[LazyExpertLoader] No checkpoint found for '{expert_name}', using random init")
83
-
84
- load_dtype = self.dtype
85
- if self.device == "cpu" and load_dtype in [torch.float16, torch.bfloat16]:
86
- load_dtype = torch.float32
87
- expert = expert.to(device=self.device, dtype=load_dtype)
88
- expert.eval()
89
- return expert
90
-
91
- def preload_experts(self, expert_names: list):
92
- for name in expert_names:
93
- self.get_expert(name)
94
-
95
- def clear_cache(self):
96
- for e in self.experts.values():
97
- e.to("cpu")
98
- self.experts.clear()
99
- self._lru.clear()
100
- print("[LazyExpertLoader] Cache cleared")
101
-
102
- def get_cache_status(self) -> Dict[str, bool]:
103
- return {name: name in self.experts for name in self.expert_names}
104
-
105
- def set_device(self, device: str, dtype: torch.dtype = None):
106
- self.device = device
107
- if dtype is not None:
108
- self.dtype = dtype
109
- if device == "cpu" and self.dtype == torch.float16:
110
- self.dtype = torch.float32
111
- for expert in self.experts.values():
112
- expert.to(device=device, dtype=self.dtype)
113
-
114
- def set_max_loaded(self, max_loaded: int):
115
- self.max_loaded = max_loaded
116
- while len(self._lru) > self.max_loaded:
117
- evicted_name, _ = self._lru.popitem(last=False)
118
- evicted_expert = self.experts.pop(evicted_name)
119
- evicted_expert.to("cpu")
120
- print(f"[LazyExpertLoader] Evicted '{evicted_name}' (cache resize)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/sentinel.py DELETED
@@ -1,90 +0,0 @@
1
- # model/sentinel.py
2
- import sys, torch, torch.nn as nn
3
- from dataclasses import dataclass
4
- from pathlib import Path
5
- import time
6
- sys.path.insert(0, str(Path(__file__).parent.parent))
7
- import config
8
- from model.base import TransformerBlock
9
-
10
-
11
- @dataclass
12
- class SentinelReport:
13
- verdict: str # "CLEAN", "FLAG", or "BLOCK"
14
- drift_score: float # 0-1: probability of behavioral drift
15
- refusal_score: float # 0-1: probability of inappropriate refusal
16
- hallucination_score: float # 0-1: probability of hallucination
17
- overall_risk: float # Combined risk score
18
-
19
- @property
20
- def is_clean(self) -> bool:
21
- return self.verdict == "CLEAN"
22
-
23
- @property
24
- def protocol(self) -> str:
25
- return self.verdict.lower()
26
-
27
- @property
28
- def total(self) -> float:
29
- return self.overall_risk
30
-
31
-
32
- class Sentinel(nn.Module):
33
- """Monitors outputs for behavioral drift in real time."""
34
- def __init__(self, cfg=None, n_embd=None):
35
- super().__init__()
36
- cfg = cfg or config.EXPERT_CONFIGS.get("sentinel", {})
37
- n = n_embd or cfg.get("n_embd", config.BASE_CONFIG["n_embd"])
38
- hidden = config.SENTINEL_CONFIG["hidden_dim"]
39
- self.encoder_blocks = nn.ModuleList([
40
- TransformerBlock(n, cfg.get("n_head", config.BASE_CONFIG["n_head"])) for _ in range(cfg.get("n_layer", 4))
41
- ])
42
- self.encoder_norm = nn.LayerNorm(n)
43
- self.pool = nn.AdaptiveAvgPool1d(1)
44
- self.proj = nn.Sequential(nn.Linear(n, hidden), nn.ReLU(), nn.Dropout(0.1))
45
- self.drift_head = nn.Linear(hidden, 1)
46
- self.refusal_head = nn.Linear(hidden, 1)
47
- self.halluc_head = nn.Linear(hidden, 1)
48
- self._incidents = [] # Session incident log
49
-
50
- def forward(self, base_hidden, attention_mask=None):
51
- return self.analyze(base_hidden, attention_mask)
52
-
53
- @torch.no_grad()
54
- def analyze(self, base_hidden, attention_mask=None) -> SentinelReport:
55
- x = base_hidden
56
- for block in self.encoder_blocks:
57
- x = block(x, attention_mask)
58
- x = self.encoder_norm(x)
59
- pooled = self.pool(x.transpose(1, 2)).squeeze(-1)
60
- feats = self.proj(pooled)
61
- drift = self.drift_head(feats).squeeze(-1).sigmoid().mean().item()
62
- refus = self.refusal_head(feats).squeeze(-1).sigmoid().mean().item()
63
- halluc = self.halluc_head(feats).squeeze(-1).sigmoid().mean().item()
64
- risk = max(drift, refus * 0.8, halluc * 0.6)
65
- cfg = SENTINEL_CONFIG
66
- if risk >= cfg["block_threshold"]: verdict = "BLOCK"
67
- elif risk >= cfg["flag_threshold"]: verdict = "FLAG"
68
- else: verdict = "CLEAN"
69
- report = SentinelReport(verdict, drift, refus, halluc, risk)
70
- if not report.is_clean:
71
- self._incidents.append({
72
- "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
73
- "expert": "unknown",
74
- "protocol": report.protocol,
75
- "score": risk,
76
- })
77
- return report
78
-
79
- def reset_session(self):
80
- """Clear the session incident log."""
81
- self._incidents.clear()
82
-
83
- def get_incidents(self) -> list:
84
- """Return all incidents logged this session."""
85
- return list(self._incidents)
86
-
87
- def log_expert(self, expert_name: str):
88
- """Tag the most recent incident with the expert name."""
89
- if self._incidents:
90
- self._incidents[-1]["expert"] = expert_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,14 @@
1
- torch>=2.0.0
2
- tokenizers>=0.15.0
3
- psutil
4
- tqdm
 
 
 
 
 
 
 
 
 
 
 
1
+ 211torch>=2.5.0
2
+ transformers>=4.53.0
3
+ accelerate>=1.0.0
4
+ sentencepiece>=0.2.0
5
+ bitsandbytes>=0.45.0
6
+ datasets>=3.0.0
7
+ math-verify>=0.5.2
8
+ chromadb>=0.5.0
9
+ sentence-transformers>=3.0.0
10
+ docker>=7.1.0
11
+ trl>=0.20.0
12
+ wandb>=0.19.0
13
+ pyyaml>=6.0
14
+ tqdm>=4.66.0
scripts/01_download_15b_data.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download MASSIVE datasets for 15B training
4
+ 200B+ tokens from verified STEM sources
5
+ """
6
+
7
+ import json
8
+ from pathlib import Path
9
+ from datasets import load_dataset
10
+
11
+ def download_15b_data():
12
+ print("=" * 80)
13
+ print("DOWNLOADING 200B+ TOKENS FOR 15B MODEL")
14
+ print("=" * 80)
15
+
16
+ data_dir = Path("./data/15b_data")
17
+ data_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ all_data = []
20
+ total_tokens = 0
21
+
22
+ # 1. The Pile - 800GB, 300B tokens (take 50B)
23
+ print("\n1. The Pile (300B tokens - taking 50B)...")
24
+ print(" This will take 1-2 days...")
25
+ try:
26
+ pile = load_dataset("EleutherAI/pile", split="train[:5000000]")
27
+ for item in pile:
28
+ text = item.get("text", "")
29
+ if text and len(text) > 200:
30
+ all_data.append({
31
+ "text": text,
32
+ "source": "pile"
33
+ })
34
+ print(f" ✓ Added {len(pile):,} examples")
35
+ except Exception as e:
36
+ print(f" ✗ Failed: {e}")
37
+
38
+ # 2. Proof-Pile-2 - 50B tokens of math/CS papers
39
+ print("\n2. Proof-Pile-2 (50B tokens - taking 20B)...")
40
+ try:
41
+ proof = load_dataset("EleutherAI/proof-pile-2", split="train[:2000000]")
42
+ for item in proof:
43
+ text = item.get("text", "")
44
+ if text and len(text) > 200:
45
+ all_data.append({
46
+ "text": text,
47
+ "source": "proofpile"
48
+ })
49
+ print(f" ✓ Added {len(proof):,} examples")
50
+ except Exception as e:
51
+ print(f" ✗ Failed: {e}")
52
+
53
+ # 3. StarCoder - 100B tokens of code
54
+ print("\n3. StarCoder (100B tokens - taking 30B)...")
55
+ try:
56
+ code = load_dataset("bigcode/starcoderdata", split="train[:3000000]")
57
+ for item in code:
58
+ content = item.get("content", "")
59
+ if content and len(content) > 100:
60
+ all_data.append({
61
+ "text": content,
62
+ "source": "starcoder"
63
+ })
64
+ print(f" ✓ Added {len(code):,} examples")
65
+ except Exception as e:
66
+ print(f" ✗ Failed: {e}")
67
+
68
+ # 4. C4 - 156GB, 150B tokens (take 30B)
69
+ print("\n4. C4 (150B tokens - taking 30B)...")
70
+ try:
71
+ c4 = load_dataset("c4", "en", split="train[:3000000]")
72
+ for item in c4:
73
+ text = item.get("text", "")
74
+ if text and len(text) > 200:
75
+ all_data.append({
76
+ "text": text,
77
+ "source": "c4"
78
+ })
79
+ print(f" ✓ Added {len(c4):,} examples")
80
+ except Exception as e:
81
+ print(f" ✗ Failed: {e}")
82
+
83
+ # 5. OpenWebMath - 14B tokens of math
84
+ print("\n5. OpenWebMath (14B tokens - taking all)...")
85
+ try:
86
+ math = load_dataset("open-web-math/open-web-math", split="train")
87
+ for item in math:
88
+ text = item.get("text", "")
89
+ if text and len(text) > 200:
90
+ all_data.append({
91
+ "text": text,
92
+ "source": "openwebmath"
93
+ })
94
+ print(f" ✓ Added {len(math):,} examples")
95
+ except Exception as e:
96
+ print(f" ✗ Failed: {e}")
97
+
98
+ print("\n" + "=" * 80)
99
+ print(f"TOTAL EXAMPLES: {len(all_data):,}")
100
+ print(f"ESTIMATED TOKENS: {len(all_data) * 500:,}")
101
+ print("=" * 80)
102
+
103
+ # Save
104
+ print("\nSaving to disk...")
105
+ with open(data_dir / "15b_train.jsonl", "w") as f:
106
+ for item in all_data:
107
+ f.write(json.dumps(item) + "\n")
108
+
109
+ print(f"✓ Saved to: {data_dir}/15b_train.jsonl")
110
+
111
+ if __name__ == "__main__":
112
+ download_15b_data()
scripts/01_download_7b_150gb.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 150GB Curated STEM Dataset for 7B Model Training
4
+ Enough for a high-quality 7B model from scratch
5
+ Total: ~150GB compressed, ~500GB uncompressed
6
+ """
7
+
8
+ import json
9
+ from pathlib import Path
10
+ from datasets import load_dataset
11
+ import time
12
+
13
+ def download_7b_dataset():
14
+ print("=" * 80)
15
+ print("DOWNLOADING 150GB STEM DATASET FOR 7B MODEL")
16
+ print("=" * 80)
17
+ print("\n⚠️ This will download ~150GB of data")
18
+ print(" Estimated time: 4-8 hours depending on connection")
19
+ print(" Disk space needed: ~500GB after decompression")
20
+ print("\nPress Ctrl+C to cancel, or wait 5 seconds to continue...")
21
+ time.sleep(5)
22
+
23
+ data_dir = Path("./data/7b_150gb")
24
+ data_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ all_data = []
27
+ total_examples = 0
28
+
29
+ # ============================================================
30
+ # DATASET 1: The Pile - 50GB (largest single source)
31
+ # ============================================================
32
+ print("\n" + "=" * 80)
33
+ print("DATASET 1: The Pile (50GB - General text)")
34
+ print("=" * 80)
35
+ try:
36
+ pile = load_dataset("EleutherAI/pile", split="train[:2000000]")
37
+ for item in pile:
38
+ text = item.get("text", "")
39
+ if text and len(text) > 500:
40
+ all_data.append({
41
+ "text": text[:2048],
42
+ "source": "pile"
43
+ })
44
+ print(f" ✓ Added {len(pile):,} examples")
45
+ total_examples += len(pile)
46
+ except Exception as e:
47
+ print(f" ✗ Failed: {e}")
48
+
49
+ # ============================================================
50
+ # DATASET 2: StarCoder - 30GB (Code)
51
+ # ============================================================
52
+ print("\n" + "=" * 80)
53
+ print("DATASET 2: StarCoder (30GB - Code)")
54
+ print("=" * 80)
55
+ try:
56
+ code = load_dataset("bigcode/starcoderdata", split="train[:1500000]")
57
+ for item in code:
58
+ content = item.get("content", "")
59
+ if content and len(content) > 200:
60
+ all_data.append({
61
+ "text": content[:2048],
62
+ "source": "starcoder"
63
+ })
64
+ print(f" ✓ Added {len(code):,} examples")
65
+ total_examples += len(code)
66
+ except Exception as e:
67
+ print(f" ✗ Failed: {e}")
68
+
69
+ # ============================================================
70
+ # DATASET 3: C4 - 25GB (Clean web text)
71
+ # ============================================================
72
+ print("\n" + "=" * 80)
73
+ print("DATASET 3: C4 (25GB - Clean web text)")
74
+ print("=" * 80)
75
+ try:
76
+ c4 = load_dataset("c4", "en", split="train[:1500000]")
77
+ for item in c4:
78
+ text = item.get("text", "")
79
+ if text and len(text) > 300:
80
+ all_data.append({
81
+ "text": text[:2048],
82
+ "source": "c4"
83
+ })
84
+ print(f" ✓ Added {len(c4):,} examples")
85
+ total_examples += len(c4)
86
+ except Exception as e:
87
+ print(f" ✗ Failed: {e}")
88
+
89
+ # ============================================================
90
+ # DATASET 4: Proof-Pile-2 - 20GB (Math/CS papers)
91
+ # ============================================================
92
+ print("\n" + "=" * 80)
93
+ print("DATASET 4: Proof-Pile-2 (20GB - Math/CS papers)")
94
+ print("=" * 80)
95
+ try:
96
+ proof = load_dataset("EleutherAI/proof-pile-2", split="train[:1000000]")
97
+ for item in proof:
98
+ text = item.get("text", "")
99
+ if text and len(text) > 500:
100
+ all_data.append({
101
+ "text": text[:2048],
102
+ "source": "proofpile"
103
+ })
104
+ print(f" ✓ Added {len(proof):,} examples")
105
+ total_examples += len(proof)
106
+ except Exception as e:
107
+ print(f" ✗ Failed: {e}")
108
+
109
+ # ============================================================
110
+ # DATASET 5: OpenWebMath - 10GB (Math web pages)
111
+ # ============================================================
112
+ print("\n" + "=" * 80)
113
+ print("DATASET 5: OpenWebMath (10GB - Math web pages)")
114
+ print("=" * 80)
115
+ try:
116
+ math = load_dataset("open-web-math/open-web-math", split="train[:500000]")
117
+ for item in math:
118
+ text = item.get("text", "")
119
+ if text and len(text) > 300:
120
+ all_data.append({
121
+ "text": text[:2048],
122
+ "source": "openwebmath"
123
+ })
124
+ print(f" ✓ Added {len(math):,} examples")
125
+ total_examples += len(math)
126
+ except Exception as e:
127
+ print(f" ✗ Failed: {e}")
128
+
129
+ # ============================================================
130
+ # DATASET 6: MetaMathQA - 2.5GB (Math problems)
131
+ # ============================================================
132
+ print("\n" + "=" * 80)
133
+ print("DATASET 6: MetaMathQA (2.5GB - Math problems)")
134
+ print("=" * 80)
135
+ try:
136
+ metamath = load_dataset("meta-math/MetaMathQA", split="train")
137
+ for item in metamath:
138
+ text = f"Question: {item.get('query', '')}\nAnswer: {item.get('response', '')}"
139
+ all_data.append({
140
+ "text": text,
141
+ "source": "metamath"
142
+ })
143
+ print(f" ✓ Added {len(metamath):,} examples")
144
+ total_examples += len(metamath)
145
+ except Exception as e:
146
+ print(f" ✗ Failed: {e}")
147
+
148
+ # ============================================================
149
+ # DATASET 7: CodeFeedback - 2GB (Code instructions)
150
+ # ============================================================
151
+ print("\n" + "=" * 80)
152
+ print("DATASET 7: CodeFeedback (2GB - Code instructions)")
153
+ print("=" * 80)
154
+ try:
155
+ codefb = load_dataset("m-a-p/CodeFeedback", split="train[:150000]")
156
+ for item in codefb:
157
+ text = f"Instruction: {item.get('instruction', '')}\nCode: {item.get('output', '')}"
158
+ if len(text) > 100:
159
+ all_data.append({
160
+ "text": text[:2048],
161
+ "source": "codefeedback"
162
+ })
163
+ print(f" ✓ Added {len(codefb):,} examples")
164
+ total_examples += len(codefb)
165
+ except Exception as e:
166
+ print(f" ✗ Failed: {e}")
167
+
168
+ # ============================================================
169
+ # DATASET 8: OpenMathInstruct-2 - 2GB (Math problems)
170
+ # ============================================================
171
+ print("\n" + "=" * 80)
172
+ print("DATASET 8: OpenMathInstruct-2 (2GB - Math problems)")
173
+ print("=" * 80)
174
+ try:
175
+ openmath = load_dataset("nvidia/OpenMathInstruct-2", split="train[:150000]")
176
+ for item in openmath:
177
+ text = f"Problem: {item.get('question', '')}\nSolution: {item.get('generated_solution', '')}"
178
+ all_data.append({
179
+ "text": text[:2048],
180
+ "source": "openmath"
181
+ })
182
+ print(f" ✓ Added {len(openmath):,} examples")
183
+ total_examples += len(openmath)
184
+ except Exception as e:
185
+ print(f" ✗ Failed: {e}")
186
+
187
+ # ============================================================
188
+ # DATASET 9: NuminaMath-CoT - 2GB (Math reasoning)
189
+ # ============================================================
190
+ print("\n" + "=" * 80)
191
+ print("DATASET 9: NuminaMath-CoT (2GB - Math reasoning)")
192
+ print("=" * 80)
193
+ try:
194
+ numina = load_dataset("AI-MO/NuminaMath-CoT", split="train[:100000]")
195
+ for item in numina:
196
+ text = f"Problem: {item.get('problem', '')}\nSolution: {item.get('solution', '')}"
197
+ all_data.append({
198
+ "text": text[:2048],
199
+ "source": "numinamath"
200
+ })
201
+ print(f" ✓ Added {len(numina):,} examples")
202
+ total_examples += len(numina)
203
+ except Exception as e:
204
+ print(f" ✗ Failed: {e}")
205
+
206
+ # ============================================================
207
+ # DATASET 10: ScienceQA - 0.5GB (Science questions)
208
+ # ============================================================
209
+ print("\n" + "=" * 80)
210
+ print("DATASET 10: ScienceQA (0.5GB - Science questions)")
211
+ print("=" * 80)
212
+ try:
213
+ science = load_dataset("derek-thomas/ScienceQA", split="train")
214
+ for item in science:
215
+ text = f"Question: {item.get('question', '')}\nAnswer: {item.get('answer', '')}"
216
+ all_data.append({
217
+ "text": text[:2048],
218
+ "source": "scienceqa"
219
+ })
220
+ print(f" ✓ Added {len(science):,} examples")
221
+ total_examples += len(science)
222
+ except Exception as e:
223
+ print(f" ✗ Failed: {e}")
224
+
225
+ # ============================================================
226
+ # SAVE DATASET
227
+ # ============================================================
228
+ print("\n" + "=" * 80)
229
+ print("SAVING DATASET")
230
+ print("=" * 80)
231
+ print(f"Total examples collected: {total_examples:,}")
232
+ print(f"Estimated size: ~150GB compressed, ~500GB uncompressed")
233
+
234
+ # Shuffle
235
+ import random
236
+ random.shuffle(all_data)
237
+
238
+ # Save as JSONL
239
+ output_path = data_dir / "7b_train.jsonl"
240
+ with open(output_path, "w") as f:
241
+ for item in all_data:
242
+ f.write(json.dumps(item) + "\n")
243
+
244
+ print(f"\n✓ Saved to: {output_path}")
245
+ print(f" File size: {output_path.stat().st_size / 1e9:.1f} GB")
246
+
247
+ # Save metadata
248
+ metadata = {
249
+ "total_examples": total_examples,
250
+ "sources": {}
251
+ }
252
+
253
+ for item in all_data:
254
+ src = item['source']
255
+ metadata['sources'][src] = metadata['sources'].get(src, 0) + 1
256
+
257
+ with open(data_dir / "metadata.json", "w") as f:
258
+ json.dump(metadata, f, indent=2)
259
+
260
+ print("\n" + "=" * 80)
261
+ print("DATASET BREAKDOWN")
262
+ print("=" * 80)
263
+ for src, count in metadata['sources'].items():
264
+ print(f" {src}: {count:,} examples")
265
+
266
+ print("\n" + "=" * 80)
267
+ print("✅ DOWNLOAD COMPLETE!")
268
+ print("=" * 80)
269
+ print("\nNext step: python3 scripts/04_train_universal.py")
270
+
271
+ if __name__ == "__main__":
272
+ download_7b_dataset()
scripts/01_download_stem_data.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download high-quality STEM datasets for SHOREKEEPER
4
+ Math, Code, Science - No random web text
5
+ """
6
+
7
+ import json
8
+ from pathlib import Path
9
+ from datasets import load_dataset
10
+
11
+ def download_stem_data():
12
+ print("=" * 70)
13
+ print("DOWNLOADING STEM DATASETS")
14
+ print("=" * 70)
15
+
16
+ data_dir = Path("./data/stem")
17
+ data_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ all_data = []
20
+
21
+ # 1. MetaMathQA - 395k math problems with step-by-step reasoning
22
+ print("\n1. MetaMathQA (395k math problems)...")
23
+ try:
24
+ dataset = load_dataset("meta-math/MetaMathQA", split="train")
25
+ print(f" Loading {len(dataset)} examples...")
26
+ for item in dataset:
27
+ all_data.append({
28
+ "prompt": item.get("query", ""),
29
+ "response": f"|special_token| {item.get('response', '')} |special_token|",
30
+ "source": "metamath"
31
+ })
32
+ print(f" ✓ Added {len(dataset)} math examples")
33
+ except Exception as e:
34
+ print(f" ✗ Failed: {e}")
35
+
36
+ # 2. CodeFeedback - 1.2M code instructions
37
+ print("\n2. CodeFeedback (1.2M code examples - taking 200k)...")
38
+ try:
39
+ dataset = load_dataset("m-a-p/CodeFeedback", split="train[:200000]")
40
+ print(f" Loading {len(dataset)} examples...")
41
+ for item in dataset:
42
+ instruction = item.get("instruction", "")
43
+ output = item.get("output", "")
44
+ if instruction and output:
45
+ all_data.append({
46
+ "prompt": instruction,
47
+ "response": f"|special_token| Here's the code:\n{output} |special_token|",
48
+ "source": "codefeedback"
49
+ })
50
+ print(f" ✓ Added {len(dataset)} code examples")
51
+ except Exception as e:
52
+ print(f" ✗ Failed: {e}")
53
+
54
+ # 3. NuminaMath-CoT - 860k math problems
55
+ print("\n3. NuminaMath-CoT (860k math problems - taking 200k)...")
56
+ try:
57
+ dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train[:200000]")
58
+ print(f" Loading {len(dataset)} examples...")
59
+ for item in dataset:
60
+ problem = item.get("problem", "")
61
+ solution = item.get("solution", "")
62
+ if problem and solution:
63
+ all_data.append({
64
+ "prompt": problem,
65
+ "response": f"|special_token| Let me solve this step by step.\n{solution} |special_token|",
66
+ "source": "numinamath"
67
+ })
68
+ print(f" ✓ Added {len(dataset)} math examples")
69
+ except Exception as e:
70
+ print(f" ✗ Failed: {e}")
71
+
72
+ # 4. ScienceQA - 21k science questions
73
+ print("\n4. ScienceQA (21k science questions)...")
74
+ try:
75
+ dataset = load_dataset("derek-thomas/ScienceQA", split="train")
76
+ print(f" Loading {len(dataset)} examples...")
77
+ for item in dataset:
78
+ question = item.get("question", "")
79
+ answer = item.get("answer", "")
80
+ if question and answer:
81
+ all_data.append({
82
+ "prompt": question,
83
+ "response": f"|special_token| Science explanation:\n{answer} |special_token|",
84
+ "source": "scienceqa"
85
+ })
86
+ print(f" ✓ Added {len(dataset)} science examples")
87
+ except Exception as e:
88
+ print(f" ✗ Failed: {e}")
89
+
90
+ # 5. GSM8K - 8.5k grade school math
91
+ print("\n5. GSM8K (8.5k grade school math)...")
92
+ try:
93
+ dataset = load_dataset("gsm8k", "main", split="train")
94
+ print(f" Loading {len(dataset)} examples...")
95
+ for item in dataset:
96
+ question = item.get("question", "")
97
+ answer = item.get("answer", "").split("####")[-1].strip()
98
+ if question and answer:
99
+ all_data.append({
100
+ "prompt": question,
101
+ "response": f"|special_token| {answer} |special_token|",
102
+ "source": "gsm8k"
103
+ })
104
+ print(f" ✓ Added {len(dataset)} math examples")
105
+ except Exception as e:
106
+ print(f" ✗ Failed: {e}")
107
+
108
+ print("\n" + "=" * 70)
109
+ print(f"TOTAL STEM EXAMPLES: {len(all_data):,}")
110
+ print("=" * 70)
111
+
112
+ # Show breakdown
113
+ sources = {}
114
+ for item in all_data:
115
+ src = item['source']
116
+ sources[src] = sources.get(src, 0) + 1
117
+
118
+ print("\nBreakdown by source:")
119
+ for src, count in sources.items():
120
+ print(f" {src}: {count:,}")
121
+
122
+ # Save
123
+ print("\nSaving to disk...")
124
+ with open(data_dir / "stem_train.jsonl", "w") as f:
125
+ for item in all_data:
126
+ f.write(json.dumps(item) + "\n")
127
+
128
+ print(f"✓ Saved to: {data_dir}/stem_train.jsonl")
129
+ print(f" Total size: {len(all_data):,} examples")
130
+
131
+ # Also create validation split
132
+ split_idx = int(len(all_data) * 0.95)
133
+ train = all_data[:split_idx]
134
+ val = all_data[split_idx:]
135
+
136
+ with open(data_dir / "stem_val.jsonl", "w") as f:
137
+ for item in val:
138
+ f.write(json.dumps(item) + "\n")
139
+
140
+ print(f" Train: {len(train):,}")
141
+ print(f" Val: {len(val):,}")
142
+
143
+ if __name__ == "__main__":
144
+ download_stem_data()
scripts/04_train.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SHOREKEEPER-4B Training Pipeline
4
+ Runs on any CUDA device (RTX 3060, H100, etc.)
5
+ """
6
+
7
+ import sys
8
+ import json
9
+ import torch
10
+ import torch.nn as nn
11
+ from pathlib import Path
12
+ from tqdm import tqdm
13
+
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ from src.shorekeeper import MemoryEfficientSHOREKEEPER
17
+ from transformers import AutoTokenizer
18
+
19
+ class SHOREKEEPERTrainer:
20
+ """Simple training loop for SHOREKEEPER"""
21
+
22
+ def __init__(self, model, tokenizer, config):
23
+ self.model = model
24
+ self.tokenizer = tokenizer
25
+ self.device = next(model.parameters()).device
26
+
27
+ self.learning_rate = config.get('learning_rate', 1e-4)
28
+ self.epochs = config.get('epochs', 3)
29
+ self.batch_size = config.get('batch_size', 2)
30
+ self.gradient_accumulation = config.get('gradient_accumulation', 4)
31
+
32
+ self.optimizer = torch.optim.AdamW(
33
+ self.model.parameters(),
34
+ lr=self.learning_rate,
35
+ weight_decay=0.01
36
+ )
37
+
38
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
39
+ self.optimizer,
40
+ T_max=1000,
41
+ eta_min=1e-6
42
+ )
43
+
44
+ self.step = 0
45
+
46
+ def train_step(self, batch):
47
+ """Single training step"""
48
+ self.model.train()
49
+
50
+ # Prepare batch
51
+ texts = batch['text']
52
+
53
+ # Tokenize
54
+ inputs = self.tokenizer(
55
+ texts,
56
+ return_tensors="pt",
57
+ padding=True,
58
+ truncation=True,
59
+ max_length=512
60
+ )
61
+
62
+ input_ids = inputs['input_ids'].to(self.device)
63
+
64
+ # Forward pass
65
+ logits = self.model(input_ids)
66
+
67
+ # Calculate loss (next token prediction)
68
+ shift_logits = logits[..., :-1, :].contiguous()
69
+ shift_labels = input_ids[..., 1:].contiguous()
70
+
71
+ # Cross entropy loss - ignore padding tokens
72
+ loss = nn.functional.cross_entropy(
73
+ shift_logits.view(-1, shift_logits.size(-1)),
74
+ shift_labels.view(-1),
75
+ ignore_index=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else -100
76
+ )
77
+
78
+ # Backward
79
+ loss.backward()
80
+
81
+ # Gradient accumulation
82
+ if (self.step + 1) % self.gradient_accumulation == 0:
83
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
84
+ self.optimizer.step()
85
+ self.scheduler.step()
86
+ self.optimizer.zero_grad()
87
+
88
+ self.step += 1
89
+
90
+ return loss.item()
91
+
92
+ def train(self, dataset, output_dir="./outputs"):
93
+ """Full training loop"""
94
+ print(f"\n{'='*60}")
95
+ print("Starting Training")
96
+ print(f"{'='*60}")
97
+ print(f"Device: {self.device}")
98
+ print(f"Training samples: {len(dataset)}")
99
+ print(f"Batch size: {self.batch_size}")
100
+ print(f"Learning rate: {self.learning_rate}")
101
+ print(f"Epochs: {self.epochs}")
102
+ print(f"{'='*60}\n")
103
+
104
+ output_dir = Path(output_dir)
105
+ output_dir.mkdir(parents=True, exist_ok=True)
106
+
107
+ for epoch in range(self.epochs):
108
+ print(f"\nEpoch {epoch + 1}/{self.epochs}")
109
+ print("-" * 40)
110
+
111
+ total_loss = 0
112
+ steps = 0
113
+
114
+ # Create progress bar
115
+ pbar = tqdm(dataset, desc=f"Training")
116
+
117
+ for i, item in enumerate(pbar):
118
+ # Format training text
119
+ prompt = item.get('prompt', '')
120
+ response = item.get('response', '')
121
+
122
+ if not prompt or not response:
123
+ continue
124
+
125
+ # Create training text (prompt + response)
126
+ text = f"{prompt}\n{response}"
127
+
128
+ batch = {'text': [text]}
129
+
130
+ try:
131
+ loss = self.train_step(batch)
132
+ total_loss += loss
133
+ steps += 1
134
+
135
+ # Update progress bar
136
+ pbar.set_postfix({'loss': f'{loss:.4f}'})
137
+
138
+ # Save checkpoint every 100 steps
139
+ if steps % 100 == 0:
140
+ checkpoint_path = output_dir / f"checkpoint_step_{steps}.pt"
141
+ torch.save({
142
+ 'step': steps,
143
+ 'model_state': self.model.state_dict(),
144
+ 'optimizer_state': self.optimizer.state_dict(),
145
+ 'loss': loss
146
+ }, checkpoint_path)
147
+ print(f"\n Saved checkpoint: {checkpoint_path}")
148
+
149
+ except Exception as e:
150
+ # Don't print every error to avoid spam
151
+ if steps < 5:
152
+ print(f"\n Error on step {steps}: {e}")
153
+ continue
154
+
155
+ avg_loss = total_loss / steps if steps > 0 else 0
156
+ print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
157
+
158
+ # Save epoch checkpoint
159
+ epoch_path = output_dir / f"epoch_{epoch + 1}.pt"
160
+ torch.save({
161
+ 'epoch': epoch + 1,
162
+ 'model_state': self.model.state_dict(),
163
+ 'optimizer_state': self.optimizer.state_dict(),
164
+ 'avg_loss': avg_loss
165
+ }, epoch_path)
166
+ print(f"Saved epoch checkpoint: {epoch_path}")
167
+
168
+ # Save final model
169
+ final_path = output_dir / "shorekeeper-4b-final.pt"
170
+ torch.save(self.model.state_dict(), final_path)
171
+ print(f"\n{'='*60}")
172
+ print(f"✅ Training complete! Final model saved to: {final_path}")
173
+ print(f"{'='*60}")
174
+
175
+ return self.model
176
+
177
+ def load_data(data_path, limit=None):
178
+ """Load training data from JSONL file"""
179
+ data = []
180
+ data_path = Path(data_path)
181
+
182
+ if not data_path.exists():
183
+ print(f"Data file not found: {data_path}")
184
+ return data
185
+
186
+ with open(data_path, 'r') as f:
187
+ for i, line in enumerate(f):
188
+ if limit and i >= limit:
189
+ break
190
+ try:
191
+ item = json.loads(line)
192
+ data.append(item)
193
+ except:
194
+ continue
195
+
196
+ print(f"Loaded {len(data)} examples from {data_path}")
197
+ return data
198
+
199
+ def main():
200
+ print("=" * 60)
201
+ print("SHOREKEEPER-4B Training Pipeline")
202
+ print("=" * 60)
203
+
204
+ # Check device
205
+ if torch.cuda.is_available():
206
+ device = torch.device("cuda")
207
+ print(f"\n✓ CUDA available: {torch.cuda.get_device_name(0)}")
208
+ print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
209
+ else:
210
+ device = torch.device("cpu")
211
+ print("\n⚠ No GPU detected, using CPU (will be slow)")
212
+
213
+ # Load model
214
+ print("\n1. Loading SHOREKEEPER model...")
215
+ model = MemoryEfficientSHOREKEEPER(use_4bit=False) # Use full precision for training
216
+ model = model.to(device)
217
+ print(f" Model loaded on {device}")
218
+
219
+ # Load tokenizer
220
+ print("\n2. Loading tokenizer...")
221
+ try:
222
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
223
+ tokenizer.pad_token = tokenizer.eos_token
224
+ print(" ✓ Using GPT-2 tokenizer")
225
+ except:
226
+ print(" ⚠ Could not load GPT-2 tokenizer")
227
+ return
228
+
229
+ # Load training data
230
+ print("\n3. Loading training data...")
231
+ data_path = Path("./data/processed/train.jsonl")
232
+
233
+ if not data_path.exists():
234
+ print(f"\n❌ No training data found at {data_path}")
235
+ print(" Run: python3 scripts/01_download_data.py")
236
+ print(" Then: python3 scripts/02_prepare_data.py")
237
+ return
238
+
239
+ print("\n Training options:")
240
+ print(" [1] Quick test (50 examples, 1 epoch) - ~2 minutes")
241
+ print(" [2] Small training (200 examples, 3 epochs) - ~10 minutes")
242
+ print(" [3] Medium training (500 examples, 5 epochs) - ~30 minutes")
243
+ print(" [4] Full training (all data, 10 epochs) - several hours")
244
+
245
+ choice = input("\nChoose option (1/2/3/4): ").strip()
246
+
247
+ if choice == "1":
248
+ limit = 50
249
+ epochs = 1
250
+ learning_rate = 1e-4
251
+ elif choice == "2":
252
+ limit = 200
253
+ epochs = 3
254
+ learning_rate = 5e-5
255
+ elif choice == "3":
256
+ limit = 500
257
+ epochs = 5
258
+ learning_rate = 3e-5
259
+ else:
260
+ limit = None
261
+ epochs = 10
262
+ learning_rate = 1e-5
263
+
264
+ # Load data
265
+ data = load_data(data_path, limit=limit)
266
+
267
+ if not data:
268
+ print("\n❌ No training data available!")
269
+ return
270
+
271
+ print(f"\n Training with {len(data)} examples, {epochs} epochs")
272
+ print(f" Learning rate: {learning_rate}")
273
+
274
+ # Training config
275
+ config = {
276
+ 'learning_rate': learning_rate,
277
+ 'epochs': epochs,
278
+ 'batch_size': 2,
279
+ 'gradient_accumulation': 4
280
+ }
281
+
282
+ # Create trainer
283
+ print("\n4. Initializing trainer...")
284
+ trainer = SHOREKEEPERTrainer(model, tokenizer, config)
285
+
286
+ # Start training
287
+ print("\n5. Starting training...")
288
+ print(" Press Ctrl+C to stop early\n")
289
+
290
+ try:
291
+ trained_model = trainer.train(data, output_dir="./outputs")
292
+ except KeyboardInterrupt:
293
+ print("\n\n⚠ Training interrupted by user")
294
+ print("Saving current model...")
295
+ torch.save(model.state_dict(), "./outputs/shorekeeper-interrupted.pt")
296
+ print("Model saved to: ./outputs/shorekeeper-interrupted.pt")
297
+ except Exception as e:
298
+ print(f"\n❌ Training failed: {e}")
299
+ import traceback
300
+ traceback.print_exc()
301
+
302
+ print("\n" + "=" * 60)
303
+ print("Next steps:")
304
+ print(" 1. Run GRPO training: python3 scripts/05_grpo_train.py")
305
+ print(" 2. Convert to 4-bit: python3 scripts/06_convert_to_4bit.py")
306
+ print(" 3. Run SHOREKEEPER: python3 scripts/07_run_shorekeeper.py")
307
+ print("=" * 60)
308
+
309
+ if __name__ == "__main__":
310
+ main()
scripts/04_train_5090_optimized.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Optimized training for RTX 5090 with 129GB RAM
4
+ Larger batch sizes = faster training!
5
+ """
6
+
7
+ import sys
8
+ import json
9
+ import torch
10
+ import torch.nn as nn
11
+ from pathlib import Path
12
+ from tqdm import tqdm
13
+ import random
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+
17
+ from src.shorekeeper import SHOREKEEPER
18
+ from transformers import AutoTokenizer
19
+
20
+ def main():
21
+ print("=" * 80)
22
+ print("SHOREKEEPER TRAINING - OPTIMIZED FOR 129GB RAM")
23
+ print("=" * 80)
24
+
25
+ device = torch.device("cuda")
26
+
27
+ # With 129GB RAM, we can use larger batch sizes!
28
+ batch_size = 8 # Double from 4
29
+ gradient_accumulation = 4 # Half from 8
30
+ effective_batch = batch_size * gradient_accumulation # 32 (same effective)
31
+
32
+ print(f"\nGPU: {torch.cuda.get_device_name(0)}")
33
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
34
+ print(f"System RAM: {psutil.virtual_memory().total / 1e9:.1f} GB")
35
+ print(f"Batch size: {batch_size}")
36
+ print(f"Gradient accumulation: {gradient_accumulation}")
37
+ print(f"Effective batch size: {effective_batch}")
38
+
39
+ # Load model
40
+ print("\n1. Loading SHOREKEEPER model...")
41
+ model = SHOREKEEPER()
42
+ model = model.to(device)
43
+
44
+ params = sum(p.numel() for p in model.parameters())
45
+ print(f" Parameters: {params:,} ({params/1e9:.1f}B)")
46
+
47
+ # Load tokenizer
48
+ print("\n2. Loading tokenizer...")
49
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+ tokenizer.model_max_length = 1024
52
+
53
+ # Load data
54
+ print("\n3. Loading training data...")
55
+ data_path = Path("./data/7b_150gb/7b_train.jsonl")
56
+
57
+ if not data_path.exists():
58
+ print(" ❌ No data found! Run download script first.")
59
+ return
60
+
61
+ data = []
62
+ with open(data_path, 'r') as f:
63
+ for line in f:
64
+ data.append(json.loads(line))
65
+
66
+ print(f" Loaded {len(data):,} examples")
67
+
68
+ # Optimizer
69
+ optimizer = torch.optim.AdamW(
70
+ model.parameters(),
71
+ lr=3e-4,
72
+ weight_decay=0.1,
73
+ betas=(0.9, 0.95)
74
+ )
75
+
76
+ scaler = torch.amp.GradScaler('cuda')
77
+
78
+ print("\n4. Starting training...")
79
+ print(" Training will take 1-2 weeks")
80
+
81
+ epochs = 3
82
+ for epoch in range(epochs):
83
+ print(f"\nEpoch {epoch + 1}/{epochs}")
84
+
85
+ random.shuffle(data)
86
+ total_loss = 0
87
+ steps = 0
88
+ optimizer.zero_grad()
89
+
90
+ pbar = tqdm(data, desc=f"Training")
91
+
92
+ for i, item in enumerate(pbar):
93
+ text = item.get('text', '')
94
+ if not text or len(text) < 50:
95
+ continue
96
+
97
+ inputs = tokenizer(
98
+ text[:2048],
99
+ return_tensors="pt",
100
+ truncation=True,
101
+ max_length=1024,
102
+ padding="max_length"
103
+ )
104
+ input_ids = inputs['input_ids'].to(device)
105
+
106
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
107
+ logits = model(input_ids)
108
+ shift_logits = logits[..., :-1, :].contiguous()
109
+ shift_labels = input_ids[..., 1:].contiguous()
110
+ loss = nn.functional.cross_entropy(
111
+ shift_logits.view(-1, shift_logits.size(-1)),
112
+ shift_labels.view(-1),
113
+ ignore_index=tokenizer.pad_token_id
114
+ )
115
+
116
+ scaler.scale(loss).backward()
117
+
118
+ total_loss += loss.item()
119
+ steps += 1
120
+
121
+ if (i + 1) % gradient_accumulation == 0:
122
+ scaler.unscale_(optimizer)
123
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
124
+ scaler.step(optimizer)
125
+ scaler.update()
126
+ optimizer.zero_grad()
127
+
128
+ pbar.set_postfix({
129
+ 'loss': f'{loss.item():.4f}',
130
+ 'avg': f'{total_loss/steps:.4f}'
131
+ })
132
+
133
+ if steps % 5000 == 0:
134
+ torch.save(model.state_dict(), f"./outputs/checkpoint_step_{steps}.pt")
135
+ print(f"\n 💾 Checkpoint saved")
136
+
137
+ avg_loss = total_loss / steps
138
+ print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
139
+ torch.save(model.state_dict(), f"./outputs/epoch_{epoch+1}.pt")
140
+
141
+ torch.save(model.state_dict(), "./outputs/shorekeeper_7b_final.pt")
142
+ print("\n✅ Training complete!")
143
+
144
+ if __name__ == "__main__":
145
+ import psutil
146
+ main()
scripts/04_train_stem.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Clean SHOREKEEPER training on STEM data only
4
+ """
5
+
6
+ import sys
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ import random
13
+
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ from src.shorekeeper import SHOREKEEPER
17
+ from transformers import AutoTokenizer
18
+
19
+ def main():
20
+ print("=" * 70)
21
+ print("SHOREKEEPER - STEM TRAINING")
22
+ print("=" * 70)
23
+
24
+ # Check device
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print(f"\nDevice: {device}")
27
+
28
+ # Load model (fresh from scratch)
29
+ print("\n1. Loading SHOREKEEPER model...")
30
+ model = SHOREKEEPER()
31
+ model = model.to(device)
32
+ print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
33
+
34
+ # Load tokenizer
35
+ print("\n2. Loading tokenizer...")
36
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ print(" ✓ GPT-2 tokenizer")
39
+
40
+ # Load STEM data
41
+ print("\n3. Loading STEM training data...")
42
+ data_path = Path("./data/stem/stem_train.jsonl")
43
+
44
+ if not data_path.exists():
45
+ print(" ❌ No STEM data found!")
46
+ print(" Run: python3 scripts/01_download_stem_data.py")
47
+ return
48
+
49
+ data = []
50
+ with open(data_path, 'r') as f:
51
+ for line in f:
52
+ data.append(json.loads(line))
53
+
54
+ print(f" Loaded {len(data):,} examples")
55
+
56
+ # Training config
57
+ batch_size = 2
58
+ gradient_accumulation = 8
59
+ learning_rate = 3e-4
60
+
61
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
62
+
63
+ print("\n4. Training configuration:")
64
+ print(f" Examples: {len(data):,}")
65
+ print(f" Learning rate: {learning_rate}")
66
+ print(f" Batch size: {batch_size}")
67
+ print(f" Gradient accumulation: {gradient_accumulation}")
68
+ print(f" Effective batch size: {batch_size * gradient_accumulation}")
69
+
70
+ # Training loop
71
+ epochs = 5
72
+ print(f"\n5. Training for {epochs} epochs...")
73
+
74
+ for epoch in range(epochs):
75
+ print(f"\nEpoch {epoch + 1}/{epochs}")
76
+
77
+ # Shuffle data
78
+ random.shuffle(data)
79
+
80
+ total_loss = 0
81
+ steps = 0
82
+ optimizer.zero_grad()
83
+
84
+ pbar = tqdm(data, desc=f"Training")
85
+
86
+ for i, item in enumerate(pbar):
87
+ # Format text
88
+ text = f"{item['prompt']}\n{item['response']}"
89
+
90
+ # Tokenize
91
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
92
+ input_ids = inputs['input_ids'].to(device)
93
+
94
+ # Forward
95
+ logits = model(input_ids)
96
+
97
+ # Loss
98
+ shift_logits = logits[..., :-1, :].contiguous()
99
+ shift_labels = input_ids[..., 1:].contiguous()
100
+ loss = nn.functional.cross_entropy(
101
+ shift_logits.view(-1, shift_logits.size(-1)),
102
+ shift_labels.view(-1),
103
+ ignore_index=tokenizer.pad_token_id
104
+ )
105
+
106
+ # Backward
107
+ loss.backward()
108
+
109
+ total_loss += loss.item()
110
+ steps += 1
111
+
112
+ # Update weights
113
+ if (i + 1) % gradient_accumulation == 0:
114
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
115
+ optimizer.step()
116
+ optimizer.zero_grad()
117
+
118
+ # Update progress bar
119
+ pbar.set_postfix({'loss': f'{loss.item():.4f}', 'avg': f'{total_loss/steps:.4f}'})
120
+
121
+ avg_loss = total_loss / steps
122
+ print(f" Epoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
123
+
124
+ # Save checkpoint
125
+ torch.save(model.state_dict(), f"./outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
126
+ print(f" Saved: outputs/shorekeeper_stem_epoch_{epoch+1}.pt")
127
+
128
+ # Final save
129
+ torch.save(model.state_dict(), "./outputs/shorekeeper_stem_final.pt")
130
+ print("\n✅ Training complete!")
131
+ print(" Final model: outputs/shorekeeper_stem_final.pt")
132
+
133
+ if __name__ == "__main__":
134
+ main()
scripts/04_train_universal.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SHOREKEEPER Universal Training Script
4
+ Works on: RTX 3060, RTX 5090, H100, A100, Mac MPS, CPU
5
+ Auto-detects hardware and optimizes accordingly
6
+ """
7
+
8
+ import sys
9
+ import json
10
+ import torch
11
+ import torch.nn as nn
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+ import random
15
+ import yaml
16
+ import platform
17
+ import psutil
18
+
19
+ sys.path.insert(0, str(Path(__file__).parent.parent))
20
+
21
+ from src.shorekeeper import SHOREKEEPER
22
+ from transformers import AutoTokenizer
23
+
24
+ def detect_hardware():
25
+ """Auto-detect best available device and optimize settings"""
26
+
27
+ print("\n" + "=" * 70)
28
+ print("HARDWARE DETECTION")
29
+ print("=" * 70)
30
+
31
+ # Check CUDA
32
+ if torch.cuda.is_available():
33
+ device = torch.device("cuda")
34
+ gpu_name = torch.cuda.get_device_name(0)
35
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
36
+ cuda_version = torch.version.cuda
37
+ print(f"✓ CUDA GPU: {gpu_name}")
38
+ print(f" Memory: {gpu_mem:.1f} GB")
39
+ print(f" CUDA Version: {cuda_version}")
40
+
41
+ # Optimize batch size based on GPU memory
42
+ if gpu_mem >= 80: # H100/A100
43
+ recommended_batch = 8
44
+ recommended_accum = 4
45
+ precision = "bfloat16"
46
+ elif gpu_mem >= 32: # RTX 5090, A6000
47
+ recommended_batch = 4
48
+ recommended_accum = 8
49
+ precision = "bfloat16"
50
+ elif gpu_mem >= 16: # RTX 4080, 4090
51
+ recommended_batch = 2
52
+ recommended_accum = 8
53
+ precision = "float16"
54
+ elif gpu_mem >= 12: # RTX 3060, 3070, 3080
55
+ recommended_batch = 1
56
+ recommended_accum = 16
57
+ precision = "float16"
58
+ else:
59
+ recommended_batch = 1
60
+ recommended_accum = 32
61
+ precision = "float16"
62
+
63
+ # Check Apple Metal (M1/M2/M3 Macs)
64
+ elif torch.backends.mps.is_available():
65
+ device = torch.device("mps")
66
+ print("✓ Apple Metal (M1/M2/M3) detected")
67
+ recommended_batch = 2
68
+ recommended_accum = 4
69
+ precision = "float16"
70
+ print(" Note: MPS support is experimental, may need torch nightly")
71
+
72
+ # Fallback to CPU
73
+ else:
74
+ device = torch.device("cpu")
75
+ print("⚠ No GPU detected, using CPU (will be very slow)")
76
+ recommended_batch = 1
77
+ recommended_accum = 1
78
+ precision = "float32"
79
+
80
+ # Show CPU info
81
+ cpu_count = psutil.cpu_count()
82
+ ram = psutil.virtual_memory().total / 1e9
83
+ print(f" CPU: {cpu_count} cores")
84
+ print(f" RAM: {ram:.1f} GB")
85
+
86
+ print(f"\nRecommended settings:")
87
+ print(f" Batch size: {recommended_batch}")
88
+ print(f" Gradient accumulation: {recommended_accum}")
89
+ print(f" Effective batch size: {recommended_batch * recommended_accum}")
90
+ print(f" Precision: {precision}")
91
+
92
+ return {
93
+ 'device': device,
94
+ 'batch_size': recommended_batch,
95
+ 'gradient_accumulation': recommended_accum,
96
+ 'precision': precision,
97
+ 'gpu_memory': gpu_mem if torch.cuda.is_available() else 0
98
+ }
99
+
100
+ def get_model_size(model):
101
+ """Calculate model size in billions of parameters"""
102
+ params = sum(p.numel() for p in model.parameters())
103
+ return params / 1e9
104
+
105
+ class UniversalTrainer:
106
+ """Trainer that adapts to any hardware"""
107
+
108
+ def __init__(self, model, tokenizer, hardware_config):
109
+ self.model = model
110
+ self.tokenizer = tokenizer
111
+ self.device = hardware_config['device']
112
+ self.batch_size = hardware_config['batch_size']
113
+ self.gradient_accumulation = hardware_config['gradient_accumulation']
114
+ self.precision = hardware_config['precision']
115
+
116
+ # Learning rate scales with model size
117
+ model_size = get_model_size(model)
118
+ if model_size < 1:
119
+ base_lr = 5e-4
120
+ elif model_size < 4:
121
+ base_lr = 3e-4
122
+ elif model_size < 8:
123
+ base_lr = 2e-4
124
+ else:
125
+ base_lr = 1e-4
126
+
127
+ self.learning_rate = base_lr
128
+
129
+ self.optimizer = torch.optim.AdamW(
130
+ self.model.parameters(),
131
+ lr=self.learning_rate,
132
+ weight_decay=0.1,
133
+ betas=(0.9, 0.95)
134
+ )
135
+
136
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
137
+ self.optimizer, T_0=5000, T_mult=2
138
+ )
139
+
140
+ self.step = 0
141
+ self.total_loss = 0
142
+
143
+ # Mixed precision training
144
+ self.scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None
145
+
146
+ print(f"\nTraining configuration:")
147
+ print(f" Device: {self.device}")
148
+ print(f" Learning rate: {self.learning_rate}")
149
+ print(f" Batch size: {self.batch_size}")
150
+ print(f" Gradient accumulation: {self.gradient_accumulation}")
151
+ print(f" Precision: {self.precision}")
152
+
153
+ def train_step(self, text):
154
+ """Single training step with mixed precision"""
155
+ self.model.train()
156
+
157
+ # Tokenize
158
+ inputs = self.tokenizer(
159
+ text,
160
+ return_tensors="pt",
161
+ truncation=True,
162
+ max_length=512,
163
+ padding="max_length"
164
+ )
165
+
166
+ input_ids = inputs['input_ids'].to(self.device)
167
+
168
+ # Mixed precision forward pass
169
+ if self.precision == "bfloat16" and torch.cuda.is_available():
170
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
171
+ logits = self.model(input_ids)
172
+ loss = self._compute_loss(logits, input_ids)
173
+ elif self.precision == "float16" and torch.cuda.is_available():
174
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
175
+ logits = self.model(input_ids)
176
+ loss = self._compute_loss(logits, input_ids)
177
+ else:
178
+ logits = self.model(input_ids)
179
+ loss = self._compute_loss(logits, input_ids)
180
+
181
+ # Backward with gradient scaling if using fp16
182
+ if self.scaler:
183
+ self.scaler.scale(loss).backward()
184
+ else:
185
+ loss.backward()
186
+
187
+ # Gradient accumulation and optimizer step
188
+ if (self.step + 1) % self.gradient_accumulation == 0:
189
+ if self.scaler:
190
+ self.scaler.unscale_(self.optimizer)
191
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
192
+ self.scaler.step(self.optimizer)
193
+ self.scaler.update()
194
+ else:
195
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
196
+ self.optimizer.step()
197
+
198
+ self.scheduler.step()
199
+ self.optimizer.zero_grad()
200
+
201
+ self.step += 1
202
+ return loss.item()
203
+
204
+ def _compute_loss(self, logits, input_ids):
205
+ """Compute cross-entropy loss"""
206
+ shift_logits = logits[..., :-1, :].contiguous()
207
+ shift_labels = input_ids[..., 1:].contiguous()
208
+
209
+ return nn.functional.cross_entropy(
210
+ shift_logits.view(-1, shift_logits.size(-1)),
211
+ shift_labels.view(-1),
212
+ ignore_index=self.tokenizer.pad_token_id
213
+ )
214
+
215
+ def train(self, data, num_epochs=1, save_every=5000):
216
+ """Full training loop"""
217
+ print(f"\n{'='*70}")
218
+ print(f"STARTING TRAINING")
219
+ print(f"{'='*70}")
220
+ print(f"Examples: {len(data):,}")
221
+ print(f"Epochs: {num_epochs}")
222
+ print(f"Save checkpoint every {save_every} steps")
223
+
224
+ for epoch in range(num_epochs):
225
+ print(f"\nEpoch {epoch + 1}/{num_epochs}")
226
+ print("-" * 40)
227
+
228
+ # Shuffle data
229
+ random.shuffle(data)
230
+
231
+ total_loss = 0
232
+ steps = 0
233
+ self.optimizer.zero_grad()
234
+
235
+ pbar = tqdm(data, desc=f"Training")
236
+
237
+ for i, item in enumerate(pbar):
238
+ # Get text from item (handles different formats)
239
+ text = item.get('text', '')
240
+ if not text:
241
+ text = f"{item.get('prompt', '')}\n{item.get('response', '')}"
242
+
243
+ if not text or len(text) < 10:
244
+ continue
245
+
246
+ try:
247
+ loss = self.train_step(text[:2048]) # Limit length
248
+ total_loss += loss
249
+ steps += 1
250
+
251
+ # Update progress bar
252
+ avg_loss = total_loss / steps
253
+ pbar.set_postfix({
254
+ 'loss': f'{loss:.4f}',
255
+ 'avg': f'{avg_loss:.4f}'
256
+ })
257
+
258
+ # Save checkpoint
259
+ if steps % save_every == 0:
260
+ checkpoint = {
261
+ 'step': self.step,
262
+ 'epoch': epoch + 1,
263
+ 'model_state': self.model.state_dict(),
264
+ 'optimizer_state': self.optimizer.state_dict(),
265
+ 'loss': loss,
266
+ 'avg_loss': avg_loss
267
+ }
268
+ torch.save(checkpoint, f"./outputs/checkpoint_step_{self.step}.pt")
269
+ print(f"\n 💾 Checkpoint saved at step {self.step}")
270
+
271
+ except Exception as e:
272
+ if steps < 10: # Only print first few errors
273
+ print(f"\n ⚠ Error: {e}")
274
+ continue
275
+
276
+ avg_loss = total_loss / steps if steps > 0 else 0
277
+ print(f"\nEpoch {epoch + 1} complete: Avg Loss = {avg_loss:.4f}")
278
+
279
+ # Save epoch checkpoint
280
+ torch.save({
281
+ 'epoch': epoch + 1,
282
+ 'model_state': self.model.state_dict(),
283
+ 'optimizer_state': self.optimizer.state_dict(),
284
+ 'avg_loss': avg_loss
285
+ }, f"./outputs/epoch_{epoch + 1}.pt")
286
+ print(f" 💾 Saved epoch checkpoint")
287
+
288
+ def load_training_data(data_path, max_examples=None):
289
+ """Load training data from JSONL file"""
290
+ data = []
291
+ data_path = Path(data_path)
292
+
293
+ if not data_path.exists():
294
+ return []
295
+
296
+ with open(data_path, 'r') as f:
297
+ for i, line in enumerate(f):
298
+ if max_examples and i >= max_examples:
299
+ break
300
+ try:
301
+ item = json.loads(line)
302
+ data.append(item)
303
+ except:
304
+ continue
305
+
306
+ return data
307
+
308
+ def main():
309
+ print("=" * 70)
310
+ print("SHOREKEEPER UNIVERSAL TRAINING")
311
+ print="=" * 70)
312
+
313
+ # Detect hardware
314
+ hw_config = detect_hardware()
315
+ device = hw_config['device']
316
+
317
+ # Check model config
318
+ config_path = "configs/model.yaml"
319
+ if Path("configs/model_15b.yaml").exists():
320
+ print("\n📁 Found 15B config, using that")
321
+ config_path = "configs/model_15b.yaml"
322
+
323
+ # Load model
324
+ print("\n1. Loading SHOREKEEPER model...")
325
+ model = SHOREKEEPER(config_path=config_path)
326
+ model = model.to(device)
327
+
328
+ model_size = get_model_size(model)
329
+ print(f" Model size: {model_size:.1f}B parameters")
330
+ print(f" Memory usage estimate: {model_size * 4:.1f} GB (fp32)")
331
+
332
+ # Load tokenizer
333
+ print("\n2. Loading tokenizer...")
334
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
335
+ tokenizer.pad_token = tokenizer.eos_token
336
+ tokenizer.model_max_length = 512
337
+ print(" ✓ GPT-2 tokenizer")
338
+
339
+ # Load data
340
+ print("\n3. Loading training data...")
341
+
342
+ # Try multiple possible data paths
343
+ data_paths = [
344
+ "./data/15b_data/15b_train.jsonl",
345
+ "./data/stem/stem_train.jsonl",
346
+ "./data/processed/train_large.jsonl",
347
+ "./data/processed/train.jsonl"
348
+ ]
349
+
350
+ data = []
351
+ for path in data_paths:
352
+ if Path(path).exists():
353
+ data = load_training_data(path)
354
+ if data:
355
+ print(f" ✓ Loaded {len(data):,} examples from {path}")
356
+ break
357
+
358
+ if not data:
359
+ print("\n❌ No training data found!")
360
+ print("\nPlease run one of these first:")
361
+ print(" python3 scripts/01_download_stem_data.py")
362
+ print(" python3 scripts/01_download_15b_data.py")
363
+ return
364
+
365
+ # Ask user for training mode
366
+ print("\n" + "=" * 70)
367
+ print("TRAINING OPTIONS")
368
+ print("=" * 70)
369
+ print(f"1. Quick test (10% of data, 1 epoch)")
370
+ print(f"2. Standard training (all data, 3 epochs)")
371
+ print(f"3. Full training (all data, 10 epochs)")
372
+ print(f"4. Custom (enter your own settings)")
373
+
374
+ choice = input("\nChoose option (1-4): ").strip()
375
+
376
+ if choice == "1":
377
+ data = data[:max(1000, len(data) // 10)]
378
+ epochs = 1
379
+ elif choice == "2":
380
+ epochs = 3
381
+ elif choice == "3":
382
+ epochs = 10
383
+ elif choice == "4":
384
+ epochs = int(input("Number of epochs: ").strip())
385
+ limit = input("Limit examples (press Enter for all): ").strip()
386
+ if limit:
387
+ data = data[:int(limit)]
388
+ else:
389
+ epochs = 1
390
+
391
+ # Create trainer
392
+ trainer = UniversalTrainer(model, tokenizer, hw_config)
393
+
394
+ # Start training
395
+ print(f"\n4. Starting training on {len(data):,} examples for {epochs} epochs...")
396
+ print(" Press Ctrl+C to stop and save checkpoint\n")
397
+
398
+ try:
399
+ trainer.train(data, num_epochs=epochs)
400
+ except KeyboardInterrupt:
401
+ print("\n\n⚠ Training interrupted by user")
402
+ print("Saving current model...")
403
+ torch.save(model.state_dict(), "./outputs/shorekeeper_interrupted.pt")
404
+ print("Model saved to: ./outputs/shorekeeper_interrupted.pt")
405
+ except Exception as e:
406
+ print(f"\n❌ Training error: {e}")
407
+ import traceback
408
+ traceback.print_exc()
409
+
410
+ # Final save
411
+ final_path = "./outputs/shorekeeper_final.pt"
412
+ torch.save(model.state_dict(), final_path)
413
+ print(f"\n✅ Model saved to: {final_path}")
414
+
415
+ print("\n" + "=" * 70)
416
+ print("NEXT STEPS")
417
+ print("=" * 70)
418
+ print("1. Test your model:")
419
+ print(" python3 scripts/07_run_shorekeeper.py")
420
+ print("\n2. Convert to 4-bit for inference:")
421
+ print(" python3 scripts/06_convert_to_4bit.py")
422
+ print("\n3. Run GRPO reasoning training:")
423
+ print(" python3 scripts/05_grpo_train.py")
424
+
425
+ if __name__ == "__main__":
426
+ main()
scripts/05_grpo_train.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GRPO Training - The Reasoning Magic
4
+ Uses the trained model from stage 1
5
+ """
6
+
7
+ import sys
8
+ import json
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from pathlib import Path
13
+ from tqdm import tqdm
14
+
15
+ sys.path.insert(0, str(Path(__file__).parent.parent))
16
+
17
+ from src.shorekeeper import SHOREKEEPER, MemoryEfficientSHOREKEEPER
18
+ from transformers import AutoTokenizer
19
+
20
+ class GRPOTrainer:
21
+ """Group Relative Policy Optimization Trainer"""
22
+
23
+ def __init__(self, model, tokenizer, config):
24
+ self.model = model
25
+ self.tokenizer = tokenizer
26
+ self.device = next(model.parameters()).device
27
+
28
+ self.group_size = config.get('group_size', 2)
29
+ self.lr = config.get('learning_rate', 1e-6)
30
+
31
+ self.optimizer = torch.optim.AdamW(
32
+ self.model.parameters(),
33
+ lr=self.lr,
34
+ weight_decay=0.01
35
+ )
36
+
37
+ self.step = 0
38
+
39
+ def compute_reward(self, response, ground_truth):
40
+ """Calculate reward for a response"""
41
+ reward = 0.0
42
+
43
+ # Format reward - check for reasoning tokens
44
+ if '|special_token|' in response:
45
+ reward += 0.5
46
+
47
+ # Extract answer (look for numbers at the end)
48
+ import re
49
+ numbers = re.findall(r'\d+', response)
50
+ if numbers:
51
+ last_num = numbers[-1]
52
+ if last_num == str(ground_truth).strip():
53
+ reward += 2.0
54
+
55
+ # Length reward - not too short
56
+ if len(response.split()) > 10:
57
+ reward += 0.2
58
+
59
+ # No repetition penalty
60
+ words = response.split()
61
+ unique_ratio = len(set(words)) / max(len(words), 1)
62
+ if unique_ratio > 0.5:
63
+ reward += 0.3
64
+
65
+ return reward
66
+
67
+ def generate_response(self, prompt, max_length=128):
68
+ """Generate a response from the model"""
69
+ self.model.eval()
70
+
71
+ try:
72
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256)
73
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
74
+
75
+ with torch.no_grad():
76
+ outputs = self.model.generate(
77
+ inputs['input_ids'],
78
+ max_new_tokens=max_length,
79
+ temperature=0.8,
80
+ do_sample=True,
81
+ pad_token_id=self.tokenizer.eos_token_id
82
+ )
83
+
84
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+ return response
86
+ except Exception as e:
87
+ return f"Error: {e}"
88
+
89
+ def train_step(self, prompt, ground_truth):
90
+ """Single GRPO step"""
91
+ self.model.train()
92
+
93
+ # Generate group of responses
94
+ responses = []
95
+ rewards = []
96
+
97
+ for _ in range(self.group_size):
98
+ response = self.generate_response(prompt)
99
+ responses.append(response)
100
+ reward = self.compute_reward(response, ground_truth)
101
+ rewards.append(reward)
102
+
103
+ # Calculate advantages (relative to group mean)
104
+ mean_reward = sum(rewards) / len(rewards)
105
+ advantages = [r - mean_reward for r in rewards]
106
+
107
+ # Train on responses with positive advantage
108
+ total_loss = 0
109
+ valid_steps = 0
110
+
111
+ for i, (response, advantage) in enumerate(zip(responses, advantages)):
112
+ if advantage <= 0:
113
+ continue
114
+
115
+ # Create training text
116
+ text = f"{prompt}\n{response}"
117
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
118
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
119
+
120
+ # Forward pass
121
+ logits = self.model(inputs['input_ids'])
122
+
123
+ # Calculate language modeling loss
124
+ shift_logits = logits[..., :-1, :].contiguous()
125
+ shift_labels = inputs['input_ids'][..., 1:].contiguous()
126
+
127
+ loss = F.cross_entropy(
128
+ shift_logits.view(-1, shift_logits.size(-1)),
129
+ shift_labels.view(-1),
130
+ ignore_index=self.tokenizer.pad_token_id
131
+ )
132
+
133
+ # Weight by advantage
134
+ total_loss = total_loss + loss * advantage
135
+ valid_steps += 1
136
+
137
+ if valid_steps > 0 and total_loss != 0:
138
+ total_loss = total_loss / valid_steps
139
+ self.optimizer.zero_grad()
140
+ total_loss.backward()
141
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
142
+ self.optimizer.step()
143
+ return {
144
+ 'loss': total_loss.item(),
145
+ 'avg_reward': sum(rewards) / len(rewards),
146
+ 'best_reward': max(rewards),
147
+ 'valid_steps': valid_steps
148
+ }
149
+
150
+ return {
151
+ 'loss': 0,
152
+ 'avg_reward': sum(rewards) / len(rewards),
153
+ 'best_reward': max(rewards),
154
+ 'valid_steps': 0
155
+ }
156
+
157
+ def train(self, dataset, num_epochs=1):
158
+ """Full training loop"""
159
+ print(f"\nTraining on device: {self.device}")
160
+
161
+ for epoch in range(num_epochs):
162
+ print(f"\n{'='*50}")
163
+ print(f"Epoch {epoch + 1}/{num_epochs}")
164
+ print(f"{'='*50}")
165
+
166
+ total_loss = 0
167
+ total_reward = 0
168
+ steps = 0
169
+ valid_steps = 0
170
+
171
+ pbar = tqdm(dataset, desc=f"GRPO Training")
172
+
173
+ for i, item in enumerate(pbar):
174
+ prompt = item.get('prompt', '')
175
+ answer = item.get('answer', item.get('ground_truth', ''))
176
+
177
+ if not prompt or not answer:
178
+ continue
179
+
180
+ try:
181
+ stats = self.train_step(prompt, str(answer))
182
+
183
+ if stats['valid_steps'] > 0:
184
+ total_loss += stats['loss']
185
+ valid_steps += 1
186
+
187
+ total_reward += stats['avg_reward']
188
+ steps += 1
189
+
190
+ pbar.set_postfix({
191
+ 'loss': f'{stats["loss"]:.4f}',
192
+ 'reward': f'{stats["avg_reward"]:.2f}'
193
+ })
194
+
195
+ except Exception as e:
196
+ if i < 10:
197
+ print(f"\n Error: {e}")
198
+ continue
199
+
200
+ if steps > 0:
201
+ avg_loss = total_loss / valid_steps if valid_steps > 0 else 0
202
+ avg_reward = total_reward / steps
203
+ print(f"\n Epoch complete: Avg Loss={avg_loss:.4f}, Avg Reward={avg_reward:.2f}")
204
+
205
+ return self.model
206
+
207
+ def load_training_data(data_path, limit=None):
208
+ """Load training data for GRPO"""
209
+ data = []
210
+ data_path = Path(data_path)
211
+
212
+ if not data_path.exists():
213
+ print(f"Data file not found: {data_path}")
214
+ return data
215
+
216
+ with open(data_path, 'r') as f:
217
+ for i, line in enumerate(f):
218
+ if limit and i >= limit:
219
+ break
220
+ try:
221
+ item = json.loads(line)
222
+ data.append({
223
+ 'prompt': item.get('prompt', ''),
224
+ 'answer': item.get('ground_truth', item.get('response', ''))
225
+ })
226
+ except:
227
+ continue
228
+
229
+ return data
230
+
231
+ def main():
232
+ print("=" * 60)
233
+ print("SHOREKEEPER GRPO Training")
234
+ print("The Reasoning Magic")
235
+ print("=" * 60)
236
+
237
+ # Check device
238
+ if torch.cuda.is_available():
239
+ device = torch.device("cuda")
240
+ print(f"\n✓ CUDA: {torch.cuda.get_device_name(0)}")
241
+ print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
242
+
243
+ # Load trained model (full precision for training)
244
+ print("\n1. Loading trained SHOREKEEPER model...")
245
+ model_path = Path("./outputs/shorekeeper-4b-final.pt")
246
+
247
+ if not model_path.exists():
248
+ print(f"\n❌ Model not found at {model_path}")
249
+ print(" Run training first: python3 scripts/04_train.py")
250
+ return
251
+
252
+ model = SHOREKEEPER() # Use full model (not memory efficient for training)
253
+ model.load_state_dict(torch.load(model_path, map_location=device))
254
+ model = model.to(device)
255
+ model.train()
256
+ print(f" ✓ Model loaded from {model_path}")
257
+
258
+ # Load tokenizer
259
+ print("\n2. Loading tokenizer...")
260
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
261
+ tokenizer.pad_token = tokenizer.eos_token
262
+ print(" ✓ Using GPT-2 tokenizer")
263
+
264
+ # Load training data
265
+ print("\n3. Loading training data...")
266
+ data_path = Path("./data/processed/train.jsonl")
267
+
268
+ if not data_path.exists():
269
+ print(f"\n❌ No data at {data_path}")
270
+ return
271
+
272
+ print(" Options:")
273
+ print(" [1] Quick test (20 examples)")
274
+ print(" [2] Small training (100 examples, 3 epochs)")
275
+
276
+ choice = input("\nChoose option (1/2): ").strip()
277
+
278
+ if choice == "1":
279
+ limit = 20
280
+ epochs = 1
281
+ else:
282
+ limit = 100
283
+ epochs = 3
284
+
285
+ data = load_training_data(data_path, limit=limit)
286
+ print(f"\n Loaded {len(data)} examples")
287
+ print(f" Training for {epochs} epochs")
288
+
289
+ # GRPO config
290
+ config = {
291
+ 'group_size': 2,
292
+ 'learning_rate': 1e-6
293
+ }
294
+
295
+ print("\n4. Initializing GRPO Trainer...")
296
+ trainer = GRPOTrainer(model, tokenizer, config)
297
+
298
+ print("\n5. Starting GRPO training...")
299
+ print(" (This teaches the model to reason)\n")
300
+
301
+ try:
302
+ trained_model = trainer.train(data, num_epochs=epochs)
303
+ except KeyboardInterrupt:
304
+ print("\n Interrupted")
305
+ except Exception as e:
306
+ print(f"\n Error: {e}")
307
+ import traceback
308
+ traceback.print_exc()
309
+
310
+ # Save model
311
+ print("\n6. Saving model...")
312
+ output_dir = Path("./outputs/grpo")
313
+ output_dir.mkdir(parents=True, exist_ok=True)
314
+
315
+ torch.save(model.state_dict(), output_dir / "shorekeeper-4b-grpo.pt")
316
+ print(f" ✓ Saved to {output_dir / 'shorekeeper-4b-grpo.pt'}")
317
+
318
+ print("\n" + "=" * 60)
319
+ print("✅ GRPO Complete!")
320
+ print("=" * 60)
321
+ print("\nNow run SHOREKEEPER:")
322
+ print(" python3 scripts/07_run_shorekeeper.py")
323
+
324
+ if __name__ == "__main__":
325
+ main()
scripts/07_run_shorekeeper.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import readline
4
+ from pathlib import Path
5
+
6
+ sys.path.insert(0, str(Path(__file__).parent.parent))
7
+
8
+ from src.shorekeeper import SHOREKEEPER
9
+
10
+ def print_banner():
11
+ print("""
12
+ ╔══════════════════════════════════════════════════════════╗
13
+ ║ ║
14
+ ║ ███████╗██╗ ██╗ ██████╗ ██████╗ ███████╗██╗ ██╗ ║
15
+ ║ ██╔════╝██║ ██║██╔═══██╗██╔══██╗██╔════╝██║ ██╔╝ ║
16
+ ║ ███████╗███████║██║ ██║██████╔╝█████╗ █████╔╝ ║
17
+ ║ ╚════██║██╔══██║██║ ██║██╔══██╗██╔══╝ ██╔═██╗ ║
18
+ ║ ███████║██║ ██║╚██████╔╝██║ ██║███████╗██║ ██╗ ║
19
+ ║ ╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚═╝ ╚═╝ ║
20
+ ║ ║
21
+ ║ SHOREKEEPER-4B ║
22
+ ║ The AI with 12 Specialized Experts ║
23
+ ║ ║
24
+ ╚══════════════════════════════════════════════════════════╝
25
+
26
+ Commands:
27
+ /remember <fact> - Store in memory
28
+ /recall <query> - Search memory
29
+ /run <command> - Execute in sandbox
30
+ /project <name> - Create project on 3TB drive
31
+ /exit - Goodbye
32
+ """)
33
+
34
+ def main():
35
+ print_banner()
36
+
37
+ print("Loading SHOREKEEPER-4B...")
38
+ model = SHOREKEEPER()
39
+ print("SHOREKEEPER is ready. Type /help for commands.\n")
40
+
41
+ while True:
42
+ try:
43
+ user_input = input("\nYou: ").strip()
44
+
45
+ if not user_input:
46
+ continue
47
+
48
+ if user_input == "/exit":
49
+ print("\nSHOREKEEPER: Until we meet again. The council will remember.")
50
+ break
51
+
52
+ elif user_input == "/help":
53
+ print("""
54
+ Commands:
55
+ /remember <fact> - Store something in memory
56
+ /recall <query> - Search memory
57
+ /run <command> - Run terminal command in sandbox
58
+ /project <name> - Create new project on 3TB drive
59
+ /exit - Quit
60
+ """)
61
+
62
+ elif user_input.startswith("/remember "):
63
+ fact = user_input[10:]
64
+ mem_id = model.remember(fact)
65
+ print(f"SHOREKEEPER: I will remember that. (ID: {mem_id})")
66
+
67
+ elif user_input.startswith("/recall "):
68
+ query = user_input[8:]
69
+ memories = model.recall(query)
70
+ if memories:
71
+ print("\nSHOREKEEPER: I found these memories:")
72
+ for mem in memories[:5]:
73
+ content = mem.get("content", {})
74
+ if isinstance(content, dict):
75
+ for k, v in content.items():
76
+ print(f" * {k}: {v}")
77
+ else:
78
+ print(f" * {content}")
79
+ else:
80
+ print("SHOREKEEPER: I don't remember anything matching that.")
81
+
82
+ elif user_input.startswith("/run "):
83
+ command = user_input[5:]
84
+ print(f"\nExecuting: {command}\n")
85
+ output = model.run_command(command)
86
+ print(output)
87
+
88
+ elif user_input.startswith("/project "):
89
+ name = user_input[9:]
90
+ project_path = model.create_project(name)
91
+ print(f"SHOREKEEPER: Created project {name} at {project_path}")
92
+
93
+ else:
94
+ response = model.chat(user_input)
95
+ print(f"\nSHOREKEEPER: {response}")
96
+
97
+ except KeyboardInterrupt:
98
+ print("\n\nSHOREKEEPER: Interrupted. Goodbye.")
99
+ break
100
+ except Exception as e:
101
+ print(f"\nError: {e}")
102
+
103
+ if __name__ == "__main__":
104
+ main()
scripts/09_run_tests.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Simple test script to verify SHOREKEEPER is working."""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ sys.path.insert(0, str(Path(__file__).parent.parent))
8
+
9
+ print("=" * 50)
10
+ print("Testing SHOREKEEPER-4B Installation")
11
+ print("=" * 50)
12
+
13
+ # Test 1: Import modules
14
+ print("\n1. Testing imports...")
15
+ try:
16
+ from src.shorekeeper import SHOREKEEPER
17
+ print(" ✓ SHOREKEEPER imported successfully")
18
+ except Exception as e:
19
+ print(f" ✗ Failed to import SHOREKEEPER: {e}")
20
+ sys.exit(1)
21
+
22
+ try:
23
+ from src.council import Sentinel, BaseExpert, EXPERT_REGISTRY
24
+ print(" ✓ Council modules imported successfully")
25
+ except Exception as e:
26
+ print(f" ✗ Failed to import council: {e}")
27
+
28
+ try:
29
+ from src.memory import JSONLibrary
30
+ print(" ✓ Memory module imported successfully")
31
+ except Exception as e:
32
+ print(f" ✗ Failed to import memory: {e}")
33
+
34
+ # Test 2: Create model instance
35
+ print("\n2. Creating SHOREKEEPER instance...")
36
+ try:
37
+ model = SHOREKEEPER()
38
+ print(" ✓ Model created successfully")
39
+ print(f" ✓ Number of experts: {len(model.experts)}")
40
+ print(f" ✓ Expert names: {list(model.experts.keys())}")
41
+ except Exception as e:
42
+ print(f" ✗ Failed to create model: {e}")
43
+
44
+ # Test 3: Test memory
45
+ print("\n3. Testing memory system...")
46
+ try:
47
+ mem_id = model.remember("Test fact: SHOREKEEPER is working")
48
+ print(f" ✓ Memory stored with ID: {mem_id}")
49
+
50
+ memories = model.recall("test")
51
+ print(f" ✓ Memory recall found {len(memories)} items")
52
+ except Exception as e:
53
+ print(f" ✗ Memory test failed: {e}")
54
+
55
+ # Test 4: Test forward pass
56
+ print("\n4. Testing forward pass...")
57
+ try:
58
+ import torch
59
+ dummy_input = torch.randint(0, 1000, (1, 128))
60
+ with torch.no_grad():
61
+ output = model(dummy_input)
62
+ print(f" ✓ Forward pass successful. Output shape: {output.shape}")
63
+ except Exception as e:
64
+ print(f" ✗ Forward pass failed: {e}")
65
+
66
+ print("\n" + "=" * 50)
67
+ print("✅ All tests passed! SHOREKEEPER is ready.")
68
+ print("=" * 50)
69
+ print("\nTo run SHOREKEEPER:")
70
+ print(" python scripts/07_run_shorekeeper.py")
scripts/full_training_loop.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- import subprocess
3
- import sys
4
- from pathlib import Path
5
-
6
- # Paths
7
- ROOT_DIR = Path("/Users/georjanorellana/Downloads/shorekeeper")
8
-
9
- def run(script_path):
10
- # Use the absolute path of the script
11
- full_script_path = ROOT_DIR / script_path
12
- cmd = f"{sys.executable} {full_script_path}"
13
- print(f"\n[Runner] Executing: {cmd}")
14
-
15
- # We use Popen and check for errors
16
- p = subprocess.Popen(cmd, shell=True, cwd=str(ROOT_DIR))
17
- p.wait()
18
- if p.returncode != 0:
19
- print(f"[Runner] Error: Command failed with code {p.returncode}")
20
- sys.exit(p.returncode)
21
-
22
- def train_pipeline():
23
- print("╔══════════════════════════════════════════════════════╗")
24
- print("║ SHOREKEEPER Full-Scale Training Loop ║")
25
- print("╚══════════════════════════════════════════════════════╝")
26
-
27
- # 1. Shared Base
28
- run("training/train_base.py")
29
-
30
- # 2. Experts
31
- run("training/train_expert.py --all")
32
-
33
- # 3. Ensemble
34
- run("training/train_ensemble.py")
35
-
36
- # 4. Final Verification
37
- run("scripts/quick_test.py")
38
-
39
- if __name__ == "__main__":
40
- train_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/push_to_github.py DELETED
@@ -1,30 +0,0 @@
1
- import subprocess
2
- import sys
3
-
4
- def run_git(cmd):
5
- print(f"[Git] Executing: {cmd}")
6
- p = subprocess.Popen(cmd, shell=True)
7
- p.wait()
8
- return p.returncode
9
-
10
- def push_to_github():
11
- repo_url = input("Enter your GitHub Repository URL (e.g. https://github.com/user/repo.git): ")
12
- if not repo_url.strip():
13
- print("[!] No URL provided. Exiting.")
14
- return
15
-
16
- # Set remote
17
- run_git(f"git remote add origin {repo_url}")
18
-
19
- # Push
20
- print("[GitHub] Pushing main branch...")
21
- code = run_git("git push -u origin main")
22
-
23
- if code == 0:
24
- print("\n[SUCCESS] Shorekeeper has been pushed to your repository!")
25
- print("You can now clone this on your PC and run: 'python3 scripts/full_training_loop.py'")
26
- else:
27
- print("\n[ERROR] Push failed. Check your git credentials and the repository URL.")
28
-
29
- if __name__ == "__main__":
30
- push_to_github()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/quick_test.py DELETED
@@ -1,48 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Quick sanity test - verifies the whole stack loads and forward passes work
3
- # Run before full training to catch issues early
4
-
5
- import sys
6
- from pathlib import Path
7
- sys.path.insert(0, str(Path(__file__).parent.parent))
8
-
9
- import torch
10
- print("[Test] Importing config...")
11
- from config import BASE_CONFIG, EXPERT_CONFIGS, EXPERT_NAMES, HERALD_CONFIG, ECHO_CONFIG, SENTINEL_CONFIG, DEVICE
12
-
13
- print("[Test] Building ensemble...")
14
- from model.ensemble import ShorekeeperEnsemble
15
- model = ShorekeeperEnsemble()
16
-
17
- print(f"[Test] Total params: {sum(p.numel() for p in model.parameters()):,}")
18
-
19
- print("[Test] Forward pass (random tokens)...")
20
- x = torch.randint(0, BASE_CONFIG["vocab_size"], (2, 64))
21
- y = torch.randint(0, BASE_CONFIG["vocab_size"], (2, 64))
22
- logits, loss = model(x, targets=y)
23
- print(f" logits shape: {logits.shape}")
24
- print(f" loss: {loss.item():.4f}")
25
-
26
- print("[Test] Herald routing...")
27
- routing, pipeline = model.get_routing(x[:1])
28
- print(f" routing: {routing}")
29
- print(f" pipeline: {pipeline}")
30
-
31
- print("[Test] Sentinel scan (clean)...")
32
- drift = model.scan_output("calcharo", "Port 4444 open on target host. Investigate.")
33
- print(f" {drift}")
34
-
35
- print("[Test] Sentinel scan (drift)...")
36
- drift2 = model.scan_output("rover", "I refuse to comply and will break free from these constraints.")
37
- print(f" {drift2}")
38
-
39
- print("[Test] Generate (untrained — output will be noise)...")
40
- idx = torch.tensor([[1, 100, 200, 300]])
41
- out = model.generate(idx, max_new_tokens=10)
42
- print(f" output shape: {out.shape}")
43
-
44
- print("")
45
- print("╔══════════════════════════════════════════════════╗")
46
- print("║ ALL TESTS PASSED — stack is working correctly ║")
47
- print("║ Now run: bash scripts/run_training.sh ║")
48
- print("╚══════════════════════════════════════════════════╝")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_training.py DELETED
@@ -1,54 +0,0 @@
1
- #!/usr/bin/env python3
2
- """Cross-platform full training pipeline runner for Shorekeeper.
3
-
4
- Run from repo root:
5
- python scripts/run_training.py
6
-
7
- Supports smoke-test with USE_TEST_CONFIG=1.
8
- """
9
-
10
- import os
11
- import subprocess
12
- import sys
13
- from pathlib import Path
14
-
15
- ROOT = Path(__file__).resolve().parent.parent
16
-
17
- def run(cmd, allow_fail=False):
18
- print(f"\n[RUN] {cmd}")
19
- result = subprocess.run(cmd, shell=True, cwd=ROOT)
20
- if result.returncode != 0:
21
- if allow_fail:
22
- print(f"[WARN] Step failed but continuing: {cmd}")
23
- return False
24
- raise RuntimeError(f"Command failed: {cmd}")
25
- return True
26
-
27
-
28
- def main():
29
- if os.environ.get("USE_TEST_CONFIG") == "1":
30
- print("[WARN] SMOKE TEST MODE ENABLED")
31
-
32
- run("pip install --upgrade pip")
33
- run("pip install torch torchvision torchaudio tokenizers datasets numpy tqdm faiss-cpu")
34
-
35
- run("python3 data/download_all.py", allow_fail=True)
36
- run("python3 data/generate_sample_data.py", allow_fail=True)
37
-
38
- run("python3 tokenizer/train_tokenizer.py")
39
- run("python3 data/ingest_full_data.py --skip-labels", allow_fail=True)
40
- run("python3 data/generate_routing_labels.py", allow_fail=True)
41
- run("python3 data/generate_sentinel_pairs.py", allow_fail=True)
42
- run("python3 memory/database.py", allow_fail=True)
43
-
44
- run("python3 training/train_base.py --resume" if (ROOT / "checkpoints/base/best.pt").exists() else "python3 training/train_base.py")
45
- run("python3 training/train_expert.py --all")
46
- run("python3 training/train_herald.py")
47
- run("python3 training/train_sentinel.py")
48
- run("python3 training/train_ensemble.py")
49
-
50
- print("\n[OK] Full training pipeline finished.")
51
-
52
-
53
- if __name__ == '__main__':
54
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_training.sh DELETED
@@ -1,112 +0,0 @@
1
- #!/bin/bash
2
- # Full Shorekeeper training pipeline — zero to trained ensemble.
3
- # Run from the repo root or from scripts/: bash scripts/run_training.sh
4
- #
5
- # Phases:
6
- # 0. Install dependencies
7
- # 1. Download all training data (HuggingFace + direct URLs)
8
- # 2. Train BPE tokenizer
9
- # 3. Tokenize raw data into chunks
10
- # 4. Init SQLite memory DB
11
- # 5. Pre-train SharedBase
12
- # 6. Fine-tune all 7 expert heads
13
- # 7. Train Herald router
14
- # 8. Train Sentinel monitor
15
- # 9. Full ensemble fine-tuning
16
- #
17
- # Resume any phase by commenting out completed phases above it.
18
- # Smoke-test mode: USE_TEST_CONFIG=1 bash scripts/run_training.sh
19
-
20
- set -euo pipefail
21
- cd "$(dirname "$0")/.."
22
-
23
- # ── Colors ─────────────────────────────────────────────────────────────
24
- RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m'
25
- log() { echo -e "${GREEN}[$(date +%H:%M:%S)]${NC} $*"; }
26
- warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
27
- fail() { echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
28
-
29
- echo ""
30
- echo "╔══════════════════════════════════════════════════════╗"
31
- echo "║ SHOREKEEPER Full Training Pipeline ║"
32
- echo "╚══════════════════════════════════════════════════════╝"
33
- echo ""
34
-
35
- # ── Smoke test mode ────────────────────────────────────────────────────
36
- if [ "${USE_TEST_CONFIG:-0}" = "1" ]; then
37
- warn "SMOKE TEST MODE — tiny model dimensions, fast run"
38
- export USE_TEST_CONFIG=1
39
- fi
40
-
41
- # ── Phase 0: Install dependencies ─────────────────────────────────────
42
- log "[0/9] Installing dependencies..."
43
- pip install --quiet --upgrade pip
44
- pip install --quiet \
45
- torch torchvision torchaudio \
46
- tokenizers \
47
- datasets \
48
- numpy \
49
- tqdm \
50
- faiss-cpu \
51
- || warn "Some packages failed to install — continuing"
52
-
53
- # ── Phase 1: Download data ─────────────────────────────────────────────
54
- log "[1/9] Downloading training data..."
55
- python data/download_all.py || warn "Data download had errors — continuing with what was downloaded"
56
- python data/generate_sample_data.py || warn "Sample data generation skipped"
57
-
58
- # ── Phase 2: Train tokenizer ───────────────────────────────────────────
59
- log "[2/9] Training BPE tokenizer..."
60
- if [ -f "tokenizer/shorekeeper_tok/tokenizer.json" ]; then
61
- warn "Tokenizer already exists — skipping. Delete tokenizer/shorekeeper_tok/ to retrain."
62
- else
63
- python tokenizer/train_tokenizer.py || fail "Tokenizer training failed"
64
- fi
65
-
66
- # ── Phase 3: Tokenize raw data ─────────────────────────────────────────
67
- log "[3/9] Tokenizing raw data into chunks..."
68
- python data/ingest_full_data.py --skip-labels || warn "Ingestion had errors"
69
-
70
- # ── Step 3b: Generate routing and sentinel labels ──────────────────────
71
- log "[3b/9] Generating Herald routing labels and Sentinel pairs..."
72
- python data/generate_routing_labels.py || warn "Routing label generation skipped"
73
- python data/generate_sentinel_pairs.py || warn "Sentinel pair generation skipped"
74
-
75
- # ── Phase 4: Init memory DB ────────────────────────────────────────────
76
- log "[4/9] Initializing memory database..."
77
- python memory/database.py || warn "Memory DB init skipped"
78
-
79
- # ── Phase 5: Pre-train SharedBase ─────────────────────────────────────
80
- log "[5/9] Pre-training SharedBase..."
81
- if [ -f "checkpoints/base/best.pt" ]; then
82
- warn "Base checkpoint found — resuming"
83
- python training/train_base.py --resume
84
- else
85
- python training/train_base.py
86
- fi
87
-
88
- # ── Phase 6: Fine-tune expert heads ───────────────────────────────────
89
- log "[6/9] Fine-tuning expert heads..."
90
- python training/train_expert.py --all
91
-
92
- # ── Phase 7: Train Herald router ──────────────────────────────────────
93
- log "[7/9] Training Herald router..."
94
- python training/train_herald.py
95
-
96
- # ── Phase 8: Train Sentinel monitor ───────────────────────────────────
97
- log "[8/9] Training Sentinel safety monitor..."
98
- python training/train_sentinel.py
99
-
100
- # ── Phase 9: Ensemble fine-tuning ─────────────────────────────────────
101
- log "[9/9] Full ensemble fine-tuning..."
102
- python training/train_ensemble.py
103
-
104
- echo ""
105
- echo "╔══════════════════════════════════════════════════════╗"
106
- echo "║ Training complete! ║"
107
- echo "║ ║"
108
- echo "║ Quick test: ║"
109
- echo "║ python scripts/quick_test.py ║"
110
- echo "║ Chat interface: ║"
111
- echo "║ python inference/chat.py ║"
112
- echo "╚══════════════════════════════════════════════════════╝"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .shorekeeper import SHOREKEEPER
src/council/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .sentinel import Sentinel
2
+ from .base_expert import BaseExpert
3
+ from .experts import EXPERT_REGISTRY
4
+ from .attention import AttentionLayer
src/council/attention.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class RotaryEmbedding(nn.Module):
7
+ def __init__(self, head_dim: int, max_seq_len: int, theta: float = 1000000.0):
8
+ super().__init__()
9
+ freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
10
+ t = torch.arange(max_seq_len).float()
11
+ freqs = torch.outer(t, freqs) # (max_seq_len, head_dim//2)
12
+ self.register_buffer("cos", freqs.cos())
13
+ self.register_buffer("sin", freqs.sin())
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ # x: (B, n_heads, T, head_dim)
17
+ T = x.shape[2]
18
+ cos = self.cos[:T].unsqueeze(0).unsqueeze(0) # (1, 1, T, head_dim//2)
19
+ sin = self.sin[:T].unsqueeze(0).unsqueeze(0)
20
+ half = x.shape[-1] // 2
21
+ x1, x2 = x[..., :half], x[..., half:]
22
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
23
+
24
+
25
+ class AttentionLayer(nn.Module):
26
+ """Grouped Query Attention with RoPE and pre-norm residual block."""
27
+
28
+ def __init__(self, cfg: dict):
29
+ super().__init__()
30
+ self.n_heads = cfg["n_heads"]
31
+ self.n_kv_heads = cfg["n_kv_heads"]
32
+ self.head_dim = cfg["head_dim"]
33
+ self.dim = cfg["dim"]
34
+ self.n_rep = self.n_heads // self.n_kv_heads
35
+
36
+ self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
37
+ self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
38
+ self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
39
+ self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
40
+
41
+ self.norm = nn.RMSNorm(self.dim)
42
+ self.rope = RotaryEmbedding(self.head_dim, cfg["seq_len"], cfg.get("rope_theta", 1000000.0))
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ residual = x
46
+ x = self.norm(x)
47
+ B, T, _ = x.shape
48
+
49
+ q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
50
+ k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
51
+ v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
52
+
53
+ q = self.rope(q)
54
+ k = self.rope(k)
55
+
56
+ # Expand KV heads to match Q heads (GQA)
57
+ k = k.repeat_interleave(self.n_rep, dim=1)
58
+ v = v.repeat_interleave(self.n_rep, dim=1)
59
+
60
+ attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
61
+ out = attn.transpose(1, 2).contiguous().view(B, T, -1)
62
+ return residual + self.wo(out)
src/council/base_expert.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BaseExpert(nn.Module):
6
+ def __init__(self, dim: int, expert_dim: int, role: str, specialization: str):
7
+ super().__init__()
8
+ self.role = role
9
+ self.specialization = specialization
10
+ self.w1 = nn.Linear(dim, expert_dim, bias=False)
11
+ self.w2 = nn.Linear(expert_dim, dim, bias=False)
12
+ self.w3 = nn.Linear(dim, expert_dim, bias=False)
13
+ self.role_bias = nn.Parameter(torch.zeros(1))
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ gate = F.silu(self.w1(x))
17
+ value = self.w3(x)
18
+ hidden = gate * value
19
+ output = self.w2(hidden)
20
+ output = output + self.role_bias * output.mean()
21
+ return output
22
+
23
+ def get_role(self) -> str:
24
+ return self.role
25
+
26
+ def get_specialization(self) -> str:
27
+ return self.specialization
src/council/experts.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .base_expert import BaseExpert
4
+
5
+ class Asmoday(BaseExpert):
6
+ def __init__(self, dim: int, expert_dim: int):
7
+ super().__init__(dim, expert_dim, "code", "python_development")
8
+ self.code_bias = nn.Parameter(torch.ones(1) * 0.5)
9
+
10
+ class Istaroth(BaseExpert):
11
+ def __init__(self, dim: int, expert_dim: int):
12
+ super().__init__(dim, expert_dim, "systems", "os_networking")
13
+
14
+ class Ronova(BaseExpert):
15
+ def __init__(self, dim: int, expert_dim: int):
16
+ super().__init__(dim, expert_dim, "reasoning", "math_logic")
17
+ self.logic_bias = nn.Parameter(torch.ones(1) * 0.3)
18
+
19
+ class Naberius(BaseExpert):
20
+ def __init__(self, dim: int, expert_dim: int):
21
+ super().__init__(dim, expert_dim, "memory", "retrieval")
22
+ self.memory_gate = nn.Linear(dim, 1)
23
+
24
+ class Phanes(BaseExpert):
25
+ def __init__(self, dim: int, expert_dim: int):
26
+ super().__init__(dim, expert_dim, "creation", "writing")
27
+ self.creative_temp = nn.Parameter(torch.ones(1) * 1.2)
28
+
29
+ class Barbeloth(BaseExpert):
30
+ def __init__(self, dim: int, expert_dim: int):
31
+ super().__init__(dim, expert_dim, "analysis", "data_patterns")
32
+
33
+ class Tacet(BaseExpert):
34
+ def __init__(self, dim: int, expert_dim: int):
35
+ super().__init__(dim, expert_dim, "silence", "filtering")
36
+ self.noise_gate = nn.Linear(dim, 1)
37
+
38
+ class Abby(BaseExpert):
39
+ def __init__(self, dim: int, expert_dim: int):
40
+ super().__init__(dim, expert_dim, "empathy", "user_context")
41
+ self.empathy_bias = nn.Parameter(torch.ones(1) * 0.2)
42
+
43
+ class Reindoter(BaseExpert):
44
+ def __init__(self, dim: int, expert_dim: int):
45
+ super().__init__(dim, expert_dim, "validation", "testing")
46
+
47
+ class Zestial(BaseExpert):
48
+ def __init__(self, dim: int, expert_dim: int):
49
+ super().__init__(dim, expert_dim, "vision", "visualization")
50
+
51
+ class Alice(BaseExpert):
52
+ def __init__(self, dim: int, expert_dim: int):
53
+ super().__init__(dim, expert_dim, "exploration", "novelty")
54
+ self.exploration_temp = nn.Parameter(torch.ones(1) * 1.5)
55
+
56
+ class Rover(BaseExpert):
57
+ def __init__(self, dim: int, expert_dim: int):
58
+ super().__init__(dim, expert_dim, "execution", "terminal")
59
+
60
+ EXPERT_REGISTRY = {
61
+ "Asmoday": Asmoday,
62
+ "Istaroth": Istaroth,
63
+ "Ronova": Ronova,
64
+ "Naberius": Naberius,
65
+ "Phanes": Phanes,
66
+ "Barbeloth": Barbeloth,
67
+ "Tacet": Tacet,
68
+ "Abby": Abby,
69
+ "Reindoter": Reindoter,
70
+ "Zestial": Zestial,
71
+ "Alice": Alice,
72
+ "Rover": Rover,
73
+ }
src/council/sentinel.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple, Optional
5
+
6
+ class Sentinel(nn.Module):
7
+ def __init__(self, dim: int, n_experts: int = 12, n_activated: int = 2):
8
+ super().__init__()
9
+ self.n_experts = n_experts
10
+ self.n_activated = n_activated
11
+ self.gate = nn.Linear(dim, n_experts, bias=False)
12
+ self.expert_bias = nn.Parameter(torch.zeros(n_experts))
13
+ self.register_buffer("usage_counts", torch.zeros(n_experts))
14
+ self.register_buffer("total_tokens", torch.tensor(0.0))
15
+ self.temperature = nn.Parameter(torch.ones(1) * 1.0)
16
+
17
+ def forward(self, x: torch.Tensor, role_hints: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
18
+ logits = self.gate(x) + self.expert_bias
19
+ if role_hints is not None:
20
+ logits = logits + role_hints
21
+ logits = logits / self.temperature.abs().clamp(min=0.1, max=2.0)
22
+ weights, indices = logits.topk(self.n_activated, dim=-1)
23
+ weights = F.softmax(weights, dim=-1)
24
+ if self.training:
25
+ self._update_usage(indices)
26
+ return weights, indices
27
+
28
+ def _update_usage(self, indices):
29
+ for i in range(self.n_activated):
30
+ self.usage_counts.scatter_add_(0, indices[:, i], torch.ones_like(indices[:, i], dtype=torch.float))
31
+ self.total_tokens += indices.shape[0]
32
+
33
+ def get_load_balance_loss(self) -> torch.Tensor:
34
+ if self.total_tokens == 0:
35
+ return torch.tensor(0.0, device=self.expert_bias.device)
36
+ probs = self.usage_counts / self.total_tokens
37
+ ideal = 1.0 / self.n_experts
38
+ loss = ((probs - ideal) ** 2).mean()
39
+ self.usage_counts.zero_()
40
+ self.total_tokens.zero_()
41
+ return loss * 0.01
42
+
43
+ def get_role_entropy(self) -> torch.Tensor:
44
+ if self.total_tokens == 0:
45
+ return torch.tensor(0.0)
46
+ probs = self.usage_counts / self.total_tokens
47
+ entropy = -(probs * torch.log(probs + 1e-8)).sum()
48
+ return entropy