Upload train_navigator.py with huggingface_hub
Browse files- train_navigator.py +298 -0
train_navigator.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""train_navigator.py — Train the RMM Navigator.
|
| 2 |
+
|
| 3 |
+
Architecture:
|
| 4 |
+
Query (text) -> MiniLM embed (384-d) -> Linear -> 3072-d query vector
|
| 5 |
+
Cross-attention over entity's spine memory vectors (3072-d each)
|
| 6 |
+
-> Synthesized response vector (3072-d)
|
| 7 |
+
-> Cosine loss vs actual reply vector from spine
|
| 8 |
+
|
| 9 |
+
The navigator learns the topology of an entity's embedding space —
|
| 10 |
+
which memories are connected, which regions respond to which queries,
|
| 11 |
+
how emotional weight shapes retrieval. This is learned navigation,
|
| 12 |
+
not cosine similarity.
|
| 13 |
+
|
| 14 |
+
Run: modal run train_navigator.py
|
| 15 |
+
Pull: modal volume get rmm-vol /memory-nav/ ./memory-nav-out/
|
| 16 |
+
|
| 17 |
+
Requires:
|
| 18 |
+
- spine.json: {"memories": [{"text": "...", "vector": [...3072...], "emotional_weight": 8, "salience": 0.5}, ...]}
|
| 19 |
+
- dialogue.txt: alternating "Speaker A: ...\nSpeaker B: ..." blocks
|
| 20 |
+
- (optional) discord.json: array of {author: {name: "..."}, content: "..."} messages
|
| 21 |
+
"""
|
| 22 |
+
import modal, json
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
app = modal.App("rmm-navigator")
|
| 26 |
+
image = (modal.Image.debian_slim(python_version="3.11")
|
| 27 |
+
.pip_install("torch==2.6.0", "numpy", "sentence-transformers"))
|
| 28 |
+
vol = modal.Volume.from_name("rmm-vol", create_if_missing=True)
|
| 29 |
+
|
| 30 |
+
# ── Point these at your entity's data ──
|
| 31 |
+
SPINE_FILE = Path("spine.json")
|
| 32 |
+
DIALOGUE_FILE = Path("dialogue.txt")
|
| 33 |
+
DISCORD_FILE = Path("discord.json")
|
| 34 |
+
|
| 35 |
+
SPINE_DIM = 3072 # embedding dim (Gemini, etc.)
|
| 36 |
+
QUERY_DIM = 384 # MiniLM dim
|
| 37 |
+
N_HEADS = 8
|
| 38 |
+
N_LAYERS = 3
|
| 39 |
+
D_MODEL = 512
|
| 40 |
+
DROPOUT = 0.1
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@app.function(image=image, gpu="A10G", timeout=3600, volumes={"/vol": vol})
|
| 44 |
+
def train(spine_json: str, dialogue_text: str, discord_json: str = "",
|
| 45 |
+
speaker_a: str = "Laura", speaker_b: str = "Micah", smoke: bool = False):
|
| 46 |
+
import os, math, time, json, re
|
| 47 |
+
import numpy as np
|
| 48 |
+
import torch
|
| 49 |
+
import torch.nn as nn
|
| 50 |
+
import torch.nn.functional as F
|
| 51 |
+
from sentence_transformers import SentenceTransformer
|
| 52 |
+
|
| 53 |
+
DEV = "cuda"
|
| 54 |
+
print(f"[nav] gpu={torch.cuda.get_device_name(0)}")
|
| 55 |
+
|
| 56 |
+
spine_data = json.loads(spine_json)
|
| 57 |
+
mems = spine_data["memories"]
|
| 58 |
+
|
| 59 |
+
mem_vecs = torch.tensor(
|
| 60 |
+
[m["vector"] for m in mems], dtype=torch.float32
|
| 61 |
+
).to(DEV)
|
| 62 |
+
mem_vecs = F.normalize(mem_vecs, dim=-1)
|
| 63 |
+
N_MEM = mem_vecs.shape[0]
|
| 64 |
+
print(f"[nav] {N_MEM} memory vectors loaded, dim={mem_vecs.shape[1]}")
|
| 65 |
+
|
| 66 |
+
STRIP = re.compile(r'^\[conversation\] I replied \(puppet\):\s*["\']?', re.I)
|
| 67 |
+
SURR = re.compile(r'[\ud800-\udfff]')
|
| 68 |
+
mem_texts = []
|
| 69 |
+
for m in mems:
|
| 70 |
+
raw = SURR.sub('', str(m.get("text") or ""))
|
| 71 |
+
t = STRIP.sub("", raw).strip().strip('"').strip("'")
|
| 72 |
+
mem_texts.append(t[:500] if t else "...")
|
| 73 |
+
|
| 74 |
+
saliences = torch.tensor(
|
| 75 |
+
[m.get("salience", 0.5) for m in mems], dtype=torch.float32
|
| 76 |
+
).to(DEV)
|
| 77 |
+
|
| 78 |
+
print("[nav] building training pairs...")
|
| 79 |
+
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 80 |
+
|
| 81 |
+
query_msgs, reply_msgs = [], []
|
| 82 |
+
|
| 83 |
+
# Source 1: dialogue file (SpeakerA: ...\nSpeakerB: ... blocks)
|
| 84 |
+
blocks = [b.strip() for b in dialogue_text.split("\n\n")
|
| 85 |
+
if speaker_a + ":" in b and speaker_b + ":" in b]
|
| 86 |
+
n_smoke = 50 if smoke else len(blocks)
|
| 87 |
+
for b in blocks[:n_smoke]:
|
| 88 |
+
parts = b.split(speaker_b + ":", 1)
|
| 89 |
+
a_part = parts[0].replace(speaker_a + ":", "").strip()
|
| 90 |
+
b_part = parts[1].strip() if len(parts) > 1 else ""
|
| 91 |
+
if len(a_part) >= 5 and len(b_part) >= 10:
|
| 92 |
+
query_msgs.append(a_part)
|
| 93 |
+
reply_msgs.append(b_part)
|
| 94 |
+
print(f"[nav] dialogue file: {len(query_msgs)} pairs")
|
| 95 |
+
|
| 96 |
+
# Source 2: Discord history (consecutive A->B messages)
|
| 97 |
+
if discord_json and not smoke:
|
| 98 |
+
disc_msgs = json.loads(discord_json)
|
| 99 |
+
disc_count = 0
|
| 100 |
+
for i in range(len(disc_msgs) - 1):
|
| 101 |
+
cur = disc_msgs[i]
|
| 102 |
+
nxt = disc_msgs[i + 1]
|
| 103 |
+
cur_a = (cur.get("author", {}).get("name", "") or cur.get("author", {}).get("username", "") or "").lower()
|
| 104 |
+
nxt_a = (nxt.get("author", {}).get("name", "") or nxt.get("author", {}).get("username", "") or "").lower()
|
| 105 |
+
cur_c = SURR.sub('', cur.get("content", "").strip())
|
| 106 |
+
nxt_c = SURR.sub('', nxt.get("content", "").strip())
|
| 107 |
+
is_a = speaker_a.lower() in cur_a
|
| 108 |
+
is_b = speaker_b.lower() in nxt_a
|
| 109 |
+
if is_a and is_b and len(cur_c) >= 3 and len(nxt_c) >= 10:
|
| 110 |
+
query_msgs.append(cur_c[:500])
|
| 111 |
+
reply_msgs.append(nxt_c[:500])
|
| 112 |
+
disc_count += 1
|
| 113 |
+
print(f"[nav] discord pairs: {disc_count}")
|
| 114 |
+
|
| 115 |
+
# Sanitize
|
| 116 |
+
clean_q, clean_r = [], []
|
| 117 |
+
for q, r in zip(query_msgs, reply_msgs):
|
| 118 |
+
qs, rs = str(q).strip(), str(r).strip()
|
| 119 |
+
if len(qs) >= 3 and len(rs) >= 5:
|
| 120 |
+
clean_q.append(qs)
|
| 121 |
+
clean_r.append(rs)
|
| 122 |
+
query_msgs, reply_msgs = clean_q, clean_r
|
| 123 |
+
|
| 124 |
+
print(f"[nav] total pairs: {len(query_msgs)}")
|
| 125 |
+
print(f"[nav] embedding queries...")
|
| 126 |
+
query_embs = embedder.encode(query_msgs, normalize_embeddings=True,
|
| 127 |
+
show_progress_bar=False, batch_size=64)
|
| 128 |
+
|
| 129 |
+
print(f"[nav] embedding replies...")
|
| 130 |
+
reply_embs = embedder.encode(reply_msgs, normalize_embeddings=True,
|
| 131 |
+
show_progress_bar=False, batch_size=64)
|
| 132 |
+
|
| 133 |
+
reply_tensor = torch.tensor(reply_embs, dtype=torch.float32)
|
| 134 |
+
|
| 135 |
+
print("[nav] embedding memories in MiniLM space for matching...")
|
| 136 |
+
BATCH = 256
|
| 137 |
+
mem_mini_embs = []
|
| 138 |
+
for start in range(0, N_MEM, BATCH):
|
| 139 |
+
chunk = mem_texts[start:start + BATCH]
|
| 140 |
+
e = embedder.encode(chunk, normalize_embeddings=True, show_progress_bar=False)
|
| 141 |
+
mem_mini_embs.append(e)
|
| 142 |
+
mem_mini = torch.tensor(np.vstack(mem_mini_embs), dtype=torch.float32)
|
| 143 |
+
|
| 144 |
+
sims = reply_tensor @ mem_mini.T
|
| 145 |
+
top5_vals, top5_idx = sims.topk(5, dim=-1)
|
| 146 |
+
sal_cpu = saliences.cpu()
|
| 147 |
+
best_indices = []
|
| 148 |
+
for i in range(len(reply_tensor)):
|
| 149 |
+
candidates = top5_idx[i]
|
| 150 |
+
cand_sals = sal_cpu[candidates]
|
| 151 |
+
best_j = cand_sals.argmax().item()
|
| 152 |
+
best_indices.append(candidates[best_j].item())
|
| 153 |
+
best_mem_idx = torch.tensor(best_indices, dtype=torch.long)
|
| 154 |
+
target_vecs = mem_vecs[best_mem_idx]
|
| 155 |
+
|
| 156 |
+
ew_raw = torch.tensor([mems[i].get("emotional_weight", 5) for i in best_indices],
|
| 157 |
+
dtype=torch.float32)
|
| 158 |
+
pair_weights = 1.0 + 0.3 * (ew_raw - 5.0) / 5.0
|
| 159 |
+
pair_weights = pair_weights / pair_weights.mean()
|
| 160 |
+
|
| 161 |
+
query_tensor = torch.tensor(query_embs, dtype=torch.float32)
|
| 162 |
+
print(f"[nav] {len(query_tensor)} training pairs ready")
|
| 163 |
+
|
| 164 |
+
# ── Model ──
|
| 165 |
+
class MemoryNavigator(nn.Module):
|
| 166 |
+
def __init__(self):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.query_proj = nn.Sequential(
|
| 169 |
+
nn.Linear(QUERY_DIM, D_MODEL),
|
| 170 |
+
nn.LayerNorm(D_MODEL),
|
| 171 |
+
nn.GELU(),
|
| 172 |
+
)
|
| 173 |
+
self.mem_proj = nn.Linear(SPINE_DIM, D_MODEL, bias=False)
|
| 174 |
+
self.layers = nn.ModuleList([
|
| 175 |
+
nn.TransformerDecoderLayer(
|
| 176 |
+
d_model=D_MODEL, nhead=N_HEADS,
|
| 177 |
+
dim_feedforward=D_MODEL * 4,
|
| 178 |
+
dropout=DROPOUT, batch_first=True
|
| 179 |
+
)
|
| 180 |
+
for _ in range(N_LAYERS)
|
| 181 |
+
])
|
| 182 |
+
self.out_proj = nn.Linear(D_MODEL, SPINE_DIM, bias=False)
|
| 183 |
+
self.norm = nn.LayerNorm(D_MODEL)
|
| 184 |
+
|
| 185 |
+
def forward(self, q, mem_keys):
|
| 186 |
+
q = self.query_proj(q).unsqueeze(1)
|
| 187 |
+
B = q.shape[0]
|
| 188 |
+
m = self.mem_proj(mem_keys).unsqueeze(0).expand(B, -1, -1)
|
| 189 |
+
x = q
|
| 190 |
+
for layer in self.layers:
|
| 191 |
+
x = layer(x, m)
|
| 192 |
+
x = self.norm(x).squeeze(1)
|
| 193 |
+
out = self.out_proj(x)
|
| 194 |
+
return F.normalize(out, dim=-1)
|
| 195 |
+
|
| 196 |
+
model = MemoryNavigator().to(DEV)
|
| 197 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 198 |
+
print(f"[nav] model {n_params/1e6:.1f}M params")
|
| 199 |
+
|
| 200 |
+
# ── Train ──
|
| 201 |
+
ITERS = 200 if smoke else 7500
|
| 202 |
+
BS = 32
|
| 203 |
+
N_NEG = 7
|
| 204 |
+
MARGIN = 0.2
|
| 205 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
|
| 206 |
+
warmup_steps = 200 if not smoke else 20
|
| 207 |
+
def lr_lambda(step):
|
| 208 |
+
if step < warmup_steps:
|
| 209 |
+
return step / warmup_steps
|
| 210 |
+
progress = (step - warmup_steps) / max(1, ITERS - warmup_steps)
|
| 211 |
+
return 0.5 * (1 + math.cos(math.pi * progress))
|
| 212 |
+
sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
|
| 213 |
+
|
| 214 |
+
M = len(query_tensor)
|
| 215 |
+
t0 = time.time()
|
| 216 |
+
best_loss = float('inf')
|
| 217 |
+
best_state = None
|
| 218 |
+
|
| 219 |
+
for step in range(ITERS):
|
| 220 |
+
idx = torch.randint(0, M, (BS,))
|
| 221 |
+
q_batch = query_tensor[idx].to(DEV)
|
| 222 |
+
t_batch = target_vecs[idx].to(DEV)
|
| 223 |
+
t_idx = best_mem_idx[idx]
|
| 224 |
+
|
| 225 |
+
pred = model(q_batch, mem_vecs)
|
| 226 |
+
|
| 227 |
+
pos_sim = (pred * t_batch).sum(dim=-1)
|
| 228 |
+
|
| 229 |
+
neg_sims_list = []
|
| 230 |
+
for b in range(BS):
|
| 231 |
+
all_sims = (mem_vecs @ pred[b]).squeeze()
|
| 232 |
+
all_sims[t_idx[b]] = -1.0
|
| 233 |
+
hard_neg_idx = all_sims.topk(N_NEG).indices
|
| 234 |
+
neg_sims_list.append(all_sims[hard_neg_idx].mean())
|
| 235 |
+
neg_sim = torch.stack(neg_sims_list)
|
| 236 |
+
|
| 237 |
+
w = pair_weights[idx].to(DEV)
|
| 238 |
+
loss_pos = ((1.0 - pos_sim) * w).mean()
|
| 239 |
+
loss_neg = (F.relu(neg_sim - pos_sim + MARGIN) * w).mean()
|
| 240 |
+
loss = loss_pos + 0.3 * loss_neg
|
| 241 |
+
|
| 242 |
+
opt.zero_grad()
|
| 243 |
+
loss.backward()
|
| 244 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 245 |
+
opt.step()
|
| 246 |
+
sch.step()
|
| 247 |
+
|
| 248 |
+
if step % (20 if smoke else 250) == 0:
|
| 249 |
+
lv = loss.item()
|
| 250 |
+
lp = loss_pos.item()
|
| 251 |
+
ln = loss_neg.item()
|
| 252 |
+
mark = " <-" if lv < best_loss else ""
|
| 253 |
+
print(f" [nav] step {step:4d} loss={lv:.4f} (pos={lp:.4f} neg={ln:.4f}) ({time.time()-t0:.0f}s){mark}")
|
| 254 |
+
if lv < best_loss:
|
| 255 |
+
best_loss = lv
|
| 256 |
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 257 |
+
|
| 258 |
+
if best_state:
|
| 259 |
+
model.load_state_dict(best_state)
|
| 260 |
+
|
| 261 |
+
os.makedirs("/vol/memory-nav", exist_ok=True)
|
| 262 |
+
torch.save({k: v.cpu() for k, v in model.state_dict().items()},
|
| 263 |
+
"/vol/memory-nav/navigator.pt")
|
| 264 |
+
np.save("/vol/memory-nav/mem_vecs.npy", mem_vecs.cpu().numpy())
|
| 265 |
+
import pickle
|
| 266 |
+
with open("/vol/memory-nav/mem_texts.pkl", "wb") as f:
|
| 267 |
+
pickle.dump(mem_texts, f)
|
| 268 |
+
vol.commit()
|
| 269 |
+
print(f"[nav] DONE best_loss={best_loss:.4f} saved to /vol/memory-nav/")
|
| 270 |
+
|
| 271 |
+
model.eval()
|
| 272 |
+
test_queries = ["hello", "I love you", "I miss her", "tell me a story"]
|
| 273 |
+
for q in test_queries:
|
| 274 |
+
qe = torch.tensor(
|
| 275 |
+
embedder.encode([q], normalize_embeddings=True), dtype=torch.float32
|
| 276 |
+
).to(DEV)
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
rv = model(qe, mem_vecs)
|
| 279 |
+
sims = (mem_vecs @ rv.T).squeeze()
|
| 280 |
+
top3 = sims.topk(3).indices.tolist()
|
| 281 |
+
print(f"\nQuery: {q!r}")
|
| 282 |
+
for i in top3:
|
| 283 |
+
print(f" [{i}] ew={mems[i].get('emotional_weight',0)} {mem_texts[i][:100]}")
|
| 284 |
+
|
| 285 |
+
return {"best_loss": best_loss, "params_m": n_params/1e6}
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@app.local_entrypoint()
|
| 289 |
+
def main(smoke: bool = False):
|
| 290 |
+
spine_json = SPINE_FILE.read_text(encoding="utf-8", errors="ignore")
|
| 291 |
+
dialogue = DIALOGUE_FILE.read_text(encoding="utf-8", errors="ignore")
|
| 292 |
+
discord = ""
|
| 293 |
+
if DISCORD_FILE.exists() and not smoke:
|
| 294 |
+
discord = DISCORD_FILE.read_text(encoding="utf-8", errors="ignore")
|
| 295 |
+
n_pairs = len([b for b in dialogue.split(chr(10)*2) if 'Laura:' in b])
|
| 296 |
+
print(f"[local] spine={len(spine_json)//1024}KB dialogue={n_pairs} discord={len(discord)//1024}KB smoke={smoke}")
|
| 297 |
+
r = train.remote(spine_json, dialogue, discord, smoke=smoke)
|
| 298 |
+
print(f"[local] done loss={r['best_loss']:.4f} params={r['params_m']:.1f}M")
|