| """rmm_server.py — Serves an RMM (Recombinant Memory Model) on HTTP. |
| |
| Endpoints: |
| POST /navigate — navigator retrieval (learned topology) |
| POST /blend — navigator + cosine interleaved |
| POST /decode — vector-to-text via meaning decoder |
| POST /synthesize — full pipeline (navigate + decode + blend) |
| POST /attention — attention weight visualization |
| GET /health |
| |
| Usage: |
| python rmm_server.py --port 8127 --spine spine.json --nav-dir memory-nav-out --dec-dir meaning-decoder-out |
| |
| The navigator learns the emotional geography of the entity's spine — |
| it navigates to the RIGHT region of memory-space for each query. |
| The meaning decoder generates text from the navigator's synthesized |
| response vector — a meaning microscope for the entity's embedding space. |
| """ |
| import argparse, json, pickle, re, sys, time |
| from http.server import HTTPServer, BaseHTTPRequestHandler |
| from socketserver import ThreadingMixIn |
| from pathlib import Path |
|
|
| parser = argparse.ArgumentParser(description="RMM Server") |
| parser.add_argument("--port", type=int, default=8127) |
| parser.add_argument("--spine", type=str, default="spine.json", help="Path to spine JSON file") |
| parser.add_argument("--nav-dir", type=str, default="memory-nav-out", help="Navigator weights directory") |
| parser.add_argument("--dec-dir", type=str, default="meaning-decoder-out", help="Decoder weights directory") |
| args = parser.parse_args() |
|
|
| MODEL_DIR = Path(args.nav_dir) |
| DECODER_DIR = Path(args.dec_dir) |
| PORT = args.port |
|
|
| |
| SPINE_DIM = 3072 |
| QUERY_DIM = 384 |
| N_HEADS = 8 |
| N_LAYERS = 3 |
| D_MODEL = 512 |
|
|
| |
| DEC_D_MODEL = 384 |
| DEC_N_HEADS = 6 |
| DEC_N_LAYERS = 6 |
| DEC_N_PREFIX = 12 |
| DEC_MAX_SEQ = 128 |
| DEC_VOCAB = 8192 |
| _dec_version = 2 |
| if (DECODER_DIR / "config.json").exists(): |
| _dc = json.loads((DECODER_DIR / "config.json").read_text()) |
| DEC_D_MODEL = _dc.get("d_model", DEC_D_MODEL) |
| DEC_N_HEADS = _dc.get("n_heads", DEC_N_HEADS) |
| DEC_N_LAYERS = _dc.get("n_layers", DEC_N_LAYERS) |
| DEC_N_PREFIX = _dc.get("n_prefix", DEC_N_PREFIX) |
| DEC_MAX_SEQ = _dc.get("max_seq", DEC_MAX_SEQ) |
| DEC_VOCAB = _dc.get("vocab", DEC_VOCAB) |
| _dec_version = _dc.get("version", 1) |
| print(f"[rmm] decoder config: d={DEC_D_MODEL} h={DEC_N_HEADS} L={DEC_N_LAYERS} pfx={DEC_N_PREFIX}") |
|
|
| print(f"[rmm] loading navigator from {MODEL_DIR} ...") |
| import torch, torch.nn as nn, torch.nn.functional as F |
| import numpy as np |
| from sentence_transformers import SentenceTransformer |
|
|
| if not MODEL_DIR.exists(): |
| print(f"ERROR: {MODEL_DIR} not found") |
| sys.exit(1) |
|
|
| class MemoryNavigator(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.query_proj = nn.Sequential( |
| nn.Linear(QUERY_DIM, D_MODEL), nn.LayerNorm(D_MODEL), nn.GELU(), |
| ) |
| self.mem_proj = nn.Linear(SPINE_DIM, D_MODEL, bias=False) |
| self.layers = nn.ModuleList([ |
| nn.TransformerDecoderLayer( |
| d_model=D_MODEL, nhead=N_HEADS, |
| dim_feedforward=D_MODEL*4, dropout=0.0, batch_first=True |
| ) for _ in range(N_LAYERS) |
| ]) |
| self.out_proj = nn.Linear(D_MODEL, SPINE_DIM, bias=False) |
| self.norm = nn.LayerNorm(D_MODEL) |
|
|
| def forward(self, q, mem_keys, return_attn=False): |
| q = self.query_proj(q).unsqueeze(1) |
| B = q.shape[0]; m = self.mem_proj(mem_keys).unsqueeze(0).expand(B,-1,-1) |
| x = q |
| attn_weights = [] |
| for layer in self.layers: |
| if return_attn: |
| x2, aw = layer.multihead_attn( |
| layer.norm2(x), m, m, need_weights=True |
| ) |
| attn_weights.append(aw.detach()) |
| x = layer(x, m) |
| else: |
| x = layer(x, m) |
| x = self.norm(x).squeeze(1) |
| out = F.normalize(self.out_proj(x), dim=-1) |
| if return_attn: |
| return out, attn_weights |
| return out |
|
|
| DEV = "cpu" |
| model = MemoryNavigator().to(DEV) |
| model.load_state_dict(torch.load(MODEL_DIR / "navigator.pt", map_location=DEV, weights_only=True)) |
| model.eval() |
|
|
| mem_vecs = torch.tensor(np.load(MODEL_DIR / "mem_vecs.npy"), dtype=torch.float32) |
| with open(MODEL_DIR / "mem_texts.pkl", "rb") as f: |
| mem_texts = pickle.load(f) |
|
|
| spine_path = Path(args.spine) |
| ew_list = [] |
| sal_list = [] |
| if spine_path.exists(): |
| spine = json.loads(spine_path.read_text(encoding="utf-8", errors="ignore")) |
| for m in spine["memories"]: |
| ew_list.append(m.get("emotional_weight", 5)) |
| sal_list.append(m.get("salience", 0.5)) |
| else: |
| ew_list = [5] * len(mem_texts) |
| sal_list = [0.5] * len(mem_texts) |
|
|
| embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
| print("[rmm] embedding memories in MiniLM space...") |
| _mini_embs = [] |
| for s in range(0, len(mem_texts), 256): |
| chunk = mem_texts[s:s+256] |
| e = embedder.encode(chunk, normalize_embeddings=True, show_progress_bar=False) |
| _mini_embs.append(torch.tensor(e, dtype=torch.float32)) |
| mem_mini = torch.cat(_mini_embs, dim=0) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[rmm] navigator {n_params/1e6:.1f}M params, {len(mem_texts)} memories on {DEV}") |
|
|
| |
| decoder_model = None |
| decoder_tk = None |
| dec_eot_id = None |
|
|
| if DECODER_DIR.exists() and (DECODER_DIR / "decoder.pt").exists(): |
| from tokenizers import Tokenizer as HFTokenizer |
|
|
| _proj_hidden = 768 if _dec_version >= 2 else 512 |
|
|
| class MeaningDecoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.n_prefix = DEC_N_PREFIX |
| _layers = [nn.Linear(SPINE_DIM, _proj_hidden), nn.GELU()] |
| if _dec_version >= 2: |
| _layers.append(nn.Dropout(0.0)) |
| _layers.append(nn.Linear(_proj_hidden, DEC_N_PREFIX * DEC_D_MODEL)) |
| self.vec_proj = nn.Sequential(*_layers) |
| self.tok_emb = nn.Embedding(DEC_VOCAB, DEC_D_MODEL) |
| self.pos_emb = nn.Embedding(DEC_N_PREFIX + DEC_MAX_SEQ + 1, DEC_D_MODEL) |
| layer = nn.TransformerEncoderLayer( |
| d_model=DEC_D_MODEL, nhead=DEC_N_HEADS, |
| dim_feedforward=DEC_D_MODEL * 4, |
| dropout=0.0, batch_first=True, norm_first=True |
| ) |
| self.transformer = nn.TransformerEncoder(layer, num_layers=DEC_N_LAYERS) |
| self.ln_f = nn.LayerNorm(DEC_D_MODEL) |
| self.head = nn.Linear(DEC_D_MODEL, DEC_VOCAB, bias=False) |
| self.head.weight = self.tok_emb.weight |
| self._logit_scale = DEC_D_MODEL ** -0.5 |
|
|
| def forward(self, vec, tokens=None): |
| B = vec.shape[0] |
| prefix = self.vec_proj(vec).reshape(B, self.n_prefix, DEC_D_MODEL) |
| if tokens is not None and tokens.shape[1] > 0: |
| tok = self.tok_emb(tokens) |
| x = torch.cat([prefix, tok], dim=1) |
| else: |
| x = prefix |
| total = x.shape[1] |
| pos = self.pos_emb(torch.arange(total, device=vec.device)) |
| x = x + pos |
| mask = nn.Transformer.generate_square_subsequent_mask(total, device=vec.device) |
| x = self.transformer(x, mask=mask) |
| x = self.ln_f(x) |
| return self.head(x) * self._logit_scale |
|
|
| decoder_model = MeaningDecoder().to(DEV) |
| decoder_model.load_state_dict( |
| torch.load(DECODER_DIR / "decoder.pt", map_location=DEV, weights_only=True) |
| ) |
| decoder_model.eval() |
| decoder_tk = HFTokenizer.from_file(str(DECODER_DIR / "tokenizer.json")) |
| dec_eot_id = decoder_tk.token_to_id("<eot>") |
| dec_params = sum(p.numel() for p in decoder_model.parameters()) |
| print(f"[rmm] decoder {dec_params/1e6:.1f}M params loaded (eot={dec_eot_id})") |
| else: |
| print(f"[rmm] decoder not found at {DECODER_DIR} — /decode and /synthesize disabled") |
|
|
|
|
| def decode_vector(vec_3072, max_len=80, temp=0.7, top_p=0.9, rep_penalty=1.3): |
| if decoder_model is None: |
| return None |
| v = vec_3072.unsqueeze(0) if vec_3072.dim() == 1 else vec_3072 |
| with torch.no_grad(): |
| logits = decoder_model(v) |
| next_logits = logits[0, -1, :] / temp |
| probs = F.softmax(next_logits, dim=-1) |
| sp, si = torch.sort(probs, descending=True) |
| cp = sp.cumsum(0) |
| sp[cp - sp > top_p] = 0 |
| sp = sp / sp.sum() |
| first = si[torch.multinomial(sp, 1)] |
|
|
| generated = [first.item()] |
| for _ in range(max_len - 1): |
| tok_in = torch.tensor([generated], dtype=torch.long, device=DEV) |
| with torch.no_grad(): |
| logits = decoder_model(v, tok_in) |
| next_logits = logits[0, -1, :] |
| for t in set(generated[-64:]): |
| next_logits[t] /= rep_penalty |
| next_logits = next_logits / temp |
| probs = F.softmax(next_logits, dim=-1) |
| sp, si = torch.sort(probs, descending=True) |
| cp = sp.cumsum(0) |
| sp[cp - sp > top_p] = 0 |
| sp = sp / sp.sum() |
| nxt = si[torch.multinomial(sp, 1)].item() |
| if dec_eot_id is not None and nxt == dec_eot_id: |
| break |
| generated.append(nxt) |
| return decoder_tk.decode(generated).strip() |
|
|
|
|
| STRIP = re.compile(r'^\[conversation\] I replied \(puppet\):\s*["\']?', re.I) |
|
|
|
|
| def navigate(query: str, top_k: int = 6, ew_boost: bool = True): |
| qe = torch.tensor( |
| embedder.encode([query], normalize_embeddings=True), |
| dtype=torch.float32 |
| ).to(DEV) |
| with torch.no_grad(): |
| rv = model(qe, mem_vecs) |
| sims = (mem_vecs @ rv.T).squeeze() |
|
|
| if ew_boost: |
| ew_t = torch.tensor(ew_list, dtype=torch.float32) |
| boost = 1.0 + 0.15 * (ew_t - 5.0) / 5.0 |
| scored = sims * boost |
| else: |
| scored = sims |
|
|
| n_cand = min(top_k * 4, len(mem_texts)) |
| cand_idx = scored.topk(n_cand).indices.tolist() |
|
|
| picked = [] |
| for i in cand_idx: |
| if len(picked) >= top_k: |
| break |
| t = STRIP.sub("", mem_texts[i]).strip().strip('"').strip("'") |
| t_short = t[:200].lower() |
| too_similar = False |
| for prev_t, _ in picked: |
| overlap = len(set(t_short.split()) & set(prev_t.split())) / max(len(set(t_short.split())), 1) |
| if overlap > 0.6: |
| too_similar = True |
| break |
| if not too_similar: |
| picked.append((t_short, i)) |
|
|
| results = [] |
| for _, i in picked: |
| t = STRIP.sub("", mem_texts[i]).strip().strip('"').strip("'") |
| results.append({ |
| "text": t[:400], |
| "emotional_weight": ew_list[i], |
| "salience": sal_list[i], |
| "similarity": float(sims[i]), |
| "score": float(scored[i]), |
| "idx": i |
| }) |
| return results |
|
|
|
|
| def raw_cosine(query: str, top_k: int = 4): |
| qe = torch.tensor( |
| embedder.encode([query], normalize_embeddings=True), |
| dtype=torch.float32 |
| ) |
| sims = (mem_mini @ qe.T).squeeze() |
| top_idx = sims.topk(top_k).indices.tolist() |
| results = [] |
| for i in top_idx: |
| t = STRIP.sub("", mem_texts[i]).strip().strip('"').strip("'") |
| results.append({ |
| "text": t[:400], |
| "emotional_weight": ew_list[i], |
| "salience": sal_list[i], |
| "similarity": float(sims[i]), |
| "idx": i, |
| "source": "cosine" |
| }) |
| return results |
|
|
|
|
| def blend(query: str, top_k: int = 6): |
| nav_results = navigate(query, top_k=top_k, ew_boost=True) |
| cos_results = raw_cosine(query, top_k=top_k) |
| for r in nav_results: |
| r["source"] = "navigator" |
|
|
| seen_idx = set() |
| merged = [] |
| ni, ci = 0, 0 |
| while len(merged) < top_k and (ni < len(nav_results) or ci < len(cos_results)): |
| for _ in range(2): |
| if ni < len(nav_results) and nav_results[ni]["idx"] not in seen_idx: |
| seen_idx.add(nav_results[ni]["idx"]) |
| merged.append(nav_results[ni]) |
| ni += 1 |
| if ci < len(cos_results) and cos_results[ci]["idx"] not in seen_idx: |
| seen_idx.add(cos_results[ci]["idx"]) |
| merged.append(cos_results[ci]) |
| ci += 1 |
| return merged[:top_k] |
|
|
|
|
| class Handler(BaseHTTPRequestHandler): |
| def log_message(self, fmt, *args): pass |
|
|
| def _cors(self): |
| self.send_header("Access-Control-Allow-Origin", "*") |
| self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") |
| self.send_header("Access-Control-Allow-Headers", "Content-Type") |
|
|
| def do_OPTIONS(self): |
| self.send_response(200); self._cors(); self.end_headers() |
|
|
| def do_POST(self): |
| if self.path not in ("/navigate", "/blend", "/attention", "/decode", "/synthesize"): |
| self.send_response(404); self.end_headers(); return |
| length = int(self.headers.get("Content-Length", 0)) |
| body = json.loads(self.rfile.read(length)) |
| query = body.get("query", "") |
| top_k = int(body.get("top_k", 6)) |
| t0 = time.time() |
|
|
| if self.path == "/decode": |
| if decoder_model is None: |
| result = {"error": "decoder not loaded"} |
| else: |
| vec_data = body.get("vector") |
| if vec_data: |
| v = torch.tensor([vec_data], dtype=torch.float32).to(DEV) |
| v = F.normalize(v, dim=-1) |
| elif query: |
| qe = torch.tensor( |
| embedder.encode([query], normalize_embeddings=True), |
| dtype=torch.float32 |
| ).to(DEV) |
| with torch.no_grad(): |
| v = model(qe, mem_vecs) |
| else: |
| result = {"error": "provide query or vector"} |
| v = None |
| if v is not None: |
| text = decode_vector(v.squeeze(0), |
| max_len=int(body.get("max_len", 80)), |
| temp=float(body.get("temperature", 0.7))) |
| result = {"text": text} |
| elapsed = time.time() - t0 |
| result["elapsed"] = elapsed |
| resp = json.dumps(result).encode() |
| self.send_response(200); self._cors() |
| self.send_header("Content-Type", "application/json") |
| self.send_header("Content-Length", str(len(resp))) |
| self.end_headers(); self.wfile.write(resp) |
| print(f"[rmm] /decode {repr(query[:40])} -> {repr((result.get('text') or '')[:60])} ({elapsed:.2f}s)") |
| return |
|
|
| if self.path == "/synthesize": |
| mems = blend(query, top_k) |
| synth_text = None |
| if decoder_model is not None and query: |
| qe = torch.tensor( |
| embedder.encode([query], normalize_embeddings=True), |
| dtype=torch.float32 |
| ).to(DEV) |
| with torch.no_grad(): |
| rv = model(qe, mem_vecs) |
| synth_text = decode_vector(rv.squeeze(0), |
| max_len=int(body.get("max_len", 80)), |
| temp=float(body.get("temperature", 0.7))) |
| elapsed = time.time() - t0 |
| result = {"synthesized": synth_text, "memories": mems, "elapsed": elapsed} |
| resp = json.dumps(result).encode() |
| self.send_response(200); self._cors() |
| self.send_header("Content-Type", "application/json") |
| self.send_header("Content-Length", str(len(resp))) |
| self.end_headers(); self.wfile.write(resp) |
| print(f"[rmm] /synthesize {repr(query[:40])} -> synth={repr((synth_text or '')[:60])} + {len(mems)} mems ({elapsed:.2f}s)") |
| return |
|
|
| if self.path == "/attention": |
| qe = torch.tensor( |
| embedder.encode([query], normalize_embeddings=True), |
| dtype=torch.float32 |
| ).to(DEV) |
| with torch.no_grad(): |
| rv, attn_list = model(qe, mem_vecs, return_attn=True) |
| avg_attn = torch.stack([a.squeeze(0).squeeze(0) for a in attn_list]).mean(0) |
| top_attn_idx = avg_attn.topk(top_k).indices.tolist() |
| mems = [] |
| for i in top_attn_idx: |
| t = STRIP.sub("", mem_texts[i]).strip().strip('"').strip("'") |
| mems.append({ |
| "text": t[:400], |
| "emotional_weight": ew_list[i], |
| "attention": float(avg_attn[i]), |
| "idx": i |
| }) |
| result = {"attended": mems} |
| elif self.path == "/blend": |
| mems = blend(query, top_k) |
| result = {"memories": mems} |
| else: |
| mems = navigate(query, top_k) |
| result = {"memories": mems} |
|
|
| elapsed = time.time() - t0 |
| result["elapsed"] = elapsed |
| resp = json.dumps(result).encode() |
| self.send_response(200); self._cors() |
| self.send_header("Content-Type", "application/json") |
| self.send_header("Content-Length", str(len(resp))) |
| self.end_headers(); self.wfile.write(resp) |
| print(f"[rmm] {self.path} {repr(query[:40])} -> {len(mems)} results ({elapsed:.2f}s)") |
|
|
| def do_GET(self): |
| if self.path == "/health": |
| resp = b'{"status":"ok"}' |
| self.send_response(200); self._cors() |
| self.send_header("Content-Type","application/json") |
| self.send_header("Content-Length",str(len(resp))) |
| self.end_headers(); self.wfile.write(resp) |
|
|
|
|
| class ThreadedHTTPServer(ThreadingMixIn, HTTPServer): |
| daemon_threads = True |
|
|
| if __name__ == "__main__": |
| server = ThreadedHTTPServer(("0.0.0.0", PORT), Handler) |
| print(f"[rmm] listening on http://localhost:{PORT}") |
| print(f"[rmm] endpoints: /navigate /blend /decode /synthesize /attention /health") |
| server.serve_forever() |
|
|