Update memory.py
Browse files
memory.py
CHANGED
|
@@ -17,11 +17,13 @@ class CognitiveMemory(nn.Module):
|
|
| 17 |
self.consolidation_threshold = 0.7
|
| 18 |
|
| 19 |
# Memory projection layers
|
| 20 |
-
self.key_proj = nn.Linear(
|
| 21 |
-
self.value_proj = nn.Linear(
|
| 22 |
|
| 23 |
def add_memory(self, context: torch.Tensor, activation: float):
|
| 24 |
"""Store new memory with adaptive importance"""
|
|
|
|
|
|
|
| 25 |
importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
|
| 26 |
self.memory_queue.append({
|
| 27 |
'context': context.detach(),
|
|
@@ -46,9 +48,14 @@ class CognitiveMemory(nn.Module):
|
|
| 46 |
if not self.memory_queue:
|
| 47 |
return torch.zeros_like(query)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
query_proj = self.key_proj(query)
|
| 52 |
|
| 53 |
-
scores = F.softmax(keys
|
| 54 |
-
|
|
|
|
|
|
| 17 |
self.consolidation_threshold = 0.7
|
| 18 |
|
| 19 |
# Memory projection layers
|
| 20 |
+
self.key_proj = nn.Linear(1, 64) # Changed from context_size to 1
|
| 21 |
+
self.value_proj = nn.Linear(1, 64) # Changed from context_size to 1
|
| 22 |
|
| 23 |
def add_memory(self, context: torch.Tensor, activation: float):
|
| 24 |
"""Store new memory with adaptive importance"""
|
| 25 |
+
# Ensure context is 1D tensor with single value
|
| 26 |
+
context = context.reshape(-1)
|
| 27 |
importance = torch.sigmoid(torch.tensor(activation * 0.5 + 0.2))
|
| 28 |
self.memory_queue.append({
|
| 29 |
'context': context.detach(),
|
|
|
|
| 48 |
if not self.memory_queue:
|
| 49 |
return torch.zeros_like(query)
|
| 50 |
|
| 51 |
+
# Ensure query is 1D tensor with single value
|
| 52 |
+
query = query.reshape(1, 1)
|
| 53 |
+
memories = torch.stack([m['context'].reshape(1, 1) for m in self.memory_queue])
|
| 54 |
+
|
| 55 |
+
keys = self.key_proj(memories)
|
| 56 |
+
values = self.value_proj(memories)
|
| 57 |
query_proj = self.key_proj(query)
|
| 58 |
|
| 59 |
+
scores = F.softmax(torch.matmul(keys, query_proj.transpose(0, 1)), dim=0)
|
| 60 |
+
retrieved = torch.matmul(scores.transpose(0, 1), values)
|
| 61 |
+
return retrieved.squeeze(0)
|