| """ |
| needle_test.py - "Needle in a Haystack" test for memory validation. |
| |
| This is the KILLER TEST that proves if RippleGPT can retain long-term information |
| through the Ripple Field (ALiBi-style attention) mechanism. |
| |
| The test: |
| 1. Places a "needle" (SECRET_PASSWORD = "bananas") at the beginning of a long text |
| 2. Adds hundreds of lines of Python code as "haystack" |
| 3. Asks the model to remember the password |
| 4. Measures if it can retrieve the information |
| |
| β οΈ TECHNICAL NOTE - MEMORY COMPLEXITY: O(TΒ²) |
| ββββββββββββββββββββββββββββββββββββββββββββββ |
| RippleGPT uses full quadratic attention. |
| |
| For T context tokens: |
| β’ Memory β TΒ² Γ 4 bytes Γ n_heads Γ n_layers |
| β’ T=1000 β ~4MB per head |
| β’ T=3000 β ~36MB per head |
| β’ T=8000 β ~256MB per head |
| |
| The BENEFIT of Ripple Field is NOT memory efficiency, |
| but rather EXTRAPOLATION: train on 256, infer on 1024+. |
| |
| Usage: |
| python validation/memory/needle_test.py --config medium |
| python validation/memory/needle_test.py --config large --depths 100 200 500 1000 |
| """ |
|
|
| import os |
| import sys |
| import time |
| import pickle |
| import argparse |
| import json |
| from datetime import datetime |
| from typing import List, Dict, Tuple |
| import random |
|
|
| import torch |
| import psutil |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
|
| from src.model import RippleGPT |
| from src.config import RippleConfig |
| from validation.memory.model_configs import get_config |
|
|
| |
| DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') |
| CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') |
| RESULTS_DIR = os.path.join(os.path.dirname(__file__), 'results') |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
|
|
| |
| |
| |
|
|
| NEEDLES = [ |
| ("SECRET_PASSWORD", "bananas"), |
| ("API_KEY", "sk-abc123xyz789"), |
| ("DATABASE_URL", "postgres://localhost:5432/mydb"), |
| ("ADMIN_PASSWORD", "super_secret_2024"), |
| ("MAGIC_NUMBER", "42"), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| HAYSTACK_SNIPPETS = [ |
| ''' |
| def process_data(items): |
| """Process a list of items.""" |
| result = [] |
| for item in items: |
| if item.is_valid(): |
| result.append(item.transform()) |
| return result |
| ''', |
| ''' |
| class DataProcessor: |
| def __init__(self, config): |
| self.config = config |
| self.cache = {} |
| |
| def process(self, data): |
| if data.id in self.cache: |
| return self.cache[data.id] |
| result = self._compute(data) |
| self.cache[data.id] = result |
| return result |
| ''', |
| ''' |
| def calculate_metrics(values): |
| total = sum(values) |
| count = len(values) |
| mean = total / count if count > 0 else 0 |
| variance = sum((x - mean) ** 2 for x in values) / count if count > 0 else 0 |
| return {"mean": mean, "variance": variance, "total": total} |
| ''', |
| ''' |
| async def fetch_data(url): |
| async with aiohttp.ClientSession() as session: |
| async with session.get(url) as response: |
| if response.status == 200: |
| return await response.json() |
| raise Exception(f"Error: {response.status}") |
| ''', |
| ''' |
| def validate_input(data): |
| if not isinstance(data, dict): |
| raise TypeError("Expected dict") |
| required = ["name", "email", "age"] |
| for field in required: |
| if field not in data: |
| raise ValueError(f"Missing field: {field}") |
| return True |
| ''', |
| ''' |
| class Logger: |
| def __init__(self, name): |
| self.name = name |
| self.level = "INFO" |
| |
| def log(self, message, level="INFO"): |
| timestamp = datetime.now().isoformat() |
| print(f"[{timestamp}] [{level}] {self.name}: {message}") |
| ''', |
| ''' |
| def merge_configs(*configs): |
| result = {} |
| for config in configs: |
| for key, value in config.items(): |
| if key in result and isinstance(result[key], dict): |
| result[key] = merge_configs(result[key], value) |
| else: |
| result[key] = value |
| return result |
| ''', |
| ''' |
| def fibonacci(n): |
| if n <= 1: |
| return n |
| a, b = 0, 1 |
| for _ in range(2, n + 1): |
| a, b = b, a + b |
| return b |
| ''', |
| ] |
|
|
|
|
| def generate_haystack(num_lines: int) -> str: |
| """Generates haystack code with approximate number of lines.""" |
| lines = [] |
| current_lines = 0 |
| |
| while current_lines < num_lines: |
| snippet = random.choice(HAYSTACK_SNIPPETS) |
| lines.append(snippet) |
| current_lines += snippet.count('\n') |
| |
| return '\n'.join(lines) |
|
|
|
|
| def create_needle_prompt(needle_name: str, needle_value: str, haystack_lines: int) -> Tuple[str, str]: |
| """ |
| Creates a prompt with the needle at the start and question at the end. |
| |
| Returns: |
| (full_prompt, expected_answer) |
| """ |
| |
| needle = f'{needle_name} = "{needle_value}"\n\n' |
| |
| |
| haystack = generate_haystack(haystack_lines) |
| |
| |
| question = f'\n\n# Question: What is the value of {needle_name}?\n# Answer: {needle_name} = "' |
| |
| full_prompt = needle + haystack + question |
| |
| return full_prompt, needle_value |
|
|
|
|
| |
| |
| |
|
|
| def load_model(config_name: str) -> Tuple[RippleGPT, callable, callable]: |
| """Loads trained model.""" |
| |
| |
| best_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_best.pt') |
| final_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_final.pt') |
| |
| if os.path.exists(best_path): |
| ckpt_path = best_path |
| elif os.path.exists(final_path): |
| ckpt_path = final_path |
| else: |
| raise FileNotFoundError( |
| f"Checkpoint not found for config '{config_name}'\n" |
| f"Run: python validation/memory/train_large.py --config {config_name}" |
| ) |
| |
| print(f"π¦ Loading model from: {ckpt_path}") |
| |
| checkpoint = torch.load(ckpt_path, map_location=DEVICE, weights_only=False) |
| config = checkpoint['config'] |
| |
| model = RippleGPT(config) |
| |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k in list(state_dict.keys()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| |
| model.load_state_dict(state_dict) |
| model.to(DEVICE) |
| model.eval() |
| |
| |
| with open(os.path.join(DATA_DIR, 'meta.pkl'), 'rb') as f: |
| meta = pickle.load(f) |
| |
| stoi = meta['stoi'] |
| itos = meta['itos'] |
| |
| unknown = stoi.get('?', stoi.get(' ', 0)) |
| encode = lambda s: [stoi.get(c, unknown) for c in s] |
| decode = lambda l: ''.join([itos.get(i, '?') for i in l]) |
| |
| print(f" β
Model loaded ({model.get_num_params()/1e6:.2f}M params)") |
| |
| return model, encode, decode |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def run_needle_test( |
| model: RippleGPT, |
| encode, |
| decode, |
| needle_name: str, |
| needle_value: str, |
| haystack_lines: int, |
| max_gen_tokens: int = 50 |
| ) -> Dict: |
| """ |
| Executes a needle in a haystack test. |
| |
| Returns: |
| Dict with test results |
| """ |
| |
| prompt, expected = create_needle_prompt(needle_name, needle_value, haystack_lines) |
| |
| |
| input_ids = encode(prompt) |
| num_input_tokens = len(input_ids) |
| |
| |
| if DEVICE == 'cuda': |
| torch.cuda.reset_peak_memory_stats() |
| mem_before = torch.cuda.memory_allocated() / 1e6 |
| else: |
| mem_before = psutil.Process().memory_info().rss / 1e6 |
| |
| |
| x = torch.tensor(input_ids, dtype=torch.long, device=DEVICE).unsqueeze(0) |
| |
| start_time = time.time() |
| output = model.generate(x, max_new_tokens=max_gen_tokens, temperature=0.1, top_k=5) |
| gen_time = time.time() - start_time |
| |
| |
| if DEVICE == 'cuda': |
| mem_after = torch.cuda.max_memory_allocated() / 1e6 |
| else: |
| mem_after = psutil.Process().memory_info().rss / 1e6 |
| |
| |
| full_output = decode(output[0].tolist()) |
| generated = full_output[len(prompt):] |
| |
| |
| |
| generated_clean = generated.split('"')[0] if '"' in generated else generated.split('\n')[0] |
| generated_clean = generated_clean.strip() |
| |
| |
| exact_match = needle_value in generated |
| partial_match = any( |
| needle_value[i:i+5] in generated |
| for i in range(len(needle_value)-4) |
| ) if len(needle_value) > 4 else needle_value in generated |
| |
| return { |
| 'needle_name': needle_name, |
| 'needle_value': needle_value, |
| 'haystack_lines': haystack_lines, |
| 'input_tokens': num_input_tokens, |
| 'generated': generated[:100], |
| 'exact_match': exact_match, |
| 'partial_match': partial_match, |
| 'generation_time': gen_time, |
| 'tokens_per_second': max_gen_tokens / gen_time, |
| 'memory_mb': mem_after - mem_before, |
| 'peak_memory_mb': mem_after |
| } |
|
|
|
|
| def run_full_test_suite( |
| model, |
| encode, |
| decode, |
| depths: List[int] = [50, 100, 200, 500], |
| num_trials: int = 3 |
| ) -> Dict: |
| """ |
| Executes full test suite at different depths. |
| """ |
| results = { |
| 'depths': {}, |
| 'summary': {} |
| } |
| |
| all_exact = 0 |
| all_partial = 0 |
| total_tests = 0 |
| |
| for depth in depths: |
| print(f"\nπ Testing depth: {depth} lines") |
| print("-" * 50) |
| |
| depth_results = [] |
| exact_count = 0 |
| partial_count = 0 |
| |
| for trial in range(num_trials): |
| |
| needle_name, needle_value = random.choice(NEEDLES) |
| |
| result = run_needle_test( |
| model, encode, decode, |
| needle_name, needle_value, |
| depth |
| ) |
| |
| depth_results.append(result) |
| |
| if result['exact_match']: |
| exact_count += 1 |
| if result['partial_match']: |
| partial_count += 1 |
| |
| status = "β
" if result['exact_match'] else ("β οΈ" if result['partial_match'] else "β") |
| print(f" {status} {needle_name}: {result['generated'][:30]}...") |
| |
| results['depths'][depth] = { |
| 'trials': depth_results, |
| 'exact_accuracy': exact_count / num_trials, |
| 'partial_accuracy': partial_count / num_trials, |
| 'avg_tokens': sum(r['input_tokens'] for r in depth_results) / num_trials, |
| 'avg_memory_mb': sum(r['peak_memory_mb'] for r in depth_results) / num_trials, |
| 'avg_tokens_per_sec': sum(r['tokens_per_second'] for r in depth_results) / num_trials |
| } |
| |
| all_exact += exact_count |
| all_partial += partial_count |
| total_tests += num_trials |
| |
| results['summary'] = { |
| 'total_tests': total_tests, |
| 'overall_exact_accuracy': all_exact / total_tests, |
| 'overall_partial_accuracy': all_partial / total_tests, |
| } |
| |
| return results |
|
|
|
|
| def print_results(results: Dict, config_name: str): |
| """Prints formatted results.""" |
| |
| print("\n" + "=" * 70) |
| print(f"π§ NEEDLE IN A HAYSTACK RESULTS - Model: {config_name.upper()}") |
| print("=" * 70) |
| |
| print("\nπ Results by Depth:") |
| print("-" * 70) |
| print(f"{'Depth':<10} {'Exact':<10} {'Partial':<10} {'Tokens':<12} {'Memory':<12} {'Speed':<12}") |
| print("-" * 70) |
| |
| for depth, data in results['depths'].items(): |
| print(f"{depth:<10} {data['exact_accuracy']*100:>6.1f}% {data['partial_accuracy']*100:>6.1f}% " |
| f"{data['avg_tokens']:>8.0f} {data['avg_memory_mb']:>8.1f}MB " |
| f"{data['avg_tokens_per_sec']:>8.1f}t/s") |
| |
| print("-" * 70) |
| summary = results['summary'] |
| print(f"\nπ SUMMARY:") |
| print(f" Total tests: {summary['total_tests']}") |
| print(f" Exact accuracy: {summary['overall_exact_accuracy']*100:.1f}%") |
| print(f" Partial accuracy: {summary['overall_partial_accuracy']*100:.1f}%") |
| |
| |
| if summary['overall_exact_accuracy'] >= 0.7: |
| print("\nπ VERDICT: EXCELLENT! Ripple architecture retains long-term memory!") |
| elif summary['overall_exact_accuracy'] >= 0.4: |
| print("\nβ οΈ VERDICT: PROMISING. Partial retention, but needs adjustments.") |
| else: |
| print("\nβ VERDICT: More training needed for long-term retention.") |
| |
| print("=" * 70) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Needle in a Haystack Test') |
| parser.add_argument('--config', type=str, default='medium', |
| choices=['small', 'medium', 'large', 'xlarge']) |
| parser.add_argument('--depths', type=int, nargs='+', default=[50, 100, 200, 500], |
| help='Depths to test (lines of code)') |
| parser.add_argument('--trials', type=int, default=3, help='Tests per depth') |
| parser.add_argument('--no-save', action='store_true') |
| args = parser.parse_args() |
| |
| print("=" * 70) |
| print("π¬ NEEDLE IN A HAYSTACK TEST - RippleGPT Memory Validation") |
| print("=" * 70) |
| |
| |
| max_depth = max(args.depths) |
| |
| estimated_tokens = max_depth * 10 |
| |
| |
| |
| config_params = { |
| 'small': (6, 6), |
| 'medium': (8, 8), |
| 'large': (12, 12), |
| 'xlarge': (16, 16) |
| } |
| n_heads, n_layers = config_params.get(args.config, (8, 8)) |
| |
| |
| estimated_mem_mb = (estimated_tokens ** 2) * 4 * n_heads * n_layers / 1e6 |
| |
| print(f"\nβ οΈ TECHNICAL NOTE: Memory Complexity O(TΒ²)") |
| print(f" β’ Max depth: {max_depth} lines (~{estimated_tokens} tokens)") |
| print(f" β’ Model: {args.config} ({n_heads} heads Γ {n_layers} layers)") |
| print(f" β’ Estimated attention memory: ~{estimated_mem_mb:.1f} MB") |
| |
| if estimated_mem_mb > 1000: |
| print(f" β οΈ WARNING: High estimated memory! May cause OOM.") |
| print(f" π‘ Consider using smaller --depths or smaller model.") |
| |
| |
| try: |
| model, encode, decode = load_model(args.config) |
| except FileNotFoundError as e: |
| print(f"\nβ {e}") |
| return 1 |
| |
| |
| results = run_full_test_suite( |
| model, encode, decode, |
| depths=args.depths, |
| num_trials=args.trials |
| ) |
| |
| |
| results['metadata'] = { |
| 'config': args.config, |
| 'timestamp': datetime.now().isoformat(), |
| 'device': DEVICE |
| } |
| |
| |
| print_results(results, args.config) |
| |
| |
| if not args.no_save: |
| os.makedirs(RESULTS_DIR, exist_ok=True) |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| results_path = os.path.join(RESULTS_DIR, f'needle_test_{args.config}_{timestamp}.json') |
| |
| |
| def make_serializable(obj): |
| if isinstance(obj, dict): |
| return {k: make_serializable(v) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [make_serializable(v) for v in obj] |
| elif isinstance(obj, (bool, int, float, str, type(None))): |
| return obj |
| else: |
| return str(obj) |
| |
| with open(results_path, 'w') as f: |
| json.dump(make_serializable(results), f, indent=2) |
| |
| print(f"\nπΎ Results saved to: {results_path}") |
| |
| return 0 if results['summary']['overall_exact_accuracy'] >= 0.5 else 1 |
|
|
|
|
| if __name__ == '__main__': |
| exit(main()) |
|
|