File size: 8,877 Bytes
81ff944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
81ff944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
81ff944
 
 
 
 
 
 
0ac64e3
 
81ff944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
81ff944
 
 
 
 
 
0ac64e3
 
81ff944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
Memory-specific benchmark tasks for Cortex.

These tasks specifically test capabilities that the Cortex MemoryBank
and other modules should enhance:

1. PasskeyRetrieval: Can the model retrieve a passkey buried in distractor text?
   Tests long-context memory and attention to specific details.

2. MultiHopMemory: Can the model chain facts across multiple "memory writes"?
   Tests compositional memory via multi-hop reasoning.
"""

import random
import string
from typing import List, Dict, Optional, Tuple

from benchmark.scoring import generate_and_check
from cortex.torch_device import resolve_torch_device


class PasskeyRetrieval:
    """
    Passkey Retrieval Test.
    
    Embeds a random 5-digit passkey at a random position within distractor text,
    then asks the model to retrieve it. Tests the model's ability to attend to 
    and remember specific details within noise.
    
    Difficulty scales with context length (more distractor text = harder).
    """
    name = "passkey_retrieval"
    
    def __init__(self, context_lengths: Optional[List[int]] = None):
        if context_lengths is None:
            self.context_lengths = [128, 256, 512, 1024]
        else:
            self.context_lengths = context_lengths
    
    def generate_example(
        self, 
        context_length: int, 
        seed: int = 42,
    ) -> Tuple[str, str]:
        """
        Generate a passkey retrieval example.
        
        Returns:
            (prompt, expected_passkey)
        """
        rng = random.Random(seed)
        
        # Generate passkey
        passkey = "".join(rng.choices(string.digits, k=5))
        
        # Generate distractor sentences
        distractor_templates = [
            "The quick brown fox jumps over the lazy dog.",
            "A journey of a thousand miles begins with a single step.",
            "To be or not to be, that is the question.",
            "All that glitters is not gold.",
            "Knowledge is power, and power corrupts.",
            "The pen is mightier than the sword.",
            "Actions speak louder than words.",
            "Fortune favors the brave and the bold.",
            "Time heals all wounds, they say.",
            "Where there's a will, there's a way.",
        ]
        
        # Build distractor text to target approximate token count
        # Rough estimate: 1 sentence ≈ 10-15 tokens
        num_sentences = max(context_length // 12, 4)
        
        sentences = [rng.choice(distractor_templates) for _ in range(num_sentences)]
        
        # Insert passkey at a random position
        insert_pos = rng.randint(len(sentences) // 4, 3 * len(sentences) // 4)
        passkey_sentence = f"The secret passkey is: {passkey}. Remember this number."
        sentences.insert(insert_pos, passkey_sentence)
        
        text = " ".join(sentences)
        prompt = text + "\n\nWhat was the secret passkey mentioned in the text above? The passkey is:"
        
        return prompt, passkey
    
    def run(
        self,
        model,
        tokenizer,
        n_per_length: int = 5,
        device: Optional[str] = None,
        seed: int = 42,
    ) -> Dict:
        """
        Run the passkey retrieval benchmark.
        
        Returns dict with results per context length.
        """
        if device is None:
            device = resolve_torch_device("auto")
        results = {}
        
        for ctx_len in self.context_lengths:
            correct = 0
            total = 0
            
            for i in range(n_per_length):
                prompt, expected = self.generate_example(ctx_len, seed=seed + i + ctx_len)
                
                is_correct, generated = generate_and_check(
                    model, tokenizer, prompt, expected,
                    max_new_tokens=16, device=device, exact_match=False,
                )
                
                correct += int(is_correct)
                total += 1
            
            results[f"ctx_{ctx_len}"] = {
                "accuracy": correct / total if total > 0 else 0.0,
                "correct": correct,
                "total": total,
            }
        
        # Overall accuracy
        total_correct = sum(r["correct"] for r in results.values())
        total_count = sum(r["total"] for r in results.values())
        results["overall"] = {
            "accuracy": total_correct / total_count if total_count > 0 else 0.0,
            "correct": total_correct,
            "total": total_count,
        }
        
        return results


class MultiHopMemory:
    """
    Multi-Hop Memory Reasoning Test.
    
    Presents a chain of facts:
      "Alice lives in Paris."
      "The capital of France is Paris."  
      "Alice's neighbor is Bob."
    Then asks a question requiring chaining: "Where does Alice's neighbor likely live?"
    
    Tests whether Cortex's MemoryBank helps with compositional reasoning.
    """
    name = "multi_hop_memory"
    
    FACT_CHAINS = [
        {
            "facts": [
                "Alice lives in Paris.",
                "Paris is the capital of France.",
                "France is in Europe.",
            ],
            "question": "What continent does Alice live on?",
            "answer": "Europe",
        },
        {
            "facts": [
                "Bob works at TechCorp.",
                "TechCorp is headquartered in San Francisco.",
                "San Francisco is in California.",
            ],
            "question": "What state does Bob work in?",
            "answer": "California",
        },
        {
            "facts": [
                "The blue car belongs to Maria.",
                "Maria is a doctor.",
                "Doctors work at hospitals.",
            ],
            "question": "Where does the owner of the blue car work?",
            "answer": "hospital",
        },
        {
            "facts": [
                "Tom's favorite book was written by Jane.",
                "Jane lives in London.",
                "London is known for its rain.",
            ],
            "question": "What is the weather like where Tom's favorite author lives?",
            "answer": "rain",
        },
        {
            "facts": [
                "The red house is on Elm Street.",
                "Elm Street is in Springfield.",
                "Springfield is a small town.",
            ],
            "question": "What kind of place is the red house located in?",
            "answer": "small town",
        },
        {
            "facts": [
                "Lucy adopted a golden retriever.",
                "Golden retrievers are a breed of dog.",
                "Dogs are mammals.",
            ],
            "question": "What type of animal did Lucy adopt?",
            "answer": "mammal",
        },
        {
            "facts": [
                "The Nile flows through Egypt.",
                "Egypt is in Africa.",
                "Africa is the second largest continent.",
            ],
            "question": "What continent does the Nile flow through?",
            "answer": "Africa",
        },
        {
            "facts": [
                "Sarah plays the violin.",
                "The violin is a string instrument.",
                "String instruments are part of an orchestra.",
            ],
            "question": "What type of ensemble could Sarah play in?",
            "answer": "orchestra",
        },
    ]
    
    def run(
        self,
        model,
        tokenizer,
        n: Optional[int] = None,
        device: Optional[str] = None,
    ) -> Dict:
        """
        Run multi-hop memory benchmark.
        
        Returns accuracy and per-example results.
        """
        if device is None:
            device = resolve_torch_device("auto")
        chains = self.FACT_CHAINS
        if n is not None:
            chains = chains[:n]
        
        correct = 0
        total = 0
        per_example = []
        
        for chain in chains:
            facts_text = " ".join(chain["facts"])
            prompt = f"Facts: {facts_text}\n\nQuestion: {chain['question']}\nAnswer:"
            
            is_correct, generated = generate_and_check(
                model, tokenizer, prompt, chain["answer"],
                max_new_tokens=32, device=device, exact_match=False,
            )
            
            correct += int(is_correct)
            total += 1
            per_example.append({
                "question": chain["question"],
                "expected": chain["answer"],
                "generated": generated[:100],
                "correct": is_correct,
            })
        
        return {
            "accuracy": correct / total if total > 0 else 0.0,
            "correct": correct,
            "total": total,
            "per_example": per_example,
        }