SHOREKEEPER / src /shorekeeper.py
geoore's picture
Restructure to src/ layout with attention, per-layer MoE, and working chat
73400c8
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Optional
from pathlib import Path
import yaml
from src.council.sentinel import Sentinel
from src.council.experts import EXPERT_REGISTRY
from src.council.attention import AttentionLayer
from src.memory.json_library import JSONLibrary
from src.sandbox.terminal import TerminalSandbox
class SHOREKEEPER(nn.Module):
def __init__(self, config_path: str = "configs/model.yaml"):
super().__init__()
with open(config_path, "r") as f:
self.config = yaml.safe_load(f)
with open("configs/memory.yaml", "r") as f:
self.memory_config = yaml.safe_load(f)
with open("configs/sandbox.yaml", "r") as f:
self.sandbox_config = yaml.safe_load(f)
model_cfg = self.config["model"]
self.dim = model_cfg["dim"]
self.vocab_size = model_cfg["vocab_size"]
self.max_seq_len = model_cfg["seq_len"]
self.token_embedding = nn.Embedding(model_cfg["vocab_size"], model_cfg["dim"])
self.experts = nn.ModuleDict()
for expert_info in model_cfg["experts"]["members"]:
name = expert_info["name"]
expert_class = EXPERT_REGISTRY[name]
self.experts[name] = expert_class(model_cfg["dim"], model_cfg["expert_dim"])
self.sentinel = Sentinel(model_cfg["dim"], model_cfg["n_experts"], model_cfg["n_activated"])
self.expert_names = model_cfg["experts"]["members"]
self.expert_list = [self.experts[e["name"]] for e in self.expert_names]
self.layers = nn.ModuleList([
AttentionLayer(model_cfg) for _ in range(model_cfg["n_layers"])
])
self.moe_norms = nn.ModuleList([
nn.RMSNorm(model_cfg["dim"]) for _ in range(model_cfg["n_layers"])
])
self.norm = nn.RMSNorm(model_cfg["dim"])
self.lm_head = nn.Linear(model_cfg["dim"], model_cfg["vocab_size"], bias=False)
self.token_embedding.weight = self.lm_head.weight
self.memory = JSONLibrary(self.memory_config["memory"]["path"])
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer.pad_token = self.tokenizer.eos_token
except Exception as e:
print(f"Warning: Tokenizer not available: {e}")
self.tokenizer = None
try:
self.sandbox = TerminalSandbox(self.sandbox_config["sandbox"])
except Exception as e:
print(f"Warning: Sandbox not available: {e}")
self.sandbox = None
self.conversation_history = []
def forward(self, tokens: torch.Tensor, role_hints: Optional[torch.Tensor] = None):
x = self.token_embedding(tokens)
for layer, moe_norm in zip(self.layers, self.moe_norms):
# Attention sub-layer (pre-norm + residual inside AttentionLayer)
x = layer(x)
# MoE FFN sub-layer with pre-norm + residual
h = moe_norm(x)
B, T, C = h.shape
h_flat = h.view(-1, C)
weights, indices = self.sentinel(h_flat, role_hints)
out_flat = torch.zeros_like(h_flat)
for i, expert in enumerate(self.expert_list):
mask = (indices == i).any(dim=-1)
if mask.any():
expert_out = expert(h_flat[mask])
expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True)
out_flat[mask] += expert_out * expert_weights
x = x + out_flat.view(B, T, C)
logits = self.lm_head(self.norm(x))
return logits
def generate(self, input_ids, max_new_tokens=64, temperature=0.8, do_sample=True, pad_token_id=None):
"""Simple generation for inference"""
self.eval()
device = next(self.parameters()).device
if not isinstance(input_ids, torch.Tensor):
input_ids = torch.tensor([input_ids]) if isinstance(input_ids, list) else torch.tensor([[input_ids]])
input_ids = input_ids.to(device)
generated = input_ids
for _ in range(max_new_tokens):
with torch.no_grad():
logits = self.forward(generated)
next_token_logits = logits[:, -1, :] / temperature
if do_sample:
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if pad_token_id and next_token.item() == pad_token_id:
break
return generated
def chat(self, user_message: str, max_new_tokens: int = 128) -> str:
if self.tokenizer is None:
return "Tokenizer not available. Install transformers to use chat."
memory_context = self.memory.get_context_string(limit=5)
# Build prompt: memory + last 3 turns + current message
history = ""
for turn in self.conversation_history[-3:]:
history += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n"
prompt = f"{memory_context}\n{history}User: {user_message}\nAssistant:"
inputs = self.tokenizer(
prompt, return_tensors="pt", truncation=True,
max_length=self.max_seq_len - max_new_tokens
)
input_ids = inputs["input_ids"].to(next(self.parameters()).device)
output_ids = self.generate(
input_ids, max_new_tokens=max_new_tokens,
pad_token_id=self.tokenizer.eos_token_id
)
new_tokens = output_ids[0, input_ids.shape[1]:]
response = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
self.memory.store({"user": user_message, "assistant": response}, "conversation_history")
self.conversation_history.append({"user": user_message, "assistant": response})
return response
def remember(self, fact: str, category: str = "important_facts"):
return self.memory.store({"fact": fact}, category)
def recall(self, query: str) -> List[Dict]:
return self.memory.recall(query)
def run_command(self, command: str) -> str:
if self.sandbox:
output, code = self.sandbox.execute(command)
return f"Exit code: {code}\n{output}"
else:
return "Sandbox not available. Install Docker to use this feature."
def create_project(self, name: str) -> str:
if self.sandbox:
return self.sandbox.create_project(name)
else:
return "Sandbox not available. Install Docker to use this feature."
class MemoryEfficientSHOREKEEPER(SHOREKEEPER):
def __init__(self, config_path: str = "configs/model.yaml", use_4bit: bool = True):
self.use_4bit = use_4bit
super().__init__(config_path)
if use_4bit and torch.cuda.is_available():
self._apply_4bit_quantization()
def _apply_4bit_quantization(self):
try:
import bitsandbytes as bnb
for name, module in self.named_modules():
if isinstance(module, nn.Linear) and module.in_features >= 1024:
new_layer = bnb.nn.Linear4bit(
module.in_features, module.out_features,
bias=module.bias is not None,
quant_type="nf4", compute_dtype=torch.bfloat16
)
new_layer.weight.data = module.weight.data
if module.bias is not None:
new_layer.bias.data = module.bias.data
parent = self._get_parent_module(name)
if parent:
setattr(parent, name.split('.')[-1], new_layer)
print(" ✓ 4-bit quantization applied")
except:
pass
def _get_parent_module(self, name):
parts = name.split('.')
if len(parts) == 1:
return None
obj = self
for part in parts[:-1]:
obj = getattr(obj, part)
return obj