Refactor device handling in benchmark and cortex modules to use resolve_torch_device function for improved device selection. Update README with benchmark table formatting and additional device options.
0ac64e3 | """ | |
| Memory-specific benchmark tasks for Cortex. | |
| These tasks specifically test capabilities that the Cortex MemoryBank | |
| and other modules should enhance: | |
| 1. PasskeyRetrieval: Can the model retrieve a passkey buried in distractor text? | |
| Tests long-context memory and attention to specific details. | |
| 2. MultiHopMemory: Can the model chain facts across multiple "memory writes"? | |
| Tests compositional memory via multi-hop reasoning. | |
| """ | |
| import random | |
| import string | |
| from typing import List, Dict, Optional, Tuple | |
| from benchmark.scoring import generate_and_check | |
| from cortex.torch_device import resolve_torch_device | |
| class PasskeyRetrieval: | |
| """ | |
| Passkey Retrieval Test. | |
| Embeds a random 5-digit passkey at a random position within distractor text, | |
| then asks the model to retrieve it. Tests the model's ability to attend to | |
| and remember specific details within noise. | |
| Difficulty scales with context length (more distractor text = harder). | |
| """ | |
| name = "passkey_retrieval" | |
| def __init__(self, context_lengths: Optional[List[int]] = None): | |
| if context_lengths is None: | |
| self.context_lengths = [128, 256, 512, 1024] | |
| else: | |
| self.context_lengths = context_lengths | |
| def generate_example( | |
| self, | |
| context_length: int, | |
| seed: int = 42, | |
| ) -> Tuple[str, str]: | |
| """ | |
| Generate a passkey retrieval example. | |
| Returns: | |
| (prompt, expected_passkey) | |
| """ | |
| rng = random.Random(seed) | |
| # Generate passkey | |
| passkey = "".join(rng.choices(string.digits, k=5)) | |
| # Generate distractor sentences | |
| distractor_templates = [ | |
| "The quick brown fox jumps over the lazy dog.", | |
| "A journey of a thousand miles begins with a single step.", | |
| "To be or not to be, that is the question.", | |
| "All that glitters is not gold.", | |
| "Knowledge is power, and power corrupts.", | |
| "The pen is mightier than the sword.", | |
| "Actions speak louder than words.", | |
| "Fortune favors the brave and the bold.", | |
| "Time heals all wounds, they say.", | |
| "Where there's a will, there's a way.", | |
| ] | |
| # Build distractor text to target approximate token count | |
| # Rough estimate: 1 sentence ≈ 10-15 tokens | |
| num_sentences = max(context_length // 12, 4) | |
| sentences = [rng.choice(distractor_templates) for _ in range(num_sentences)] | |
| # Insert passkey at a random position | |
| insert_pos = rng.randint(len(sentences) // 4, 3 * len(sentences) // 4) | |
| passkey_sentence = f"The secret passkey is: {passkey}. Remember this number." | |
| sentences.insert(insert_pos, passkey_sentence) | |
| text = " ".join(sentences) | |
| prompt = text + "\n\nWhat was the secret passkey mentioned in the text above? The passkey is:" | |
| return prompt, passkey | |
| def run( | |
| self, | |
| model, | |
| tokenizer, | |
| n_per_length: int = 5, | |
| device: Optional[str] = None, | |
| seed: int = 42, | |
| ) -> Dict: | |
| """ | |
| Run the passkey retrieval benchmark. | |
| Returns dict with results per context length. | |
| """ | |
| if device is None: | |
| device = resolve_torch_device("auto") | |
| results = {} | |
| for ctx_len in self.context_lengths: | |
| correct = 0 | |
| total = 0 | |
| for i in range(n_per_length): | |
| prompt, expected = self.generate_example(ctx_len, seed=seed + i + ctx_len) | |
| is_correct, generated = generate_and_check( | |
| model, tokenizer, prompt, expected, | |
| max_new_tokens=16, device=device, exact_match=False, | |
| ) | |
| correct += int(is_correct) | |
| total += 1 | |
| results[f"ctx_{ctx_len}"] = { | |
| "accuracy": correct / total if total > 0 else 0.0, | |
| "correct": correct, | |
| "total": total, | |
| } | |
| # Overall accuracy | |
| total_correct = sum(r["correct"] for r in results.values()) | |
| total_count = sum(r["total"] for r in results.values()) | |
| results["overall"] = { | |
| "accuracy": total_correct / total_count if total_count > 0 else 0.0, | |
| "correct": total_correct, | |
| "total": total_count, | |
| } | |
| return results | |
| class MultiHopMemory: | |
| """ | |
| Multi-Hop Memory Reasoning Test. | |
| Presents a chain of facts: | |
| "Alice lives in Paris." | |
| "The capital of France is Paris." | |
| "Alice's neighbor is Bob." | |
| Then asks a question requiring chaining: "Where does Alice's neighbor likely live?" | |
| Tests whether Cortex's MemoryBank helps with compositional reasoning. | |
| """ | |
| name = "multi_hop_memory" | |
| FACT_CHAINS = [ | |
| { | |
| "facts": [ | |
| "Alice lives in Paris.", | |
| "Paris is the capital of France.", | |
| "France is in Europe.", | |
| ], | |
| "question": "What continent does Alice live on?", | |
| "answer": "Europe", | |
| }, | |
| { | |
| "facts": [ | |
| "Bob works at TechCorp.", | |
| "TechCorp is headquartered in San Francisco.", | |
| "San Francisco is in California.", | |
| ], | |
| "question": "What state does Bob work in?", | |
| "answer": "California", | |
| }, | |
| { | |
| "facts": [ | |
| "The blue car belongs to Maria.", | |
| "Maria is a doctor.", | |
| "Doctors work at hospitals.", | |
| ], | |
| "question": "Where does the owner of the blue car work?", | |
| "answer": "hospital", | |
| }, | |
| { | |
| "facts": [ | |
| "Tom's favorite book was written by Jane.", | |
| "Jane lives in London.", | |
| "London is known for its rain.", | |
| ], | |
| "question": "What is the weather like where Tom's favorite author lives?", | |
| "answer": "rain", | |
| }, | |
| { | |
| "facts": [ | |
| "The red house is on Elm Street.", | |
| "Elm Street is in Springfield.", | |
| "Springfield is a small town.", | |
| ], | |
| "question": "What kind of place is the red house located in?", | |
| "answer": "small town", | |
| }, | |
| { | |
| "facts": [ | |
| "Lucy adopted a golden retriever.", | |
| "Golden retrievers are a breed of dog.", | |
| "Dogs are mammals.", | |
| ], | |
| "question": "What type of animal did Lucy adopt?", | |
| "answer": "mammal", | |
| }, | |
| { | |
| "facts": [ | |
| "The Nile flows through Egypt.", | |
| "Egypt is in Africa.", | |
| "Africa is the second largest continent.", | |
| ], | |
| "question": "What continent does the Nile flow through?", | |
| "answer": "Africa", | |
| }, | |
| { | |
| "facts": [ | |
| "Sarah plays the violin.", | |
| "The violin is a string instrument.", | |
| "String instruments are part of an orchestra.", | |
| ], | |
| "question": "What type of ensemble could Sarah play in?", | |
| "answer": "orchestra", | |
| }, | |
| ] | |
| def run( | |
| self, | |
| model, | |
| tokenizer, | |
| n: Optional[int] = None, | |
| device: Optional[str] = None, | |
| ) -> Dict: | |
| """ | |
| Run multi-hop memory benchmark. | |
| Returns accuracy and per-example results. | |
| """ | |
| if device is None: | |
| device = resolve_torch_device("auto") | |
| chains = self.FACT_CHAINS | |
| if n is not None: | |
| chains = chains[:n] | |
| correct = 0 | |
| total = 0 | |
| per_example = [] | |
| for chain in chains: | |
| facts_text = " ".join(chain["facts"]) | |
| prompt = f"Facts: {facts_text}\n\nQuestion: {chain['question']}\nAnswer:" | |
| is_correct, generated = generate_and_check( | |
| model, tokenizer, prompt, chain["answer"], | |
| max_new_tokens=32, device=device, exact_match=False, | |
| ) | |
| correct += int(is_correct) | |
| total += 1 | |
| per_example.append({ | |
| "question": chain["question"], | |
| "expected": chain["answer"], | |
| "generated": generated[:100], | |
| "correct": is_correct, | |
| }) | |
| return { | |
| "accuracy": correct / total if total > 0 else 0.0, | |
| "correct": correct, | |
| "total": total, | |
| "per_example": per_example, | |
| } | |