""" Titans-style neural long-term memory. Key insight: The hidden state IS a neural network. Updates happen via self-supervised learning during inference. Based on: https://arxiv.org/abs/2501.00663 """ from __future__ import annotations import hashlib from typing import Any import torch import torch.nn as nn import torch.nn.functional as functional from torch import Tensor from ..config import MemoryConfig class NeuralMemory(nn.Module): """ Titans-style neural long-term memory. The memory is a small neural network that updates its weights during inference via gradient descent (test-time training). Example: >>> config = MemoryConfig(dim=256) >>> memory = NeuralMemory(config) >>> result = memory.observe("Python uses indentation") >>> print(f"Surprise: {result['surprise']:.3f}") """ def __init__(self, config: MemoryConfig | int | None = None, **kwargs: Any) -> None: super().__init__() # Handle both config object and legacy positional args if config is None: config = MemoryConfig(**kwargs) elif isinstance(config, int): # Legacy: NeuralMemory(dim=256) or NeuralMemory(256) config = MemoryConfig(dim=config, **kwargs) self.config = config self.dim = config.dim # The memory IS a neural network self.memory_net = nn.Sequential( nn.Linear(config.dim, config.dim * 4), nn.GELU(), nn.LayerNorm(config.dim * 4), nn.Linear(config.dim * 4, config.dim), ) # Target projection for self-supervised learning self.target_proj = nn.Linear(config.dim, config.dim) # Learnable learning rate (meta-learning) self.lr = nn.Parameter(torch.tensor(config.learning_rate)) # Observation counter self._observation_count = 0 self._recent_surprises: list[float] = [] # Move to device self.to(config.device) def _encode_text(self, text: str) -> Tensor: """ Encode text to tensor representation. Uses a simple but deterministic encoding for demo purposes. In production, would use a proper encoder (e.g., sentence-transformers). """ # Create deterministic embedding from text text_bytes = text.encode("utf-8") hash_bytes = hashlib.sha256(text_bytes).digest() # Expand hash to fill dimension values = [] for i in range(self.dim): byte_idx = i % len(hash_bytes) bit_offset = (i // len(hash_bytes)) % 8 val = ((hash_bytes[byte_idx] >> bit_offset) & 1) * 2 - 1 # -1 or 1 values.append(val * 0.1) # Add variation based on character positions for i, char in enumerate(text[: self.dim]): idx = i % self.dim values[idx] += (ord(char) / 255.0 - 0.5) * 0.2 tensor = torch.tensor(values, dtype=torch.float32, device=self.config.device) # Shape: [1, seq_len, dim] - treat each character as a "token" seq_len = min(len(text), 64) # Cap sequence length tensor = tensor.unsqueeze(0).unsqueeze(0).expand(1, seq_len, -1).clone() # Add positional variation for i in range(seq_len): if i < len(text): tensor[0, i, :] += torch.randn(self.dim, device=self.config.device) * 0.01 tensor[0, i, i % self.dim] += ord(text[i]) / 255.0 return tensor def forward(self, x: Tensor, learn: bool = True) -> Tensor: """ Process input and optionally update memory weights. Args: x: Input tensor [batch, seq, dim] learn: Whether to update memory weights (test-time training) Returns: Memory-augmented representation """ # Ensure requires_grad for learning if learn: x = x.detach().requires_grad_(False) for param in self.memory_net.parameters(): param.requires_grad_(True) # Query the memory memory_output: Tensor = self.memory_net(x) if learn and x.shape[1] > 1: # Self-supervised objective: predict next token representation loss = self._compute_surprise_tensor(x, memory_output) if loss.requires_grad: # Update memory weights (this is the key innovation) self._update_weights(loss) return memory_output def _compute_surprise_tensor(self, x: Tensor, pred: Tensor) -> Tensor: """ Compute surprise as prediction error (returns tensor for gradients). """ if x.shape[1] <= 1: return torch.tensor(0.0, device=x.device, requires_grad=True) # Target: shifted input projected target = self.target_proj(x[:, 1:, :]) prediction = pred[:, :-1, :] return functional.mse_loss(prediction, target) def _compute_surprise(self, x: Tensor, pred: Tensor) -> float: """ Compute surprise score (0 to 1 range). """ with torch.no_grad(): if x.shape[1] <= 1: return 0.5 target = self.target_proj(x[:, 1:, :]) prediction = pred[:, :-1, :] mse = functional.mse_loss(prediction, target).item() # Convert to 0-1 range using sigmoid-like scaling surprise = 2.0 / (1.0 + torch.exp(torch.tensor(-mse * 10)).item()) - 1.0 return float(max(0.0, min(1.0, surprise))) def _update_weights(self, loss: Tensor) -> None: """The key innovation: gradient descent during forward pass.""" try: grads = torch.autograd.grad( loss, list(self.memory_net.parameters()), create_graph=False, allow_unused=True ) with torch.no_grad(): for param, grad in zip(self.memory_net.parameters(), grads): if grad is not None: param -= self.lr * grad except RuntimeError: # Gradient computation failed, skip update pass def observe(self, content: str | Tensor, learning_rate: float | None = None) -> dict[str, Any]: """ Feed content to memory, triggering test-time learning. Args: content: Text string or tensor to learn from learning_rate: Optional override for learning rate Returns: dict with surprise score, weight delta, and metadata """ # Handle learning rate override original_lr = None if learning_rate is not None: original_lr = self.lr.data.clone() self.lr.data = torch.tensor(learning_rate, device=self.config.device) # Encode if string x = self._encode_text(content) if isinstance(content, str) else content # Store initial weights for delta calculation initial_weights = { name: param.clone() for name, param in self.memory_net.named_parameters() } # Forward with learning output = self.forward(x, learn=True) # Calculate metrics surprise = self._compute_surprise(x, output) weight_delta = sum( (param - initial_weights[name]).abs().sum().item() for name, param in self.memory_net.named_parameters() ) # Restore learning rate if original_lr is not None: self.lr.data = original_lr # Update stats self._observation_count += 1 self._recent_surprises.append(surprise) if len(self._recent_surprises) > 100: self._recent_surprises.pop(0) return { "surprise": surprise, "weight_delta": weight_delta, "patterns_activated": [f"pattern_{self._observation_count}"], "learned": weight_delta > 1e-6, } def infer(self, query: str | Tensor, temperature: float = 1.0) -> dict[str, Any]: """ Query memory using learned representations (no learning). Args: query: Text string or tensor to query temperature: Not used currently, for API compatibility Returns: dict with response tensor and confidence """ del temperature # Unused, kept for API compatibility x = self._encode_text(query) if isinstance(query, str) else query with torch.no_grad(): output = self.forward(x, learn=False) confidence = 1.0 - self._compute_surprise(x, output) return { "response": output, "confidence": max(0.0, min(1.0, confidence)), "attention_weights": output[0, 0, :10].tolist() if output.dim() >= 3 else [], } def surprise(self, content: str | Tensor) -> float: """ Measure how surprising/novel content is WITHOUT learning. Args: content: Text string or tensor to evaluate Returns: Surprise score between 0 (familiar) and 1 (novel) """ x = self._encode_text(content) if isinstance(content, str) else content with torch.no_grad(): output = self.memory_net(x) return self._compute_surprise(x, output) def get_weight_hash(self) -> str: """ Get hash of current weights for change detection. Returns: 16-character hex hash of weights """ with torch.no_grad(): state = self.memory_net.state_dict() flat = torch.cat([v.flatten().cpu() for v in state.values()]) # Use string representation instead of numpy to avoid numpy dependency data_str = str(flat.tolist()) hash_bytes = hashlib.sha256(data_str.encode()).digest() return hash_bytes[:8].hex() def get_stats(self) -> dict[str, Any]: """Get memory statistics.""" return { "total_observations": self._observation_count, "weight_parameters": sum(p.numel() for p in self.memory_net.parameters()), "avg_surprise": ( sum(self._recent_surprises) / len(self._recent_surprises) if self._recent_surprises else 0.0 ), "learning_rate": self.lr.item(), "dimension": self.dim, }