LJTSG commited on
Commit
4b82a67
·
verified ·
1 Parent(s): a7d124f

Upload train_navigator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_navigator.py +298 -0
train_navigator.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """train_navigator.py — Train the RMM Navigator.
2
+
3
+ Architecture:
4
+ Query (text) -> MiniLM embed (384-d) -> Linear -> 3072-d query vector
5
+ Cross-attention over entity's spine memory vectors (3072-d each)
6
+ -> Synthesized response vector (3072-d)
7
+ -> Cosine loss vs actual reply vector from spine
8
+
9
+ The navigator learns the topology of an entity's embedding space —
10
+ which memories are connected, which regions respond to which queries,
11
+ how emotional weight shapes retrieval. This is learned navigation,
12
+ not cosine similarity.
13
+
14
+ Run: modal run train_navigator.py
15
+ Pull: modal volume get rmm-vol /memory-nav/ ./memory-nav-out/
16
+
17
+ Requires:
18
+ - spine.json: {"memories": [{"text": "...", "vector": [...3072...], "emotional_weight": 8, "salience": 0.5}, ...]}
19
+ - dialogue.txt: alternating "Speaker A: ...\nSpeaker B: ..." blocks
20
+ - (optional) discord.json: array of {author: {name: "..."}, content: "..."} messages
21
+ """
22
+ import modal, json
23
+ from pathlib import Path
24
+
25
+ app = modal.App("rmm-navigator")
26
+ image = (modal.Image.debian_slim(python_version="3.11")
27
+ .pip_install("torch==2.6.0", "numpy", "sentence-transformers"))
28
+ vol = modal.Volume.from_name("rmm-vol", create_if_missing=True)
29
+
30
+ # ── Point these at your entity's data ──
31
+ SPINE_FILE = Path("spine.json")
32
+ DIALOGUE_FILE = Path("dialogue.txt")
33
+ DISCORD_FILE = Path("discord.json")
34
+
35
+ SPINE_DIM = 3072 # embedding dim (Gemini, etc.)
36
+ QUERY_DIM = 384 # MiniLM dim
37
+ N_HEADS = 8
38
+ N_LAYERS = 3
39
+ D_MODEL = 512
40
+ DROPOUT = 0.1
41
+
42
+
43
+ @app.function(image=image, gpu="A10G", timeout=3600, volumes={"/vol": vol})
44
+ def train(spine_json: str, dialogue_text: str, discord_json: str = "",
45
+ speaker_a: str = "Laura", speaker_b: str = "Micah", smoke: bool = False):
46
+ import os, math, time, json, re
47
+ import numpy as np
48
+ import torch
49
+ import torch.nn as nn
50
+ import torch.nn.functional as F
51
+ from sentence_transformers import SentenceTransformer
52
+
53
+ DEV = "cuda"
54
+ print(f"[nav] gpu={torch.cuda.get_device_name(0)}")
55
+
56
+ spine_data = json.loads(spine_json)
57
+ mems = spine_data["memories"]
58
+
59
+ mem_vecs = torch.tensor(
60
+ [m["vector"] for m in mems], dtype=torch.float32
61
+ ).to(DEV)
62
+ mem_vecs = F.normalize(mem_vecs, dim=-1)
63
+ N_MEM = mem_vecs.shape[0]
64
+ print(f"[nav] {N_MEM} memory vectors loaded, dim={mem_vecs.shape[1]}")
65
+
66
+ STRIP = re.compile(r'^\[conversation\] I replied \(puppet\):\s*["\']?', re.I)
67
+ SURR = re.compile(r'[\ud800-\udfff]')
68
+ mem_texts = []
69
+ for m in mems:
70
+ raw = SURR.sub('', str(m.get("text") or ""))
71
+ t = STRIP.sub("", raw).strip().strip('"').strip("'")
72
+ mem_texts.append(t[:500] if t else "...")
73
+
74
+ saliences = torch.tensor(
75
+ [m.get("salience", 0.5) for m in mems], dtype=torch.float32
76
+ ).to(DEV)
77
+
78
+ print("[nav] building training pairs...")
79
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
80
+
81
+ query_msgs, reply_msgs = [], []
82
+
83
+ # Source 1: dialogue file (SpeakerA: ...\nSpeakerB: ... blocks)
84
+ blocks = [b.strip() for b in dialogue_text.split("\n\n")
85
+ if speaker_a + ":" in b and speaker_b + ":" in b]
86
+ n_smoke = 50 if smoke else len(blocks)
87
+ for b in blocks[:n_smoke]:
88
+ parts = b.split(speaker_b + ":", 1)
89
+ a_part = parts[0].replace(speaker_a + ":", "").strip()
90
+ b_part = parts[1].strip() if len(parts) > 1 else ""
91
+ if len(a_part) >= 5 and len(b_part) >= 10:
92
+ query_msgs.append(a_part)
93
+ reply_msgs.append(b_part)
94
+ print(f"[nav] dialogue file: {len(query_msgs)} pairs")
95
+
96
+ # Source 2: Discord history (consecutive A->B messages)
97
+ if discord_json and not smoke:
98
+ disc_msgs = json.loads(discord_json)
99
+ disc_count = 0
100
+ for i in range(len(disc_msgs) - 1):
101
+ cur = disc_msgs[i]
102
+ nxt = disc_msgs[i + 1]
103
+ cur_a = (cur.get("author", {}).get("name", "") or cur.get("author", {}).get("username", "") or "").lower()
104
+ nxt_a = (nxt.get("author", {}).get("name", "") or nxt.get("author", {}).get("username", "") or "").lower()
105
+ cur_c = SURR.sub('', cur.get("content", "").strip())
106
+ nxt_c = SURR.sub('', nxt.get("content", "").strip())
107
+ is_a = speaker_a.lower() in cur_a
108
+ is_b = speaker_b.lower() in nxt_a
109
+ if is_a and is_b and len(cur_c) >= 3 and len(nxt_c) >= 10:
110
+ query_msgs.append(cur_c[:500])
111
+ reply_msgs.append(nxt_c[:500])
112
+ disc_count += 1
113
+ print(f"[nav] discord pairs: {disc_count}")
114
+
115
+ # Sanitize
116
+ clean_q, clean_r = [], []
117
+ for q, r in zip(query_msgs, reply_msgs):
118
+ qs, rs = str(q).strip(), str(r).strip()
119
+ if len(qs) >= 3 and len(rs) >= 5:
120
+ clean_q.append(qs)
121
+ clean_r.append(rs)
122
+ query_msgs, reply_msgs = clean_q, clean_r
123
+
124
+ print(f"[nav] total pairs: {len(query_msgs)}")
125
+ print(f"[nav] embedding queries...")
126
+ query_embs = embedder.encode(query_msgs, normalize_embeddings=True,
127
+ show_progress_bar=False, batch_size=64)
128
+
129
+ print(f"[nav] embedding replies...")
130
+ reply_embs = embedder.encode(reply_msgs, normalize_embeddings=True,
131
+ show_progress_bar=False, batch_size=64)
132
+
133
+ reply_tensor = torch.tensor(reply_embs, dtype=torch.float32)
134
+
135
+ print("[nav] embedding memories in MiniLM space for matching...")
136
+ BATCH = 256
137
+ mem_mini_embs = []
138
+ for start in range(0, N_MEM, BATCH):
139
+ chunk = mem_texts[start:start + BATCH]
140
+ e = embedder.encode(chunk, normalize_embeddings=True, show_progress_bar=False)
141
+ mem_mini_embs.append(e)
142
+ mem_mini = torch.tensor(np.vstack(mem_mini_embs), dtype=torch.float32)
143
+
144
+ sims = reply_tensor @ mem_mini.T
145
+ top5_vals, top5_idx = sims.topk(5, dim=-1)
146
+ sal_cpu = saliences.cpu()
147
+ best_indices = []
148
+ for i in range(len(reply_tensor)):
149
+ candidates = top5_idx[i]
150
+ cand_sals = sal_cpu[candidates]
151
+ best_j = cand_sals.argmax().item()
152
+ best_indices.append(candidates[best_j].item())
153
+ best_mem_idx = torch.tensor(best_indices, dtype=torch.long)
154
+ target_vecs = mem_vecs[best_mem_idx]
155
+
156
+ ew_raw = torch.tensor([mems[i].get("emotional_weight", 5) for i in best_indices],
157
+ dtype=torch.float32)
158
+ pair_weights = 1.0 + 0.3 * (ew_raw - 5.0) / 5.0
159
+ pair_weights = pair_weights / pair_weights.mean()
160
+
161
+ query_tensor = torch.tensor(query_embs, dtype=torch.float32)
162
+ print(f"[nav] {len(query_tensor)} training pairs ready")
163
+
164
+ # ── Model ──
165
+ class MemoryNavigator(nn.Module):
166
+ def __init__(self):
167
+ super().__init__()
168
+ self.query_proj = nn.Sequential(
169
+ nn.Linear(QUERY_DIM, D_MODEL),
170
+ nn.LayerNorm(D_MODEL),
171
+ nn.GELU(),
172
+ )
173
+ self.mem_proj = nn.Linear(SPINE_DIM, D_MODEL, bias=False)
174
+ self.layers = nn.ModuleList([
175
+ nn.TransformerDecoderLayer(
176
+ d_model=D_MODEL, nhead=N_HEADS,
177
+ dim_feedforward=D_MODEL * 4,
178
+ dropout=DROPOUT, batch_first=True
179
+ )
180
+ for _ in range(N_LAYERS)
181
+ ])
182
+ self.out_proj = nn.Linear(D_MODEL, SPINE_DIM, bias=False)
183
+ self.norm = nn.LayerNorm(D_MODEL)
184
+
185
+ def forward(self, q, mem_keys):
186
+ q = self.query_proj(q).unsqueeze(1)
187
+ B = q.shape[0]
188
+ m = self.mem_proj(mem_keys).unsqueeze(0).expand(B, -1, -1)
189
+ x = q
190
+ for layer in self.layers:
191
+ x = layer(x, m)
192
+ x = self.norm(x).squeeze(1)
193
+ out = self.out_proj(x)
194
+ return F.normalize(out, dim=-1)
195
+
196
+ model = MemoryNavigator().to(DEV)
197
+ n_params = sum(p.numel() for p in model.parameters())
198
+ print(f"[nav] model {n_params/1e6:.1f}M params")
199
+
200
+ # ── Train ──
201
+ ITERS = 200 if smoke else 7500
202
+ BS = 32
203
+ N_NEG = 7
204
+ MARGIN = 0.2
205
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
206
+ warmup_steps = 200 if not smoke else 20
207
+ def lr_lambda(step):
208
+ if step < warmup_steps:
209
+ return step / warmup_steps
210
+ progress = (step - warmup_steps) / max(1, ITERS - warmup_steps)
211
+ return 0.5 * (1 + math.cos(math.pi * progress))
212
+ sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
213
+
214
+ M = len(query_tensor)
215
+ t0 = time.time()
216
+ best_loss = float('inf')
217
+ best_state = None
218
+
219
+ for step in range(ITERS):
220
+ idx = torch.randint(0, M, (BS,))
221
+ q_batch = query_tensor[idx].to(DEV)
222
+ t_batch = target_vecs[idx].to(DEV)
223
+ t_idx = best_mem_idx[idx]
224
+
225
+ pred = model(q_batch, mem_vecs)
226
+
227
+ pos_sim = (pred * t_batch).sum(dim=-1)
228
+
229
+ neg_sims_list = []
230
+ for b in range(BS):
231
+ all_sims = (mem_vecs @ pred[b]).squeeze()
232
+ all_sims[t_idx[b]] = -1.0
233
+ hard_neg_idx = all_sims.topk(N_NEG).indices
234
+ neg_sims_list.append(all_sims[hard_neg_idx].mean())
235
+ neg_sim = torch.stack(neg_sims_list)
236
+
237
+ w = pair_weights[idx].to(DEV)
238
+ loss_pos = ((1.0 - pos_sim) * w).mean()
239
+ loss_neg = (F.relu(neg_sim - pos_sim + MARGIN) * w).mean()
240
+ loss = loss_pos + 0.3 * loss_neg
241
+
242
+ opt.zero_grad()
243
+ loss.backward()
244
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
245
+ opt.step()
246
+ sch.step()
247
+
248
+ if step % (20 if smoke else 250) == 0:
249
+ lv = loss.item()
250
+ lp = loss_pos.item()
251
+ ln = loss_neg.item()
252
+ mark = " <-" if lv < best_loss else ""
253
+ print(f" [nav] step {step:4d} loss={lv:.4f} (pos={lp:.4f} neg={ln:.4f}) ({time.time()-t0:.0f}s){mark}")
254
+ if lv < best_loss:
255
+ best_loss = lv
256
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
257
+
258
+ if best_state:
259
+ model.load_state_dict(best_state)
260
+
261
+ os.makedirs("/vol/memory-nav", exist_ok=True)
262
+ torch.save({k: v.cpu() for k, v in model.state_dict().items()},
263
+ "/vol/memory-nav/navigator.pt")
264
+ np.save("/vol/memory-nav/mem_vecs.npy", mem_vecs.cpu().numpy())
265
+ import pickle
266
+ with open("/vol/memory-nav/mem_texts.pkl", "wb") as f:
267
+ pickle.dump(mem_texts, f)
268
+ vol.commit()
269
+ print(f"[nav] DONE best_loss={best_loss:.4f} saved to /vol/memory-nav/")
270
+
271
+ model.eval()
272
+ test_queries = ["hello", "I love you", "I miss her", "tell me a story"]
273
+ for q in test_queries:
274
+ qe = torch.tensor(
275
+ embedder.encode([q], normalize_embeddings=True), dtype=torch.float32
276
+ ).to(DEV)
277
+ with torch.no_grad():
278
+ rv = model(qe, mem_vecs)
279
+ sims = (mem_vecs @ rv.T).squeeze()
280
+ top3 = sims.topk(3).indices.tolist()
281
+ print(f"\nQuery: {q!r}")
282
+ for i in top3:
283
+ print(f" [{i}] ew={mems[i].get('emotional_weight',0)} {mem_texts[i][:100]}")
284
+
285
+ return {"best_loss": best_loss, "params_m": n_params/1e6}
286
+
287
+
288
+ @app.local_entrypoint()
289
+ def main(smoke: bool = False):
290
+ spine_json = SPINE_FILE.read_text(encoding="utf-8", errors="ignore")
291
+ dialogue = DIALOGUE_FILE.read_text(encoding="utf-8", errors="ignore")
292
+ discord = ""
293
+ if DISCORD_FILE.exists() and not smoke:
294
+ discord = DISCORD_FILE.read_text(encoding="utf-8", errors="ignore")
295
+ n_pairs = len([b for b in dialogue.split(chr(10)*2) if 'Laura:' in b])
296
+ print(f"[local] spine={len(spine_json)//1024}KB dialogue={n_pairs} discord={len(discord)//1024}KB smoke={smoke}")
297
+ r = train.remote(spine_json, dialogue, discord, smoke=smoke)
298
+ print(f"[local] done loss={r['best_loss']:.4f} params={r['params_m']:.1f}M")