""" Memory-specific benchmark tasks for Cortex. These tasks specifically test capabilities that the Cortex MemoryBank and other modules should enhance: 1. PasskeyRetrieval: Can the model retrieve a passkey buried in distractor text? Tests long-context memory and attention to specific details. 2. MultiHopMemory: Can the model chain facts across multiple "memory writes"? Tests compositional memory via multi-hop reasoning. """ import random import string from typing import List, Dict, Optional, Tuple from benchmark.scoring import generate_and_check from cortex.torch_device import resolve_torch_device class PasskeyRetrieval: """ Passkey Retrieval Test. Embeds a random 5-digit passkey at a random position within distractor text, then asks the model to retrieve it. Tests the model's ability to attend to and remember specific details within noise. Difficulty scales with context length (more distractor text = harder). """ name = "passkey_retrieval" def __init__(self, context_lengths: Optional[List[int]] = None): if context_lengths is None: self.context_lengths = [128, 256, 512, 1024] else: self.context_lengths = context_lengths def generate_example( self, context_length: int, seed: int = 42, ) -> Tuple[str, str]: """ Generate a passkey retrieval example. Returns: (prompt, expected_passkey) """ rng = random.Random(seed) # Generate passkey passkey = "".join(rng.choices(string.digits, k=5)) # Generate distractor sentences distractor_templates = [ "The quick brown fox jumps over the lazy dog.", "A journey of a thousand miles begins with a single step.", "To be or not to be, that is the question.", "All that glitters is not gold.", "Knowledge is power, and power corrupts.", "The pen is mightier than the sword.", "Actions speak louder than words.", "Fortune favors the brave and the bold.", "Time heals all wounds, they say.", "Where there's a will, there's a way.", ] # Build distractor text to target approximate token count # Rough estimate: 1 sentence ≈ 10-15 tokens num_sentences = max(context_length // 12, 4) sentences = [rng.choice(distractor_templates) for _ in range(num_sentences)] # Insert passkey at a random position insert_pos = rng.randint(len(sentences) // 4, 3 * len(sentences) // 4) passkey_sentence = f"The secret passkey is: {passkey}. Remember this number." sentences.insert(insert_pos, passkey_sentence) text = " ".join(sentences) prompt = text + "\n\nWhat was the secret passkey mentioned in the text above? The passkey is:" return prompt, passkey def run( self, model, tokenizer, n_per_length: int = 5, device: Optional[str] = None, seed: int = 42, ) -> Dict: """ Run the passkey retrieval benchmark. Returns dict with results per context length. """ if device is None: device = resolve_torch_device("auto") results = {} for ctx_len in self.context_lengths: correct = 0 total = 0 for i in range(n_per_length): prompt, expected = self.generate_example(ctx_len, seed=seed + i + ctx_len) is_correct, generated = generate_and_check( model, tokenizer, prompt, expected, max_new_tokens=16, device=device, exact_match=False, ) correct += int(is_correct) total += 1 results[f"ctx_{ctx_len}"] = { "accuracy": correct / total if total > 0 else 0.0, "correct": correct, "total": total, } # Overall accuracy total_correct = sum(r["correct"] for r in results.values()) total_count = sum(r["total"] for r in results.values()) results["overall"] = { "accuracy": total_correct / total_count if total_count > 0 else 0.0, "correct": total_correct, "total": total_count, } return results class MultiHopMemory: """ Multi-Hop Memory Reasoning Test. Presents a chain of facts: "Alice lives in Paris." "The capital of France is Paris." "Alice's neighbor is Bob." Then asks a question requiring chaining: "Where does Alice's neighbor likely live?" Tests whether Cortex's MemoryBank helps with compositional reasoning. """ name = "multi_hop_memory" FACT_CHAINS = [ { "facts": [ "Alice lives in Paris.", "Paris is the capital of France.", "France is in Europe.", ], "question": "What continent does Alice live on?", "answer": "Europe", }, { "facts": [ "Bob works at TechCorp.", "TechCorp is headquartered in San Francisco.", "San Francisco is in California.", ], "question": "What state does Bob work in?", "answer": "California", }, { "facts": [ "The blue car belongs to Maria.", "Maria is a doctor.", "Doctors work at hospitals.", ], "question": "Where does the owner of the blue car work?", "answer": "hospital", }, { "facts": [ "Tom's favorite book was written by Jane.", "Jane lives in London.", "London is known for its rain.", ], "question": "What is the weather like where Tom's favorite author lives?", "answer": "rain", }, { "facts": [ "The red house is on Elm Street.", "Elm Street is in Springfield.", "Springfield is a small town.", ], "question": "What kind of place is the red house located in?", "answer": "small town", }, { "facts": [ "Lucy adopted a golden retriever.", "Golden retrievers are a breed of dog.", "Dogs are mammals.", ], "question": "What type of animal did Lucy adopt?", "answer": "mammal", }, { "facts": [ "The Nile flows through Egypt.", "Egypt is in Africa.", "Africa is the second largest continent.", ], "question": "What continent does the Nile flow through?", "answer": "Africa", }, { "facts": [ "Sarah plays the violin.", "The violin is a string instrument.", "String instruments are part of an orchestra.", ], "question": "What type of ensemble could Sarah play in?", "answer": "orchestra", }, ] def run( self, model, tokenizer, n: Optional[int] = None, device: Optional[str] = None, ) -> Dict: """ Run multi-hop memory benchmark. Returns accuracy and per-example results. """ if device is None: device = resolve_torch_device("auto") chains = self.FACT_CHAINS if n is not None: chains = chains[:n] correct = 0 total = 0 per_example = [] for chain in chains: facts_text = " ".join(chain["facts"]) prompt = f"Facts: {facts_text}\n\nQuestion: {chain['question']}\nAnswer:" is_correct, generated = generate_and_check( model, tokenizer, prompt, chain["answer"], max_new_tokens=32, device=device, exact_match=False, ) correct += int(is_correct) total += 1 per_example.append({ "question": chain["question"], "expected": chain["answer"], "generated": generated[:100], "correct": is_correct, }) return { "accuracy": correct / total if total > 0 else 0.0, "correct": correct, "total": total, "per_example": per_example, }