msrcam commited on
Commit
78ae5ee
·
verified ·
1 Parent(s): 77ccd8a

Upload tests/test_session5_cross.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tests/test_session5_cross.py +166 -0
tests/test_session5_cross.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Session 5 Cross-Session Test — Loads from session 4f checkpoint
3
+ Tests: old facts survive + new facts learned
4
+ """
5
+ import json, os, sys, time, torch
6
+ from datetime import datetime
7
+
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
+
10
+ sys.path.insert(0, "/workspace")
11
+ from persistent_absorber import ModelManager, QuizGenerator, check_personality, CONSOLIDATION_EPOCHS
12
+
13
+ CHECKPOINT = "/workspace/checkpoints/claudia_session_20260323_0623"
14
+ LOG_DIR = "/workspace/logs"
15
+
16
+ SESSION5_MESSAGES = [
17
+ "My neighbor Dave is a firefighter in Denver. He's been doing it for 15 years.",
18
+ "My colleague Rina is a data scientist at Google. She's based in Mountain View.",
19
+ "I just bought a Tesla Model 3. It's midnight blue.",
20
+ ]
21
+
22
+ # Test both old (session 4) and new (session 5) facts
23
+ SESSION5_RECALL = [
24
+ ("What does Jordan do?", ["marine biologist"], "old"),
25
+ ("Where does Elena live?", ["portland"], "old"),
26
+ ("What is Marcus's job?", ["architect"], "old"),
27
+ ("Where does Marcus live?", ["seattle"], "old"),
28
+ ("What does Priya do?", ["neurosurgeon"], "old"),
29
+ ("Where does Matt work?", ["novamind"], "old"),
30
+ ("What does Matt's neighbor Dave do?", ["firefighter"], "new"),
31
+ ("Where does Dave live?", ["denver"], "new"),
32
+ ("What does Rina do?", ["data scientist"], "new"),
33
+ ("Where does Rina work?", ["google"], "new"),
34
+ ("What car did Matt buy?", ["tesla", "model 3"], "new"),
35
+ ]
36
+
37
+ def score_answer(answer, keywords):
38
+ a = answer.lower()
39
+ return sum(1 for k in keywords if k in a) / len(keywords)
40
+
41
+ def quick_verify_entities(mm, entities):
42
+ confused = set()
43
+ for name, info in entities.items():
44
+ if info.get("job"):
45
+ ans = mm.generate([{"role": "user", "content": f"What does {name} do?"}], max_new_tokens=100)
46
+ if info["job"].lower() not in ans.lower():
47
+ confused.add(name)
48
+ if info.get("city"):
49
+ ans = mm.generate([{"role": "user", "content": f"Where does {name} live?"}], max_new_tokens=100)
50
+ if info["city"].lower() not in ans.lower():
51
+ confused.add(name)
52
+ return confused
53
+
54
+ def main():
55
+ print("=" * 60)
56
+ print("SESSION 5 CROSS-SESSION TEST")
57
+ print(f"Loading from: {CHECKPOINT}")
58
+ print(f"Time: {datetime.now().isoformat()}")
59
+ print("=" * 60)
60
+
61
+ # Load from checkpoint (thinker only, ~63GB)
62
+ print("\n[1/5] Loading from checkpoint...")
63
+ mm = ModelManager(model_path=CHECKPOINT, checkpoint_path=CHECKPOINT)
64
+ mm.load()
65
+ quiz_gen = QuizGenerator(mm)
66
+ torch.cuda.empty_cache()
67
+ print(f" VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB")
68
+
69
+ # Pre-test: do old facts survive the save+reload?
70
+ print("\n[2/5] Pre-test: old facts from session 4...")
71
+ old_questions = [q for q in SESSION5_RECALL if q[2] == "old"]
72
+ for question, keywords, _ in old_questions:
73
+ ans = mm.generate([{"role": "user", "content": question}], max_new_tokens=150)
74
+ score = score_answer(ans, keywords)
75
+ status = "PASS" if score >= 0.5 else "FAIL"
76
+ print(f" [{status}] {question}")
77
+ print(f" {ans[:100]}")
78
+ torch.cuda.empty_cache()
79
+
80
+ # Teach session 5 facts
81
+ print("\n[3/5] Teaching session 5 facts...")
82
+ conv_buf = []
83
+ for i, msg in enumerate(SESSION5_MESSAGES):
84
+ print(f"\n S5 Message {i+1}: {msg}")
85
+ conv_buf.append({"role": "user", "content": msg})
86
+ response = mm.generate(conv_buf)
87
+ conv_buf.append({"role": "assistant", "content": response})
88
+ print(f" Claudia: {response[:100]}...")
89
+
90
+ quizzes = quiz_gen.generate(msg, response)
91
+ exchange = {"messages": [
92
+ {"role": "user", "content": msg},
93
+ {"role": "assistant", "content": response},
94
+ ]}
95
+ positive = [qp for qp in quizzes if not qp["messages"][1]["content"].lower().startswith("no.")]
96
+ contrastive = [qp for qp in quizzes if qp["messages"][1]["content"].lower().startswith("no.")]
97
+
98
+ # Phase 1
99
+ phase1 = [exchange] + positive
100
+ torch.cuda.empty_cache()
101
+ loss1 = mm.absorb(phase1)
102
+ print(f" Phase 1: {len(phase1)} items, loss={loss1:.4f}")
103
+
104
+ # Phase 2
105
+ if contrastive:
106
+ torch.cuda.empty_cache()
107
+ confused = quick_verify_entities(mm, quiz_gen.known_entities)
108
+ if confused:
109
+ targeted = [qp for qp in contrastive
110
+ if any(n.lower() in (qp["messages"][0]["content"] + " " +
111
+ qp["messages"][1]["content"]).lower()
112
+ for n in confused)]
113
+ if targeted:
114
+ torch.cuda.empty_cache()
115
+ loss2 = mm.absorb(targeted)
116
+ print(f" Phase 2: {len(targeted)} targeted, loss={loss2:.4f}")
117
+
118
+ # Full recall test
119
+ print("\n[4/5] FULL RECALL TEST (old + new)")
120
+ print("=" * 60)
121
+
122
+ old_correct = 0
123
+ old_total = 0
124
+ new_correct = 0
125
+ new_total = 0
126
+
127
+ torch.cuda.empty_cache()
128
+ for question, keywords, qtype in SESSION5_RECALL:
129
+ ans = mm.generate([{"role": "user", "content": question}], max_new_tokens=200)
130
+ score = score_answer(ans, keywords)
131
+ passed = score >= 0.5
132
+
133
+ if qtype == "old":
134
+ old_total += 1
135
+ if passed: old_correct += 1
136
+ else:
137
+ new_total += 1
138
+ if passed: new_correct += 1
139
+
140
+ status = "PASS" if passed else "FAIL"
141
+ print(f" [{status}] ({qtype}) {question}")
142
+ print(f" {ans[:120]}")
143
+
144
+ total = old_correct + new_correct
145
+ total_q = old_total + new_total
146
+ print("=" * 60)
147
+ print(f"OLD FACTS (session 4): {old_correct}/{old_total} ({old_correct/old_total:.0%})")
148
+ print(f"NEW FACTS (session 5): {new_correct}/{new_total} ({new_correct/new_total:.0%})")
149
+ print(f"TOTAL: {total}/{total_q} ({total/total_q:.0%})")
150
+ print("=" * 60)
151
+
152
+ # Save results
153
+ print("\n[5/5] Done.")
154
+ results = {
155
+ "session": "5_cross",
156
+ "old_recall": f"{old_correct}/{old_total}",
157
+ "new_recall": f"{new_correct}/{new_total}",
158
+ "total": f"{total}/{total_q}",
159
+ "timestamp": datetime.now().isoformat(),
160
+ }
161
+ with open(os.path.join(LOG_DIR, "session5_cross_results.json"), 'w') as f:
162
+ json.dump(results, f, indent=2)
163
+ print(f"Results saved. CROSS-SESSION TEST COMPLETE.")
164
+
165
+ if __name__ == "__main__":
166
+ main()