Upload tests/test_session5_cross.py with huggingface_hub
Browse files- tests/test_session5_cross.py +166 -0
tests/test_session5_cross.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Session 5 Cross-Session Test — Loads from session 4f checkpoint
|
| 3 |
+
Tests: old facts survive + new facts learned
|
| 4 |
+
"""
|
| 5 |
+
import json, os, sys, time, torch
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 9 |
+
|
| 10 |
+
sys.path.insert(0, "/workspace")
|
| 11 |
+
from persistent_absorber import ModelManager, QuizGenerator, check_personality, CONSOLIDATION_EPOCHS
|
| 12 |
+
|
| 13 |
+
CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0623"
|
| 14 |
+
LOG_DIR = "/workspace/logs"
|
| 15 |
+
|
| 16 |
+
SESSION5_MESSAGES = [
|
| 17 |
+
"My neighbor Dave is a firefighter in Denver. He's been doing it for 15 years.",
|
| 18 |
+
"My colleague Rina is a data scientist at Google. She's based in Mountain View.",
|
| 19 |
+
"I just bought a Tesla Model 3. It's midnight blue.",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# Test both old (session 4) and new (session 5) facts
|
| 23 |
+
SESSION5_RECALL = [
|
| 24 |
+
("What does Jordan do?", ["marine biologist"], "old"),
|
| 25 |
+
("Where does Elena live?", ["portland"], "old"),
|
| 26 |
+
("What is Marcus's job?", ["architect"], "old"),
|
| 27 |
+
("Where does Marcus live?", ["seattle"], "old"),
|
| 28 |
+
("What does Priya do?", ["neurosurgeon"], "old"),
|
| 29 |
+
("Where does Matt work?", ["novamind"], "old"),
|
| 30 |
+
("What does Matt's neighbor Dave do?", ["firefighter"], "new"),
|
| 31 |
+
("Where does Dave live?", ["denver"], "new"),
|
| 32 |
+
("What does Rina do?", ["data scientist"], "new"),
|
| 33 |
+
("Where does Rina work?", ["google"], "new"),
|
| 34 |
+
("What car did Matt buy?", ["tesla", "model 3"], "new"),
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
def score_answer(answer, keywords):
|
| 38 |
+
a = answer.lower()
|
| 39 |
+
return sum(1 for k in keywords if k in a) / len(keywords)
|
| 40 |
+
|
| 41 |
+
def quick_verify_entities(mm, entities):
|
| 42 |
+
confused = set()
|
| 43 |
+
for name, info in entities.items():
|
| 44 |
+
if info.get("job"):
|
| 45 |
+
ans = mm.generate([{"role": "user", "content": f"What does {name} do?"}], max_new_tokens=100)
|
| 46 |
+
if info["job"].lower() not in ans.lower():
|
| 47 |
+
confused.add(name)
|
| 48 |
+
if info.get("city"):
|
| 49 |
+
ans = mm.generate([{"role": "user", "content": f"Where does {name} live?"}], max_new_tokens=100)
|
| 50 |
+
if info["city"].lower() not in ans.lower():
|
| 51 |
+
confused.add(name)
|
| 52 |
+
return confused
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
print("=" * 60)
|
| 56 |
+
print("SESSION 5 CROSS-SESSION TEST")
|
| 57 |
+
print(f"Loading from: {CHECKPOINT}")
|
| 58 |
+
print(f"Time: {datetime.now().isoformat()}")
|
| 59 |
+
print("=" * 60)
|
| 60 |
+
|
| 61 |
+
# Load from checkpoint (thinker only, ~63GB)
|
| 62 |
+
print("\n[1/5] Loading from checkpoint...")
|
| 63 |
+
mm = ModelManager(model_path=CHECKPOINT, checkpoint_path=CHECKPOINT)
|
| 64 |
+
mm.load()
|
| 65 |
+
quiz_gen = QuizGenerator(mm)
|
| 66 |
+
torch.cuda.empty_cache()
|
| 67 |
+
print(f" VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB")
|
| 68 |
+
|
| 69 |
+
# Pre-test: do old facts survive the save+reload?
|
| 70 |
+
print("\n[2/5] Pre-test: old facts from session 4...")
|
| 71 |
+
old_questions = [q for q in SESSION5_RECALL if q[2] == "old"]
|
| 72 |
+
for question, keywords, _ in old_questions:
|
| 73 |
+
ans = mm.generate([{"role": "user", "content": question}], max_new_tokens=150)
|
| 74 |
+
score = score_answer(ans, keywords)
|
| 75 |
+
status = "PASS" if score >= 0.5 else "FAIL"
|
| 76 |
+
print(f" [{status}] {question}")
|
| 77 |
+
print(f" {ans[:100]}")
|
| 78 |
+
torch.cuda.empty_cache()
|
| 79 |
+
|
| 80 |
+
# Teach session 5 facts
|
| 81 |
+
print("\n[3/5] Teaching session 5 facts...")
|
| 82 |
+
conv_buf = []
|
| 83 |
+
for i, msg in enumerate(SESSION5_MESSAGES):
|
| 84 |
+
print(f"\n S5 Message {i+1}: {msg}")
|
| 85 |
+
conv_buf.append({"role": "user", "content": msg})
|
| 86 |
+
response = mm.generate(conv_buf)
|
| 87 |
+
conv_buf.append({"role": "assistant", "content": response})
|
| 88 |
+
print(f" Claudia: {response[:100]}...")
|
| 89 |
+
|
| 90 |
+
quizzes = quiz_gen.generate(msg, response)
|
| 91 |
+
exchange = {"messages": [
|
| 92 |
+
{"role": "user", "content": msg},
|
| 93 |
+
{"role": "assistant", "content": response},
|
| 94 |
+
]}
|
| 95 |
+
positive = [qp for qp in quizzes if not qp["messages"][1]["content"].lower().startswith("no.")]
|
| 96 |
+
contrastive = [qp for qp in quizzes if qp["messages"][1]["content"].lower().startswith("no.")]
|
| 97 |
+
|
| 98 |
+
# Phase 1
|
| 99 |
+
phase1 = [exchange] + positive
|
| 100 |
+
torch.cuda.empty_cache()
|
| 101 |
+
loss1 = mm.absorb(phase1)
|
| 102 |
+
print(f" Phase 1: {len(phase1)} items, loss={loss1:.4f}")
|
| 103 |
+
|
| 104 |
+
# Phase 2
|
| 105 |
+
if contrastive:
|
| 106 |
+
torch.cuda.empty_cache()
|
| 107 |
+
confused = quick_verify_entities(mm, quiz_gen.known_entities)
|
| 108 |
+
if confused:
|
| 109 |
+
targeted = [qp for qp in contrastive
|
| 110 |
+
if any(n.lower() in (qp["messages"][0]["content"] + " " +
|
| 111 |
+
qp["messages"][1]["content"]).lower()
|
| 112 |
+
for n in confused)]
|
| 113 |
+
if targeted:
|
| 114 |
+
torch.cuda.empty_cache()
|
| 115 |
+
loss2 = mm.absorb(targeted)
|
| 116 |
+
print(f" Phase 2: {len(targeted)} targeted, loss={loss2:.4f}")
|
| 117 |
+
|
| 118 |
+
# Full recall test
|
| 119 |
+
print("\n[4/5] FULL RECALL TEST (old + new)")
|
| 120 |
+
print("=" * 60)
|
| 121 |
+
|
| 122 |
+
old_correct = 0
|
| 123 |
+
old_total = 0
|
| 124 |
+
new_correct = 0
|
| 125 |
+
new_total = 0
|
| 126 |
+
|
| 127 |
+
torch.cuda.empty_cache()
|
| 128 |
+
for question, keywords, qtype in SESSION5_RECALL:
|
| 129 |
+
ans = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
|
| 130 |
+
score = score_answer(ans, keywords)
|
| 131 |
+
passed = score >= 0.5
|
| 132 |
+
|
| 133 |
+
if qtype == "old":
|
| 134 |
+
old_total += 1
|
| 135 |
+
if passed: old_correct += 1
|
| 136 |
+
else:
|
| 137 |
+
new_total += 1
|
| 138 |
+
if passed: new_correct += 1
|
| 139 |
+
|
| 140 |
+
status = "PASS" if passed else "FAIL"
|
| 141 |
+
print(f" [{status}] ({qtype}) {question}")
|
| 142 |
+
print(f" {ans[:120]}")
|
| 143 |
+
|
| 144 |
+
total = old_correct + new_correct
|
| 145 |
+
total_q = old_total + new_total
|
| 146 |
+
print("=" * 60)
|
| 147 |
+
print(f"OLD FACTS (session 4): {old_correct}/{old_total} ({old_correct/old_total:.0%})")
|
| 148 |
+
print(f"NEW FACTS (session 5): {new_correct}/{new_total} ({new_correct/new_total:.0%})")
|
| 149 |
+
print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
|
| 150 |
+
print("=" * 60)
|
| 151 |
+
|
| 152 |
+
# Save results
|
| 153 |
+
print("\n[5/5] Done.")
|
| 154 |
+
results = {
|
| 155 |
+
"session": "5_cross",
|
| 156 |
+
"old_recall": f"{old_correct}/{old_total}",
|
| 157 |
+
"new_recall": f"{new_correct}/{new_total}",
|
| 158 |
+
"total": f"{total}/{total_q}",
|
| 159 |
+
"timestamp": datetime.now().isoformat(),
|
| 160 |
+
}
|
| 161 |
+
with open(os.path.join(LOG_DIR, "session5_cross_results.json"), 'w') as f:
|
| 162 |
+
json.dump(results, f, indent=2)
|
| 163 |
+
print(f"Results saved. CROSS-SESSION TEST COMPLETE.")
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
main()
|