File size: 10,376 Bytes
dd41762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Titans-style neural long-term memory.

Key insight: The hidden state IS a neural network.
Updates happen via self-supervised learning during inference.

Based on: https://arxiv.org/abs/2501.00663
"""

from __future__ import annotations

import hashlib
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as functional
from torch import Tensor

from ..config import MemoryConfig


class NeuralMemory(nn.Module):
    """
    Titans-style neural long-term memory.

    The memory is a small neural network that updates its weights
    during inference via gradient descent (test-time training).

    Example:
        >>> config = MemoryConfig(dim=256)
        >>> memory = NeuralMemory(config)
        >>> result = memory.observe("Python uses indentation")
        >>> print(f"Surprise: {result['surprise']:.3f}")
    """

    def __init__(self, config: MemoryConfig | int | None = None, **kwargs: Any) -> None:
        super().__init__()

        # Handle both config object and legacy positional args
        if config is None:
            config = MemoryConfig(**kwargs)
        elif isinstance(config, int):
            # Legacy: NeuralMemory(dim=256) or NeuralMemory(256)
            config = MemoryConfig(dim=config, **kwargs)

        self.config = config
        self.dim = config.dim

        # The memory IS a neural network
        self.memory_net = nn.Sequential(
            nn.Linear(config.dim, config.dim * 4),
            nn.GELU(),
            nn.LayerNorm(config.dim * 4),
            nn.Linear(config.dim * 4, config.dim),
        )

        # Target projection for self-supervised learning
        self.target_proj = nn.Linear(config.dim, config.dim)

        # Learnable learning rate (meta-learning)
        self.lr = nn.Parameter(torch.tensor(config.learning_rate))

        # Observation counter
        self._observation_count = 0
        self._recent_surprises: list[float] = []

        # Move to device
        self.to(config.device)

    def _encode_text(self, text: str) -> Tensor:
        """
        Encode text to tensor representation.

        Uses a simple but deterministic encoding for demo purposes.
        In production, would use a proper encoder (e.g., sentence-transformers).
        """
        # Create deterministic embedding from text
        text_bytes = text.encode("utf-8")
        hash_bytes = hashlib.sha256(text_bytes).digest()

        # Expand hash to fill dimension
        values = []
        for i in range(self.dim):
            byte_idx = i % len(hash_bytes)
            bit_offset = (i // len(hash_bytes)) % 8
            val = ((hash_bytes[byte_idx] >> bit_offset) & 1) * 2 - 1  # -1 or 1
            values.append(val * 0.1)

        # Add variation based on character positions
        for i, char in enumerate(text[: self.dim]):
            idx = i % self.dim
            values[idx] += (ord(char) / 255.0 - 0.5) * 0.2

        tensor = torch.tensor(values, dtype=torch.float32, device=self.config.device)
        # Shape: [1, seq_len, dim] - treat each character as a "token"
        seq_len = min(len(text), 64)  # Cap sequence length
        tensor = tensor.unsqueeze(0).unsqueeze(0).expand(1, seq_len, -1).clone()

        # Add positional variation
        for i in range(seq_len):
            if i < len(text):
                tensor[0, i, :] += torch.randn(self.dim, device=self.config.device) * 0.01
                tensor[0, i, i % self.dim] += ord(text[i]) / 255.0

        return tensor

    def forward(self, x: Tensor, learn: bool = True) -> Tensor:
        """
        Process input and optionally update memory weights.

        Args:
            x: Input tensor [batch, seq, dim]
            learn: Whether to update memory weights (test-time training)

        Returns:
            Memory-augmented representation
        """
        # Ensure requires_grad for learning
        if learn:
            x = x.detach().requires_grad_(False)
            for param in self.memory_net.parameters():
                param.requires_grad_(True)

        # Query the memory
        memory_output: Tensor = self.memory_net(x)

        if learn and x.shape[1] > 1:
            # Self-supervised objective: predict next token representation
            loss = self._compute_surprise_tensor(x, memory_output)

            if loss.requires_grad:
                # Update memory weights (this is the key innovation)
                self._update_weights(loss)

        return memory_output

    def _compute_surprise_tensor(self, x: Tensor, pred: Tensor) -> Tensor:
        """
        Compute surprise as prediction error (returns tensor for gradients).
        """
        if x.shape[1] <= 1:
            return torch.tensor(0.0, device=x.device, requires_grad=True)

        # Target: shifted input projected
        target = self.target_proj(x[:, 1:, :])
        prediction = pred[:, :-1, :]

        return functional.mse_loss(prediction, target)

    def _compute_surprise(self, x: Tensor, pred: Tensor) -> float:
        """
        Compute surprise score (0 to 1 range).
        """
        with torch.no_grad():
            if x.shape[1] <= 1:
                return 0.5

            target = self.target_proj(x[:, 1:, :])
            prediction = pred[:, :-1, :]
            mse = functional.mse_loss(prediction, target).item()

            # Convert to 0-1 range using sigmoid-like scaling
            surprise = 2.0 / (1.0 + torch.exp(torch.tensor(-mse * 10)).item()) - 1.0
            return float(max(0.0, min(1.0, surprise)))

    def _update_weights(self, loss: Tensor) -> None:
        """The key innovation: gradient descent during forward pass."""
        try:
            grads = torch.autograd.grad(
                loss, list(self.memory_net.parameters()), create_graph=False, allow_unused=True
            )

            with torch.no_grad():
                for param, grad in zip(self.memory_net.parameters(), grads):
                    if grad is not None:
                        param -= self.lr * grad
        except RuntimeError:
            # Gradient computation failed, skip update
            pass

    def observe(self, content: str | Tensor, learning_rate: float | None = None) -> dict[str, Any]:
        """
        Feed content to memory, triggering test-time learning.

        Args:
            content: Text string or tensor to learn from
            learning_rate: Optional override for learning rate

        Returns:
            dict with surprise score, weight delta, and metadata
        """
        # Handle learning rate override
        original_lr = None
        if learning_rate is not None:
            original_lr = self.lr.data.clone()
            self.lr.data = torch.tensor(learning_rate, device=self.config.device)

        # Encode if string
        x = self._encode_text(content) if isinstance(content, str) else content

        # Store initial weights for delta calculation
        initial_weights = {
            name: param.clone() for name, param in self.memory_net.named_parameters()
        }

        # Forward with learning
        output = self.forward(x, learn=True)

        # Calculate metrics
        surprise = self._compute_surprise(x, output)
        weight_delta = sum(
            (param - initial_weights[name]).abs().sum().item()
            for name, param in self.memory_net.named_parameters()
        )

        # Restore learning rate
        if original_lr is not None:
            self.lr.data = original_lr

        # Update stats
        self._observation_count += 1
        self._recent_surprises.append(surprise)
        if len(self._recent_surprises) > 100:
            self._recent_surprises.pop(0)

        return {
            "surprise": surprise,
            "weight_delta": weight_delta,
            "patterns_activated": [f"pattern_{self._observation_count}"],
            "learned": weight_delta > 1e-6,
        }

    def infer(self, query: str | Tensor, temperature: float = 1.0) -> dict[str, Any]:
        """
        Query memory using learned representations (no learning).

        Args:
            query: Text string or tensor to query
            temperature: Not used currently, for API compatibility

        Returns:
            dict with response tensor and confidence
        """
        del temperature  # Unused, kept for API compatibility
        x = self._encode_text(query) if isinstance(query, str) else query

        with torch.no_grad():
            output = self.forward(x, learn=False)
            confidence = 1.0 - self._compute_surprise(x, output)

        return {
            "response": output,
            "confidence": max(0.0, min(1.0, confidence)),
            "attention_weights": output[0, 0, :10].tolist() if output.dim() >= 3 else [],
        }

    def surprise(self, content: str | Tensor) -> float:
        """
        Measure how surprising/novel content is WITHOUT learning.

        Args:
            content: Text string or tensor to evaluate

        Returns:
            Surprise score between 0 (familiar) and 1 (novel)
        """
        x = self._encode_text(content) if isinstance(content, str) else content

        with torch.no_grad():
            output = self.memory_net(x)
            return self._compute_surprise(x, output)

    def get_weight_hash(self) -> str:
        """
        Get hash of current weights for change detection.

        Returns:
            16-character hex hash of weights
        """
        with torch.no_grad():
            state = self.memory_net.state_dict()
            flat = torch.cat([v.flatten().cpu() for v in state.values()])
            # Use string representation instead of numpy to avoid numpy dependency
            data_str = str(flat.tolist())
            hash_bytes = hashlib.sha256(data_str.encode()).digest()
            return hash_bytes[:8].hex()

    def get_stats(self) -> dict[str, Any]:
        """Get memory statistics."""
        return {
            "total_observations": self._observation_count,
            "weight_parameters": sum(p.numel() for p in self.memory_net.parameters()),
            "avg_surprise": (
                sum(self._recent_surprises) / len(self._recent_surprises)
                if self._recent_surprises
                else 0.0
            ),
            "learning_rate": self.lr.item(),
            "dimension": self.dim,
        }