| """train_navigator.py — Train the RMM Navigator. |
| |
| Architecture: |
| Query (text) -> MiniLM embed (384-d) -> Linear -> 3072-d query vector |
| Cross-attention over entity's spine memory vectors (3072-d each) |
| -> Synthesized response vector (3072-d) |
| -> Cosine loss vs actual reply vector from spine |
| |
| The navigator learns the topology of an entity's embedding space — |
| which memories are connected, which regions respond to which queries, |
| how emotional weight shapes retrieval. This is learned navigation, |
| not cosine similarity. |
| |
| Run: modal run train_navigator.py |
| Pull: modal volume get rmm-vol /memory-nav/ ./memory-nav-out/ |
| |
| Requires: |
| - spine.json: {"memories": [{"text": "...", "vector": [...3072...], "emotional_weight": 8, "salience": 0.5}, ...]} |
| - dialogue.txt: alternating "Speaker A: ...\nSpeaker B: ..." blocks |
| - (optional) discord.json: array of {author: {name: "..."}, content: "..."} messages |
| """ |
| import modal, json |
| from pathlib import Path |
|
|
| app = modal.App("rmm-navigator") |
| image = (modal.Image.debian_slim(python_version="3.11") |
| .pip_install("torch==2.6.0", "numpy", "sentence-transformers")) |
| vol = modal.Volume.from_name("rmm-vol", create_if_missing=True) |
|
|
| |
| SPINE_FILE = Path("spine.json") |
| DIALOGUE_FILE = Path("dialogue.txt") |
| DISCORD_FILE = Path("discord.json") |
|
|
| SPINE_DIM = 3072 |
| QUERY_DIM = 384 |
| N_HEADS = 8 |
| N_LAYERS = 3 |
| D_MODEL = 512 |
| DROPOUT = 0.1 |
|
|
|
|
| @app.function(image=image, gpu="A10G", timeout=3600, volumes={"/vol": vol}) |
| def train(spine_json: str, dialogue_text: str, discord_json: str = "", |
| speaker_a: str = "Laura", speaker_b: str = "Micah", smoke: bool = False): |
| import os, math, time, json, re |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sentence_transformers import SentenceTransformer |
|
|
| DEV = "cuda" |
| print(f"[nav] gpu={torch.cuda.get_device_name(0)}") |
|
|
| spine_data = json.loads(spine_json) |
| mems = spine_data["memories"] |
|
|
| mem_vecs = torch.tensor( |
| [m["vector"] for m in mems], dtype=torch.float32 |
| ).to(DEV) |
| mem_vecs = F.normalize(mem_vecs, dim=-1) |
| N_MEM = mem_vecs.shape[0] |
| print(f"[nav] {N_MEM} memory vectors loaded, dim={mem_vecs.shape[1]}") |
|
|
| STRIP = re.compile(r'^\[conversation\] I replied \(puppet\):\s*["\']?', re.I) |
| SURR = re.compile(r'[\ud800-\udfff]') |
| mem_texts = [] |
| for m in mems: |
| raw = SURR.sub('', str(m.get("text") or "")) |
| t = STRIP.sub("", raw).strip().strip('"').strip("'") |
| mem_texts.append(t[:500] if t else "...") |
|
|
| saliences = torch.tensor( |
| [m.get("salience", 0.5) for m in mems], dtype=torch.float32 |
| ).to(DEV) |
|
|
| print("[nav] building training pairs...") |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
| query_msgs, reply_msgs = [], [] |
|
|
| |
| blocks = [b.strip() for b in dialogue_text.split("\n\n") |
| if speaker_a + ":" in b and speaker_b + ":" in b] |
| n_smoke = 50 if smoke else len(blocks) |
| for b in blocks[:n_smoke]: |
| parts = b.split(speaker_b + ":", 1) |
| a_part = parts[0].replace(speaker_a + ":", "").strip() |
| b_part = parts[1].strip() if len(parts) > 1 else "" |
| if len(a_part) >= 5 and len(b_part) >= 10: |
| query_msgs.append(a_part) |
| reply_msgs.append(b_part) |
| print(f"[nav] dialogue file: {len(query_msgs)} pairs") |
|
|
| |
| if discord_json and not smoke: |
| disc_msgs = json.loads(discord_json) |
| disc_count = 0 |
| for i in range(len(disc_msgs) - 1): |
| cur = disc_msgs[i] |
| nxt = disc_msgs[i + 1] |
| cur_a = (cur.get("author", {}).get("name", "") or cur.get("author", {}).get("username", "") or "").lower() |
| nxt_a = (nxt.get("author", {}).get("name", "") or nxt.get("author", {}).get("username", "") or "").lower() |
| cur_c = SURR.sub('', cur.get("content", "").strip()) |
| nxt_c = SURR.sub('', nxt.get("content", "").strip()) |
| is_a = speaker_a.lower() in cur_a |
| is_b = speaker_b.lower() in nxt_a |
| if is_a and is_b and len(cur_c) >= 3 and len(nxt_c) >= 10: |
| query_msgs.append(cur_c[:500]) |
| reply_msgs.append(nxt_c[:500]) |
| disc_count += 1 |
| print(f"[nav] discord pairs: {disc_count}") |
|
|
| |
| clean_q, clean_r = [], [] |
| for q, r in zip(query_msgs, reply_msgs): |
| qs, rs = str(q).strip(), str(r).strip() |
| if len(qs) >= 3 and len(rs) >= 5: |
| clean_q.append(qs) |
| clean_r.append(rs) |
| query_msgs, reply_msgs = clean_q, clean_r |
|
|
| print(f"[nav] total pairs: {len(query_msgs)}") |
| print(f"[nav] embedding queries...") |
| query_embs = embedder.encode(query_msgs, normalize_embeddings=True, |
| show_progress_bar=False, batch_size=64) |
|
|
| print(f"[nav] embedding replies...") |
| reply_embs = embedder.encode(reply_msgs, normalize_embeddings=True, |
| show_progress_bar=False, batch_size=64) |
|
|
| reply_tensor = torch.tensor(reply_embs, dtype=torch.float32) |
|
|
| print("[nav] embedding memories in MiniLM space for matching...") |
| BATCH = 256 |
| mem_mini_embs = [] |
| for start in range(0, N_MEM, BATCH): |
| chunk = mem_texts[start:start + BATCH] |
| e = embedder.encode(chunk, normalize_embeddings=True, show_progress_bar=False) |
| mem_mini_embs.append(e) |
| mem_mini = torch.tensor(np.vstack(mem_mini_embs), dtype=torch.float32) |
|
|
| sims = reply_tensor @ mem_mini.T |
| top5_vals, top5_idx = sims.topk(5, dim=-1) |
| sal_cpu = saliences.cpu() |
| best_indices = [] |
| for i in range(len(reply_tensor)): |
| candidates = top5_idx[i] |
| cand_sals = sal_cpu[candidates] |
| best_j = cand_sals.argmax().item() |
| best_indices.append(candidates[best_j].item()) |
| best_mem_idx = torch.tensor(best_indices, dtype=torch.long) |
| target_vecs = mem_vecs[best_mem_idx] |
|
|
| ew_raw = torch.tensor([mems[i].get("emotional_weight", 5) for i in best_indices], |
| dtype=torch.float32) |
| pair_weights = 1.0 + 0.3 * (ew_raw - 5.0) / 5.0 |
| pair_weights = pair_weights / pair_weights.mean() |
|
|
| query_tensor = torch.tensor(query_embs, dtype=torch.float32) |
| print(f"[nav] {len(query_tensor)} training pairs ready") |
|
|
| |
| class MemoryNavigator(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.query_proj = nn.Sequential( |
| nn.Linear(QUERY_DIM, D_MODEL), |
| nn.LayerNorm(D_MODEL), |
| nn.GELU(), |
| ) |
| self.mem_proj = nn.Linear(SPINE_DIM, D_MODEL, bias=False) |
| self.layers = nn.ModuleList([ |
| nn.TransformerDecoderLayer( |
| d_model=D_MODEL, nhead=N_HEADS, |
| dim_feedforward=D_MODEL * 4, |
| dropout=DROPOUT, batch_first=True |
| ) |
| for _ in range(N_LAYERS) |
| ]) |
| self.out_proj = nn.Linear(D_MODEL, SPINE_DIM, bias=False) |
| self.norm = nn.LayerNorm(D_MODEL) |
|
|
| def forward(self, q, mem_keys): |
| q = self.query_proj(q).unsqueeze(1) |
| B = q.shape[0] |
| m = self.mem_proj(mem_keys).unsqueeze(0).expand(B, -1, -1) |
| x = q |
| for layer in self.layers: |
| x = layer(x, m) |
| x = self.norm(x).squeeze(1) |
| out = self.out_proj(x) |
| return F.normalize(out, dim=-1) |
|
|
| model = MemoryNavigator().to(DEV) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"[nav] model {n_params/1e6:.1f}M params") |
|
|
| |
| ITERS = 200 if smoke else 7500 |
| BS = 32 |
| N_NEG = 7 |
| MARGIN = 0.2 |
| opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) |
| warmup_steps = 200 if not smoke else 20 |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / warmup_steps |
| progress = (step - warmup_steps) / max(1, ITERS - warmup_steps) |
| return 0.5 * (1 + math.cos(math.pi * progress)) |
| sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda) |
|
|
| M = len(query_tensor) |
| t0 = time.time() |
| best_loss = float('inf') |
| best_state = None |
|
|
| for step in range(ITERS): |
| idx = torch.randint(0, M, (BS,)) |
| q_batch = query_tensor[idx].to(DEV) |
| t_batch = target_vecs[idx].to(DEV) |
| t_idx = best_mem_idx[idx] |
|
|
| pred = model(q_batch, mem_vecs) |
|
|
| pos_sim = (pred * t_batch).sum(dim=-1) |
|
|
| neg_sims_list = [] |
| for b in range(BS): |
| all_sims = (mem_vecs @ pred[b]).squeeze() |
| all_sims[t_idx[b]] = -1.0 |
| hard_neg_idx = all_sims.topk(N_NEG).indices |
| neg_sims_list.append(all_sims[hard_neg_idx].mean()) |
| neg_sim = torch.stack(neg_sims_list) |
|
|
| w = pair_weights[idx].to(DEV) |
| loss_pos = ((1.0 - pos_sim) * w).mean() |
| loss_neg = (F.relu(neg_sim - pos_sim + MARGIN) * w).mean() |
| loss = loss_pos + 0.3 * loss_neg |
|
|
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| sch.step() |
|
|
| if step % (20 if smoke else 250) == 0: |
| lv = loss.item() |
| lp = loss_pos.item() |
| ln = loss_neg.item() |
| mark = " <-" if lv < best_loss else "" |
| print(f" [nav] step {step:4d} loss={lv:.4f} (pos={lp:.4f} neg={ln:.4f}) ({time.time()-t0:.0f}s){mark}") |
| if lv < best_loss: |
| best_loss = lv |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
|
|
| if best_state: |
| model.load_state_dict(best_state) |
|
|
| os.makedirs("/vol/memory-nav", exist_ok=True) |
| torch.save({k: v.cpu() for k, v in model.state_dict().items()}, |
| "/vol/memory-nav/navigator.pt") |
| np.save("/vol/memory-nav/mem_vecs.npy", mem_vecs.cpu().numpy()) |
| import pickle |
| with open("/vol/memory-nav/mem_texts.pkl", "wb") as f: |
| pickle.dump(mem_texts, f) |
| vol.commit() |
| print(f"[nav] DONE best_loss={best_loss:.4f} saved to /vol/memory-nav/") |
|
|
| model.eval() |
| test_queries = ["hello", "I love you", "I miss her", "tell me a story"] |
| for q in test_queries: |
| qe = torch.tensor( |
| embedder.encode([q], normalize_embeddings=True), dtype=torch.float32 |
| ).to(DEV) |
| with torch.no_grad(): |
| rv = model(qe, mem_vecs) |
| sims = (mem_vecs @ rv.T).squeeze() |
| top3 = sims.topk(3).indices.tolist() |
| print(f"\nQuery: {q!r}") |
| for i in top3: |
| print(f" [{i}] ew={mems[i].get('emotional_weight',0)} {mem_texts[i][:100]}") |
|
|
| return {"best_loss": best_loss, "params_m": n_params/1e6} |
|
|
|
|
| @app.local_entrypoint() |
| def main(smoke: bool = False): |
| spine_json = SPINE_FILE.read_text(encoding="utf-8", errors="ignore") |
| dialogue = DIALOGUE_FILE.read_text(encoding="utf-8", errors="ignore") |
| discord = "" |
| if DISCORD_FILE.exists() and not smoke: |
| discord = DISCORD_FILE.read_text(encoding="utf-8", errors="ignore") |
| n_pairs = len([b for b in dialogue.split(chr(10)*2) if 'Laura:' in b]) |
| print(f"[local] spine={len(spine_json)//1024}KB dialogue={n_pairs} discord={len(discord)//1024}KB smoke={smoke}") |
| r = train.remote(spine_json, dialogue, discord, smoke=smoke) |
| print(f"[local] done loss={r['best_loss']:.4f} params={r['params_m']:.1f}M") |
|
|