theapemachine commited on
Commit
81ff944
·
verified ·
1 Parent(s): 7dd817a

Add benchmark harness: memory_tasks.py - Passkey retrieval and multi-hop memory

Browse files
Files changed (1) hide show
  1. benchmark/memory_tasks.py +265 -0
benchmark/memory_tasks.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Memory-specific benchmark tasks for Cortex.
3
+
4
+ These tasks specifically test capabilities that the Cortex MemoryBank
5
+ and other modules should enhance:
6
+
7
+ 1. PasskeyRetrieval: Can the model retrieve a passkey buried in distractor text?
8
+ Tests long-context memory and attention to specific details.
9
+
10
+ 2. MultiHopMemory: Can the model chain facts across multiple "memory writes"?
11
+ Tests compositional memory via multi-hop reasoning.
12
+ """
13
+
14
+ import random
15
+ import string
16
+ from typing import List, Dict, Optional, Tuple
17
+
18
+ from benchmark.scoring import generate_and_check
19
+
20
+
21
+ class PasskeyRetrieval:
22
+ """
23
+ Passkey Retrieval Test.
24
+
25
+ Embeds a random 5-digit passkey at a random position within distractor text,
26
+ then asks the model to retrieve it. Tests the model's ability to attend to
27
+ and remember specific details within noise.
28
+
29
+ Difficulty scales with context length (more distractor text = harder).
30
+ """
31
+ name = "passkey_retrieval"
32
+
33
+ def __init__(self, context_lengths: Optional[List[int]] = None):
34
+ if context_lengths is None:
35
+ self.context_lengths = [128, 256, 512, 1024]
36
+ else:
37
+ self.context_lengths = context_lengths
38
+
39
+ def generate_example(
40
+ self,
41
+ context_length: int,
42
+ seed: int = 42,
43
+ ) -> Tuple[str, str]:
44
+ """
45
+ Generate a passkey retrieval example.
46
+
47
+ Returns:
48
+ (prompt, expected_passkey)
49
+ """
50
+ rng = random.Random(seed)
51
+
52
+ # Generate passkey
53
+ passkey = "".join(rng.choices(string.digits, k=5))
54
+
55
+ # Generate distractor sentences
56
+ distractor_templates = [
57
+ "The quick brown fox jumps over the lazy dog.",
58
+ "A journey of a thousand miles begins with a single step.",
59
+ "To be or not to be, that is the question.",
60
+ "All that glitters is not gold.",
61
+ "Knowledge is power, and power corrupts.",
62
+ "The pen is mightier than the sword.",
63
+ "Actions speak louder than words.",
64
+ "Fortune favors the brave and the bold.",
65
+ "Time heals all wounds, they say.",
66
+ "Where there's a will, there's a way.",
67
+ ]
68
+
69
+ # Build distractor text to target approximate token count
70
+ # Rough estimate: 1 sentence ≈ 10-15 tokens
71
+ num_sentences = max(context_length // 12, 4)
72
+
73
+ sentences = [rng.choice(distractor_templates) for _ in range(num_sentences)]
74
+
75
+ # Insert passkey at a random position
76
+ insert_pos = rng.randint(len(sentences) // 4, 3 * len(sentences) // 4)
77
+ passkey_sentence = f"The secret passkey is: {passkey}. Remember this number."
78
+ sentences.insert(insert_pos, passkey_sentence)
79
+
80
+ text = " ".join(sentences)
81
+ prompt = text + "\n\nWhat was the secret passkey mentioned in the text above? The passkey is:"
82
+
83
+ return prompt, passkey
84
+
85
+ def run(
86
+ self,
87
+ model,
88
+ tokenizer,
89
+ n_per_length: int = 5,
90
+ device: str = "cuda",
91
+ seed: int = 42,
92
+ ) -> Dict:
93
+ """
94
+ Run the passkey retrieval benchmark.
95
+
96
+ Returns dict with results per context length.
97
+ """
98
+ results = {}
99
+
100
+ for ctx_len in self.context_lengths:
101
+ correct = 0
102
+ total = 0
103
+
104
+ for i in range(n_per_length):
105
+ prompt, expected = self.generate_example(ctx_len, seed=seed + i + ctx_len)
106
+
107
+ is_correct, generated = generate_and_check(
108
+ model, tokenizer, prompt, expected,
109
+ max_new_tokens=16, device=device, exact_match=False,
110
+ )
111
+
112
+ correct += int(is_correct)
113
+ total += 1
114
+
115
+ results[f"ctx_{ctx_len}"] = {
116
+ "accuracy": correct / total if total > 0 else 0.0,
117
+ "correct": correct,
118
+ "total": total,
119
+ }
120
+
121
+ # Overall accuracy
122
+ total_correct = sum(r["correct"] for r in results.values())
123
+ total_count = sum(r["total"] for r in results.values())
124
+ results["overall"] = {
125
+ "accuracy": total_correct / total_count if total_count > 0 else 0.0,
126
+ "correct": total_correct,
127
+ "total": total_count,
128
+ }
129
+
130
+ return results
131
+
132
+
133
+ class MultiHopMemory:
134
+ """
135
+ Multi-Hop Memory Reasoning Test.
136
+
137
+ Presents a chain of facts:
138
+ "Alice lives in Paris."
139
+ "The capital of France is Paris."
140
+ "Alice's neighbor is Bob."
141
+ Then asks a question requiring chaining: "Where does Alice's neighbor likely live?"
142
+
143
+ Tests whether Cortex's MemoryBank helps with compositional reasoning.
144
+ """
145
+ name = "multi_hop_memory"
146
+
147
+ FACT_CHAINS = [
148
+ {
149
+ "facts": [
150
+ "Alice lives in Paris.",
151
+ "Paris is the capital of France.",
152
+ "France is in Europe.",
153
+ ],
154
+ "question": "What continent does Alice live on?",
155
+ "answer": "Europe",
156
+ },
157
+ {
158
+ "facts": [
159
+ "Bob works at TechCorp.",
160
+ "TechCorp is headquartered in San Francisco.",
161
+ "San Francisco is in California.",
162
+ ],
163
+ "question": "What state does Bob work in?",
164
+ "answer": "California",
165
+ },
166
+ {
167
+ "facts": [
168
+ "The blue car belongs to Maria.",
169
+ "Maria is a doctor.",
170
+ "Doctors work at hospitals.",
171
+ ],
172
+ "question": "Where does the owner of the blue car work?",
173
+ "answer": "hospital",
174
+ },
175
+ {
176
+ "facts": [
177
+ "Tom's favorite book was written by Jane.",
178
+ "Jane lives in London.",
179
+ "London is known for its rain.",
180
+ ],
181
+ "question": "What is the weather like where Tom's favorite author lives?",
182
+ "answer": "rain",
183
+ },
184
+ {
185
+ "facts": [
186
+ "The red house is on Elm Street.",
187
+ "Elm Street is in Springfield.",
188
+ "Springfield is a small town.",
189
+ ],
190
+ "question": "What kind of place is the red house located in?",
191
+ "answer": "small town",
192
+ },
193
+ {
194
+ "facts": [
195
+ "Lucy adopted a golden retriever.",
196
+ "Golden retrievers are a breed of dog.",
197
+ "Dogs are mammals.",
198
+ ],
199
+ "question": "What type of animal did Lucy adopt?",
200
+ "answer": "mammal",
201
+ },
202
+ {
203
+ "facts": [
204
+ "The Nile flows through Egypt.",
205
+ "Egypt is in Africa.",
206
+ "Africa is the second largest continent.",
207
+ ],
208
+ "question": "What continent does the Nile flow through?",
209
+ "answer": "Africa",
210
+ },
211
+ {
212
+ "facts": [
213
+ "Sarah plays the violin.",
214
+ "The violin is a string instrument.",
215
+ "String instruments are part of an orchestra.",
216
+ ],
217
+ "question": "What type of ensemble could Sarah play in?",
218
+ "answer": "orchestra",
219
+ },
220
+ ]
221
+
222
+ def run(
223
+ self,
224
+ model,
225
+ tokenizer,
226
+ n: Optional[int] = None,
227
+ device: str = "cuda",
228
+ ) -> Dict:
229
+ """
230
+ Run multi-hop memory benchmark.
231
+
232
+ Returns accuracy and per-example results.
233
+ """
234
+ chains = self.FACT_CHAINS
235
+ if n is not None:
236
+ chains = chains[:n]
237
+
238
+ correct = 0
239
+ total = 0
240
+ per_example = []
241
+
242
+ for chain in chains:
243
+ facts_text = " ".join(chain["facts"])
244
+ prompt = f"Facts: {facts_text}\n\nQuestion: {chain['question']}\nAnswer:"
245
+
246
+ is_correct, generated = generate_and_check(
247
+ model, tokenizer, prompt, chain["answer"],
248
+ max_new_tokens=32, device=device, exact_match=False,
249
+ )
250
+
251
+ correct += int(is_correct)
252
+ total += 1
253
+ per_example.append({
254
+ "question": chain["question"],
255
+ "expected": chain["answer"],
256
+ "generated": generated[:100],
257
+ "correct": is_correct,
258
+ })
259
+
260
+ return {
261
+ "accuracy": correct / total if total > 0 else 0.0,
262
+ "correct": correct,
263
+ "total": total,
264
+ "per_example": per_example,
265
+ }