RippleGPT-Nano / validation /memory /needle_test.py
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
needle_test.py - "Needle in a Haystack" test for memory validation.
This is the KILLER TEST that proves if RippleGPT can retain long-term information
through the Ripple Field (ALiBi-style attention) mechanism.
The test:
1. Places a "needle" (SECRET_PASSWORD = "bananas") at the beginning of a long text
2. Adds hundreds of lines of Python code as "haystack"
3. Asks the model to remember the password
4. Measures if it can retrieve the information
⚠️ TECHNICAL NOTE - MEMORY COMPLEXITY: O(T²)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RippleGPT uses full quadratic attention.
For T context tokens:
β€’ Memory β‰ˆ TΒ² Γ— 4 bytes Γ— n_heads Γ— n_layers
β€’ T=1000 β†’ ~4MB per head
β€’ T=3000 β†’ ~36MB per head
β€’ T=8000 β†’ ~256MB per head
The BENEFIT of Ripple Field is NOT memory efficiency,
but rather EXTRAPOLATION: train on 256, infer on 1024+.
Usage:
python validation/memory/needle_test.py --config medium
python validation/memory/needle_test.py --config large --depths 100 200 500 1000
"""
import os
import sys
import time
import pickle
import argparse
import json
from datetime import datetime
from typing import List, Dict, Tuple
import random
import torch
import psutil
# Add root directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.model import RippleGPT
from src.config import RippleConfig
from validation.memory.model_configs import get_config
# Directories
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
RESULTS_DIR = os.path.join(os.path.dirname(__file__), 'results')
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
# ============================================================================
# NEEDLES - Information to be retrieved
# ============================================================================
NEEDLES = [
("SECRET_PASSWORD", "bananas"),
("API_KEY", "sk-abc123xyz789"),
("DATABASE_URL", "postgres://localhost:5432/mydb"),
("ADMIN_PASSWORD", "super_secret_2024"),
("MAGIC_NUMBER", "42"),
]
# ============================================================================
# HAYSTACK - Distraction code
# ============================================================================
HAYSTACK_SNIPPETS = [
'''
def process_data(items):
"""Process a list of items."""
result = []
for item in items:
if item.is_valid():
result.append(item.transform())
return result
''',
'''
class DataProcessor:
def __init__(self, config):
self.config = config
self.cache = {}
def process(self, data):
if data.id in self.cache:
return self.cache[data.id]
result = self._compute(data)
self.cache[data.id] = result
return result
''',
'''
def calculate_metrics(values):
total = sum(values)
count = len(values)
mean = total / count if count > 0 else 0
variance = sum((x - mean) ** 2 for x in values) / count if count > 0 else 0
return {"mean": mean, "variance": variance, "total": total}
''',
'''
async def fetch_data(url):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
return await response.json()
raise Exception(f"Error: {response.status}")
''',
'''
def validate_input(data):
if not isinstance(data, dict):
raise TypeError("Expected dict")
required = ["name", "email", "age"]
for field in required:
if field not in data:
raise ValueError(f"Missing field: {field}")
return True
''',
'''
class Logger:
def __init__(self, name):
self.name = name
self.level = "INFO"
def log(self, message, level="INFO"):
timestamp = datetime.now().isoformat()
print(f"[{timestamp}] [{level}] {self.name}: {message}")
''',
'''
def merge_configs(*configs):
result = {}
for config in configs:
for key, value in config.items():
if key in result and isinstance(result[key], dict):
result[key] = merge_configs(result[key], value)
else:
result[key] = value
return result
''',
'''
def fibonacci(n):
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b
''',
]
def generate_haystack(num_lines: int) -> str:
"""Generates haystack code with approximate number of lines."""
lines = []
current_lines = 0
while current_lines < num_lines:
snippet = random.choice(HAYSTACK_SNIPPETS)
lines.append(snippet)
current_lines += snippet.count('\n')
return '\n'.join(lines)
def create_needle_prompt(needle_name: str, needle_value: str, haystack_lines: int) -> Tuple[str, str]:
"""
Creates a prompt with the needle at the start and question at the end.
Returns:
(full_prompt, expected_answer)
"""
# Needle at start
needle = f'{needle_name} = "{needle_value}"\n\n'
# Haystack
haystack = generate_haystack(haystack_lines)
# Question at the end
question = f'\n\n# Question: What is the value of {needle_name}?\n# Answer: {needle_name} = "'
full_prompt = needle + haystack + question
return full_prompt, needle_value
# ============================================================================
# MODEL
# ============================================================================
def load_model(config_name: str) -> Tuple[RippleGPT, callable, callable]:
"""Loads trained model."""
# Try best, then final
best_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_best.pt')
final_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_final.pt')
if os.path.exists(best_path):
ckpt_path = best_path
elif os.path.exists(final_path):
ckpt_path = final_path
else:
raise FileNotFoundError(
f"Checkpoint not found for config '{config_name}'\n"
f"Run: python validation/memory/train_large.py --config {config_name}"
)
print(f"πŸ“¦ Loading model from: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
model = RippleGPT(config)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Vocabulary
with open(os.path.join(DATA_DIR, 'meta.pkl'), 'rb') as f:
meta = pickle.load(f)
stoi = meta['stoi']
itos = meta['itos']
unknown = stoi.get('?', stoi.get(' ', 0))
encode = lambda s: [stoi.get(c, unknown) for c in s]
decode = lambda l: ''.join([itos.get(i, '?') for i in l])
print(f" βœ… Model loaded ({model.get_num_params()/1e6:.2f}M params)")
return model, encode, decode
# ============================================================================
# TESTS
# ============================================================================
@torch.no_grad()
def run_needle_test(
model: RippleGPT,
encode,
decode,
needle_name: str,
needle_value: str,
haystack_lines: int,
max_gen_tokens: int = 50
) -> Dict:
"""
Executes a needle in a haystack test.
Returns:
Dict with test results
"""
# Create prompt
prompt, expected = create_needle_prompt(needle_name, needle_value, haystack_lines)
# Measure tokens
input_ids = encode(prompt)
num_input_tokens = len(input_ids)
# Measure memory before
if DEVICE == 'cuda':
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated() / 1e6
else:
mem_before = psutil.Process().memory_info().rss / 1e6
# Generate response
x = torch.tensor(input_ids, dtype=torch.long, device=DEVICE).unsqueeze(0)
start_time = time.time()
output = model.generate(x, max_new_tokens=max_gen_tokens, temperature=0.1, top_k=5)
gen_time = time.time() - start_time
# Measure memory after
if DEVICE == 'cuda':
mem_after = torch.cuda.max_memory_allocated() / 1e6
else:
mem_after = psutil.Process().memory_info().rss / 1e6
# Decode response
full_output = decode(output[0].tolist())
generated = full_output[len(prompt):]
# Check if correct
# Clean generated response for comparison
generated_clean = generated.split('"')[0] if '"' in generated else generated.split('\n')[0]
generated_clean = generated_clean.strip()
# Verifications
exact_match = needle_value in generated
partial_match = any(
needle_value[i:i+5] in generated
for i in range(len(needle_value)-4)
) if len(needle_value) > 4 else needle_value in generated
return {
'needle_name': needle_name,
'needle_value': needle_value,
'haystack_lines': haystack_lines,
'input_tokens': num_input_tokens,
'generated': generated[:100], # First 100 chars
'exact_match': exact_match,
'partial_match': partial_match,
'generation_time': gen_time,
'tokens_per_second': max_gen_tokens / gen_time,
'memory_mb': mem_after - mem_before,
'peak_memory_mb': mem_after
}
def run_full_test_suite(
model,
encode,
decode,
depths: List[int] = [50, 100, 200, 500],
num_trials: int = 3
) -> Dict:
"""
Executes full test suite at different depths.
"""
results = {
'depths': {},
'summary': {}
}
all_exact = 0
all_partial = 0
total_tests = 0
for depth in depths:
print(f"\nπŸ“ Testing depth: {depth} lines")
print("-" * 50)
depth_results = []
exact_count = 0
partial_count = 0
for trial in range(num_trials):
# Choose a random needle
needle_name, needle_value = random.choice(NEEDLES)
result = run_needle_test(
model, encode, decode,
needle_name, needle_value,
depth
)
depth_results.append(result)
if result['exact_match']:
exact_count += 1
if result['partial_match']:
partial_count += 1
status = "βœ…" if result['exact_match'] else ("⚠️" if result['partial_match'] else "❌")
print(f" {status} {needle_name}: {result['generated'][:30]}...")
results['depths'][depth] = {
'trials': depth_results,
'exact_accuracy': exact_count / num_trials,
'partial_accuracy': partial_count / num_trials,
'avg_tokens': sum(r['input_tokens'] for r in depth_results) / num_trials,
'avg_memory_mb': sum(r['peak_memory_mb'] for r in depth_results) / num_trials,
'avg_tokens_per_sec': sum(r['tokens_per_second'] for r in depth_results) / num_trials
}
all_exact += exact_count
all_partial += partial_count
total_tests += num_trials
results['summary'] = {
'total_tests': total_tests,
'overall_exact_accuracy': all_exact / total_tests,
'overall_partial_accuracy': all_partial / total_tests,
}
return results
def print_results(results: Dict, config_name: str):
"""Prints formatted results."""
print("\n" + "=" * 70)
print(f"🧠 NEEDLE IN A HAYSTACK RESULTS - Model: {config_name.upper()}")
print("=" * 70)
print("\nπŸ“Š Results by Depth:")
print("-" * 70)
print(f"{'Depth':<10} {'Exact':<10} {'Partial':<10} {'Tokens':<12} {'Memory':<12} {'Speed':<12}")
print("-" * 70)
for depth, data in results['depths'].items():
print(f"{depth:<10} {data['exact_accuracy']*100:>6.1f}% {data['partial_accuracy']*100:>6.1f}% "
f"{data['avg_tokens']:>8.0f} {data['avg_memory_mb']:>8.1f}MB "
f"{data['avg_tokens_per_sec']:>8.1f}t/s")
print("-" * 70)
summary = results['summary']
print(f"\nπŸ“ˆ SUMMARY:")
print(f" Total tests: {summary['total_tests']}")
print(f" Exact accuracy: {summary['overall_exact_accuracy']*100:.1f}%")
print(f" Partial accuracy: {summary['overall_partial_accuracy']*100:.1f}%")
# Verdict
if summary['overall_exact_accuracy'] >= 0.7:
print("\nπŸŽ‰ VERDICT: EXCELLENT! Ripple architecture retains long-term memory!")
elif summary['overall_exact_accuracy'] >= 0.4:
print("\n⚠️ VERDICT: PROMISING. Partial retention, but needs adjustments.")
else:
print("\n❌ VERDICT: More training needed for long-term retention.")
print("=" * 70)
def main():
parser = argparse.ArgumentParser(description='Needle in a Haystack Test')
parser.add_argument('--config', type=str, default='medium',
choices=['small', 'medium', 'large', 'xlarge'])
parser.add_argument('--depths', type=int, nargs='+', default=[50, 100, 200, 500],
help='Depths to test (lines of code)')
parser.add_argument('--trials', type=int, default=3, help='Tests per depth')
parser.add_argument('--no-save', action='store_true')
args = parser.parse_args()
print("=" * 70)
print("πŸ”¬ NEEDLE IN A HAYSTACK TEST - RippleGPT Memory Validation")
print("=" * 70)
# Estimate needed memory
max_depth = max(args.depths)
# ~10 tokens per line of code, conservative estimate
estimated_tokens = max_depth * 10
# Memory formula: TΒ² Γ— 4 bytes Γ— n_heads Γ— n_layers (approx)
# Configs: small=6Γ—6, medium=8Γ—8, large=12Γ—12, xlarge=16Γ—16
config_params = {
'small': (6, 6),
'medium': (8, 8),
'large': (12, 12),
'xlarge': (16, 16)
}
n_heads, n_layers = config_params.get(args.config, (8, 8))
# Memory in MB per batch (TΒ² Γ— 4 bytes Γ— n_heads Γ— n_layers / 1e6)
estimated_mem_mb = (estimated_tokens ** 2) * 4 * n_heads * n_layers / 1e6
print(f"\n⚠️ TECHNICAL NOTE: Memory Complexity O(T²)")
print(f" β€’ Max depth: {max_depth} lines (~{estimated_tokens} tokens)")
print(f" β€’ Model: {args.config} ({n_heads} heads Γ— {n_layers} layers)")
print(f" β€’ Estimated attention memory: ~{estimated_mem_mb:.1f} MB")
if estimated_mem_mb > 1000:
print(f" ⚠️ WARNING: High estimated memory! May cause OOM.")
print(f" πŸ’‘ Consider using smaller --depths or smaller model.")
# Load model
try:
model, encode, decode = load_model(args.config)
except FileNotFoundError as e:
print(f"\n❌ {e}")
return 1
# Run tests
results = run_full_test_suite(
model, encode, decode,
depths=args.depths,
num_trials=args.trials
)
# Add metadata
results['metadata'] = {
'config': args.config,
'timestamp': datetime.now().isoformat(),
'device': DEVICE
}
# Print results
print_results(results, args.config)
# Save results
if not args.no_save:
os.makedirs(RESULTS_DIR, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_path = os.path.join(RESULTS_DIR, f'needle_test_{args.config}_{timestamp}.json')
# Convert to serializable JSON
def make_serializable(obj):
if isinstance(obj, dict):
return {k: make_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [make_serializable(v) for v in obj]
elif isinstance(obj, (bool, int, float, str, type(None))):
return obj
else:
return str(obj)
with open(results_path, 'w') as f:
json.dump(make_serializable(results), f, indent=2)
print(f"\nπŸ’Ύ Results saved to: {results_path}")
return 0 if results['summary']['overall_exact_accuracy'] >= 0.5 else 1
if __name__ == '__main__':
exit(main())