import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Dict, Any, Optional from pathlib import Path import yaml from src.council.sentinel import Sentinel from src.council.experts import EXPERT_REGISTRY from src.council.attention import AttentionLayer from src.memory.json_library import JSONLibrary from src.sandbox.terminal import TerminalSandbox class SHOREKEEPER(nn.Module): def __init__(self, config_path: str = "configs/model.yaml"): super().__init__() with open(config_path, "r") as f: self.config = yaml.safe_load(f) with open("configs/memory.yaml", "r") as f: self.memory_config = yaml.safe_load(f) with open("configs/sandbox.yaml", "r") as f: self.sandbox_config = yaml.safe_load(f) model_cfg = self.config["model"] self.dim = model_cfg["dim"] self.vocab_size = model_cfg["vocab_size"] self.max_seq_len = model_cfg["seq_len"] self.token_embedding = nn.Embedding(model_cfg["vocab_size"], model_cfg["dim"]) self.experts = nn.ModuleDict() for expert_info in model_cfg["experts"]["members"]: name = expert_info["name"] expert_class = EXPERT_REGISTRY[name] self.experts[name] = expert_class(model_cfg["dim"], model_cfg["expert_dim"]) self.sentinel = Sentinel(model_cfg["dim"], model_cfg["n_experts"], model_cfg["n_activated"]) self.expert_names = model_cfg["experts"]["members"] self.expert_list = [self.experts[e["name"]] for e in self.expert_names] self.layers = nn.ModuleList([ AttentionLayer(model_cfg) for _ in range(model_cfg["n_layers"]) ]) self.moe_norms = nn.ModuleList([ nn.RMSNorm(model_cfg["dim"]) for _ in range(model_cfg["n_layers"]) ]) self.norm = nn.RMSNorm(model_cfg["dim"]) self.lm_head = nn.Linear(model_cfg["dim"], model_cfg["vocab_size"], bias=False) self.token_embedding.weight = self.lm_head.weight self.memory = JSONLibrary(self.memory_config["memory"]["path"]) try: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained("gpt2") self.tokenizer.pad_token = self.tokenizer.eos_token except Exception as e: print(f"Warning: Tokenizer not available: {e}") self.tokenizer = None try: self.sandbox = TerminalSandbox(self.sandbox_config["sandbox"]) except Exception as e: print(f"Warning: Sandbox not available: {e}") self.sandbox = None self.conversation_history = [] def forward(self, tokens: torch.Tensor, role_hints: Optional[torch.Tensor] = None): x = self.token_embedding(tokens) for layer, moe_norm in zip(self.layers, self.moe_norms): # Attention sub-layer (pre-norm + residual inside AttentionLayer) x = layer(x) # MoE FFN sub-layer with pre-norm + residual h = moe_norm(x) B, T, C = h.shape h_flat = h.view(-1, C) weights, indices = self.sentinel(h_flat, role_hints) out_flat = torch.zeros_like(h_flat) for i, expert in enumerate(self.expert_list): mask = (indices == i).any(dim=-1) if mask.any(): expert_out = expert(h_flat[mask]) expert_weights = (weights[mask] * (indices[mask] == i).float()).sum(dim=-1, keepdim=True) out_flat[mask] += expert_out * expert_weights x = x + out_flat.view(B, T, C) logits = self.lm_head(self.norm(x)) return logits def generate(self, input_ids, max_new_tokens=64, temperature=0.8, do_sample=True, pad_token_id=None): """Simple generation for inference""" self.eval() device = next(self.parameters()).device if not isinstance(input_ids, torch.Tensor): input_ids = torch.tensor([input_ids]) if isinstance(input_ids, list) else torch.tensor([[input_ids]]) input_ids = input_ids.to(device) generated = input_ids for _ in range(max_new_tokens): with torch.no_grad(): logits = self.forward(generated) next_token_logits = logits[:, -1, :] / temperature if do_sample: probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if pad_token_id and next_token.item() == pad_token_id: break return generated def chat(self, user_message: str, max_new_tokens: int = 128) -> str: if self.tokenizer is None: return "Tokenizer not available. Install transformers to use chat." memory_context = self.memory.get_context_string(limit=5) # Build prompt: memory + last 3 turns + current message history = "" for turn in self.conversation_history[-3:]: history += f"User: {turn['user']}\nAssistant: {turn['assistant']}\n" prompt = f"{memory_context}\n{history}User: {user_message}\nAssistant:" inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.max_seq_len - max_new_tokens ) input_ids = inputs["input_ids"].to(next(self.parameters()).device) output_ids = self.generate( input_ids, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.eos_token_id ) new_tokens = output_ids[0, input_ids.shape[1]:] response = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip() self.memory.store({"user": user_message, "assistant": response}, "conversation_history") self.conversation_history.append({"user": user_message, "assistant": response}) return response def remember(self, fact: str, category: str = "important_facts"): return self.memory.store({"fact": fact}, category) def recall(self, query: str) -> List[Dict]: return self.memory.recall(query) def run_command(self, command: str) -> str: if self.sandbox: output, code = self.sandbox.execute(command) return f"Exit code: {code}\n{output}" else: return "Sandbox not available. Install Docker to use this feature." def create_project(self, name: str) -> str: if self.sandbox: return self.sandbox.create_project(name) else: return "Sandbox not available. Install Docker to use this feature." class MemoryEfficientSHOREKEEPER(SHOREKEEPER): def __init__(self, config_path: str = "configs/model.yaml", use_4bit: bool = True): self.use_4bit = use_4bit super().__init__(config_path) if use_4bit and torch.cuda.is_available(): self._apply_4bit_quantization() def _apply_4bit_quantization(self): try: import bitsandbytes as bnb for name, module in self.named_modules(): if isinstance(module, nn.Linear) and module.in_features >= 1024: new_layer = bnb.nn.Linear4bit( module.in_features, module.out_features, bias=module.bias is not None, quant_type="nf4", compute_dtype=torch.bfloat16 ) new_layer.weight.data = module.weight.data if module.bias is not None: new_layer.bias.data = module.bias.data parent = self._get_parent_module(name) if parent: setattr(parent, name.split('.')[-1], new_layer) print(" ✓ 4-bit quantization applied") except: pass def _get_parent_module(self, name): parts = name.split('.') if len(parts) == 1: return None obj = self for part in parts[:-1]: obj = getattr(obj, part) return obj