Upload 2 files
Browse files- memory_forest.py +382 -0
- memory_forest_docs.py +1020 -0
memory_forest.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#############################################################################################################################################
|
| 2 |
+
#||||- - - |6.25.2025| - - - || MEMORY FOREST || - - - |memory_forest.py| - - -||||#
|
| 3 |
+
#############################################################################################################################################
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
import math
|
| 9 |
+
from collections import defaultdict, deque
|
| 10 |
+
from typing import List, Dict, Tuple, Optional
|
| 11 |
+
|
| 12 |
+
SAFE_MIN = -1e6
|
| 13 |
+
SAFE_MAX = 1e6
|
| 14 |
+
EPS = 1e-8
|
| 15 |
+
|
| 16 |
+
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - π
Έ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
|
| 17 |
+
|
| 18 |
+
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
|
| 19 |
+
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor)
|
| 20 |
+
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor)
|
| 21 |
+
return torch.clamp(tensor, min_val, max_val)
|
| 22 |
+
|
| 23 |
+
def safe_cosine_similarity(a, b, dim=-1, eps=EPS):
|
| 24 |
+
if a.dtype != torch.float32:
|
| 25 |
+
a = a.float()
|
| 26 |
+
if b.dtype != torch.float32:
|
| 27 |
+
b = b.float()
|
| 28 |
+
a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps)
|
| 29 |
+
b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps)
|
| 30 |
+
return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm)
|
| 31 |
+
|
| 32 |
+
#############################################################################################################################################
|
| 33 |
+
###################################################- - - ASSOCIATIVE HASH BUCKET - - -###################################################
|
| 34 |
+
|
| 35 |
+
class AssociativeHashBucket(nn.Module):
|
| 36 |
+
def __init__(self, bucket_size=64, embedding_dim=128, num_hash_functions=4):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.bucket_size = bucket_size
|
| 39 |
+
self.embedding_dim = embedding_dim
|
| 40 |
+
self.num_hash_functions = num_hash_functions
|
| 41 |
+
|
| 42 |
+
self.hash_projections = nn.ModuleList([
|
| 43 |
+
nn.Linear(embedding_dim, 1, bias=True) for _ in range(num_hash_functions)
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
self.register_buffer('stored_items', torch.zeros(bucket_size, embedding_dim))
|
| 47 |
+
self.register_buffer('item_hashes', torch.zeros(bucket_size, num_hash_functions))
|
| 48 |
+
self.register_buffer('occupancy', torch.zeros(bucket_size, dtype=torch.bool))
|
| 49 |
+
self.register_buffer('access_counts', torch.zeros(bucket_size))
|
| 50 |
+
|
| 51 |
+
self.similarity_threshold = nn.Parameter(torch.tensor(0.7))
|
| 52 |
+
self.decay_rate = nn.Parameter(torch.tensor(0.99))
|
| 53 |
+
|
| 54 |
+
self.storage_pointer = 0
|
| 55 |
+
|
| 56 |
+
def compute_hash_signature(self, item_embedding):
|
| 57 |
+
x = item_embedding
|
| 58 |
+
if x.dim() == 1:
|
| 59 |
+
x = x.unsqueeze(0)
|
| 60 |
+
signatures = []
|
| 61 |
+
for hash_proj in self.hash_projections:
|
| 62 |
+
sig = torch.tanh(hash_proj(x)).squeeze(-1) # (B,)
|
| 63 |
+
signatures.append(sig)
|
| 64 |
+
sigs = torch.stack(signatures, dim=-1) # (B, num_hash)
|
| 65 |
+
return sigs.squeeze(0)
|
| 66 |
+
|
| 67 |
+
def store_item(self, item_embedding, item_id=None):
|
| 68 |
+
if item_embedding.dim() == 1:
|
| 69 |
+
item_embedding = item_embedding.unsqueeze(0)
|
| 70 |
+
|
| 71 |
+
batch_size = item_embedding.shape[0]
|
| 72 |
+
stored_items = []
|
| 73 |
+
|
| 74 |
+
for i in range(batch_size):
|
| 75 |
+
embedding = item_embedding[i]
|
| 76 |
+
hash_sig = self.compute_hash_signature(embedding)
|
| 77 |
+
|
| 78 |
+
if self.occupancy.any():
|
| 79 |
+
similarities = safe_cosine_similarity(
|
| 80 |
+
embedding.unsqueeze(0),
|
| 81 |
+
self.stored_items[self.occupancy],
|
| 82 |
+
dim=-1
|
| 83 |
+
).squeeze()
|
| 84 |
+
|
| 85 |
+
threshold = torch.clamp(self.similarity_threshold, 0.1, 0.95)
|
| 86 |
+
if similarities.numel() > 0 and similarities.max() > threshold:
|
| 87 |
+
best_idx = self.occupancy.nonzero(as_tuple=False)[similarities.argmax()]
|
| 88 |
+
self.stored_items[best_idx] = 0.9 * self.stored_items[best_idx] + 0.1 * embedding
|
| 89 |
+
self.access_counts[best_idx] += 1
|
| 90 |
+
stored_items.append(int(best_idx.item()))
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
if self.storage_pointer >= self.bucket_size:
|
| 94 |
+
if self.occupancy.any():
|
| 95 |
+
rel_idx = self.access_counts[self.occupancy].argmin()
|
| 96 |
+
evict_idx = self.occupancy.nonzero(as_tuple=False)[rel_idx]
|
| 97 |
+
else:
|
| 98 |
+
evict_idx = torch.tensor(0)
|
| 99 |
+
else:
|
| 100 |
+
evict_idx = torch.tensor(self.storage_pointer)
|
| 101 |
+
self.storage_pointer = min(self.storage_pointer + 1, self.bucket_size)
|
| 102 |
+
|
| 103 |
+
self.stored_items[evict_idx] = embedding
|
| 104 |
+
self.item_hashes[evict_idx] = hash_sig.squeeze()
|
| 105 |
+
self.occupancy[evict_idx] = True
|
| 106 |
+
self.access_counts[evict_idx] = 1
|
| 107 |
+
stored_items.append(int(evict_idx.item()))
|
| 108 |
+
|
| 109 |
+
return stored_items
|
| 110 |
+
|
| 111 |
+
def retrieve_similar(self, query_embedding, top_k=5):
|
| 112 |
+
if query_embedding.dim() == 1:
|
| 113 |
+
query_embedding = query_embedding.unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
if not self.occupancy.any():
|
| 116 |
+
return [], []
|
| 117 |
+
|
| 118 |
+
valid_items = self.stored_items[self.occupancy]
|
| 119 |
+
valid_indices = self.occupancy.nonzero(as_tuple=False).squeeze(-1)
|
| 120 |
+
|
| 121 |
+
if valid_items.numel() == 0:
|
| 122 |
+
return [], []
|
| 123 |
+
|
| 124 |
+
similarities = safe_cosine_similarity(
|
| 125 |
+
query_embedding.expand(valid_items.shape[0], -1),
|
| 126 |
+
valid_items,
|
| 127 |
+
dim=-1
|
| 128 |
+
).squeeze(-1) # (N,)
|
| 129 |
+
|
| 130 |
+
if similarities.numel() == 0:
|
| 131 |
+
return [], []
|
| 132 |
+
|
| 133 |
+
k = min(top_k, similarities.size(0))
|
| 134 |
+
top_sims, top_indices = torch.topk(similarities, k)
|
| 135 |
+
|
| 136 |
+
retrieved_items = valid_items[top_indices]
|
| 137 |
+
retrieved_indices = valid_indices[top_indices]
|
| 138 |
+
|
| 139 |
+
for idx in retrieved_indices:
|
| 140 |
+
self.access_counts[idx] += 1
|
| 141 |
+
|
| 142 |
+
return retrieved_items, top_sims
|
| 143 |
+
|
| 144 |
+
def get_bucket_stats(self):
|
| 145 |
+
return {
|
| 146 |
+
'occupancy_rate': self.occupancy.float().mean().item(),
|
| 147 |
+
'total_accesses': self.access_counts.sum().item(),
|
| 148 |
+
'avg_similarity': self.similarity_threshold.item(),
|
| 149 |
+
'storage_pointer': self.storage_pointer
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
###########################################################################################################################################
|
| 153 |
+
################################################- - - MEMORY DECISION TREE - - -#######################################################
|
| 154 |
+
|
| 155 |
+
class MemoryDecisionTree(nn.Module):
|
| 156 |
+
def __init__(self, input_dim, max_depth=6, min_samples_split=2):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.input_dim = input_dim
|
| 159 |
+
self.max_depth = max_depth
|
| 160 |
+
self.min_samples_split = min_samples_split
|
| 161 |
+
|
| 162 |
+
max_nodes = 2**max_depth - 1
|
| 163 |
+
|
| 164 |
+
self.split_weights = nn.Parameter(torch.randn(max_nodes, input_dim) * 0.1)
|
| 165 |
+
self.split_biases = nn.Parameter(torch.zeros(max_nodes))
|
| 166 |
+
self.split_temperatures = nn.Parameter(torch.ones(max_nodes))
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
self.split_temperatures.data.mul_(0.6)
|
| 169 |
+
self.split_biases.data.add_(0.01 * torch.randn_like(self.split_biases))
|
| 170 |
+
|
| 171 |
+
self.register_buffer('node_active', torch.zeros(max_nodes, dtype=torch.bool))
|
| 172 |
+
self.register_buffer('node_samples', torch.zeros(max_nodes))
|
| 173 |
+
|
| 174 |
+
self.leaf_to_bucket = {}
|
| 175 |
+
self.bucket_to_leaves = defaultdict(list)
|
| 176 |
+
|
| 177 |
+
self.node_active[0] = True
|
| 178 |
+
|
| 179 |
+
def get_node_split(self, node_idx, x):
|
| 180 |
+
if node_idx >= len(self.split_weights):
|
| 181 |
+
return torch.zeros(x.shape[0], device=x.device)
|
| 182 |
+
|
| 183 |
+
weights = self.split_weights[node_idx]
|
| 184 |
+
bias = self.split_biases[node_idx]
|
| 185 |
+
temp = torch.clamp(self.split_temperatures[node_idx], 0.1, 10.0)
|
| 186 |
+
|
| 187 |
+
split_score = torch.matmul(x, weights) + bias
|
| 188 |
+
split_prob = torch.sigmoid(split_score / temp)
|
| 189 |
+
|
| 190 |
+
return split_prob
|
| 191 |
+
|
| 192 |
+
def route_to_leaf(self, x, deterministic=False):
|
| 193 |
+
batch_size = x.shape[0]
|
| 194 |
+
device = x.device
|
| 195 |
+
|
| 196 |
+
current_nodes = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 197 |
+
paths = torch.zeros(batch_size, self.max_depth, dtype=torch.long, device=device)
|
| 198 |
+
|
| 199 |
+
for depth in range(self.max_depth - 1):
|
| 200 |
+
split_probs = torch.zeros(batch_size, device=device)
|
| 201 |
+
|
| 202 |
+
for i in range(batch_size):
|
| 203 |
+
node_idx = int(current_nodes[i].item())
|
| 204 |
+
if self.node_active[node_idx]:
|
| 205 |
+
split_probs[i] = self.get_node_split(node_idx, x[i:i+1]).squeeze()
|
| 206 |
+
|
| 207 |
+
if deterministic:
|
| 208 |
+
go_right = (split_probs > 0.5).long()
|
| 209 |
+
else:
|
| 210 |
+
go_right = torch.bernoulli(split_probs).long()
|
| 211 |
+
|
| 212 |
+
paths[:, depth] = go_right
|
| 213 |
+
|
| 214 |
+
current_nodes = current_nodes * 2 + 1 + go_right
|
| 215 |
+
|
| 216 |
+
return current_nodes, paths
|
| 217 |
+
|
| 218 |
+
def assign_leaf_to_bucket(self, leaf_idx, bucket_idx):
|
| 219 |
+
self.leaf_to_bucket[int(leaf_idx.item())] = int(bucket_idx)
|
| 220 |
+
self.bucket_to_leaves[int(bucket_idx)].append(int(leaf_idx.item()))
|
| 221 |
+
|
| 222 |
+
def get_bucket_for_input(self, x, deterministic=True):
|
| 223 |
+
leaf_nodes, _ = self.route_to_leaf(x, deterministic=deterministic)
|
| 224 |
+
|
| 225 |
+
bucket_assignments = []
|
| 226 |
+
for leaf in leaf_nodes:
|
| 227 |
+
bucket_idx = self.leaf_to_bucket.get(int(leaf.item()), 0)
|
| 228 |
+
bucket_assignments.append(bucket_idx)
|
| 229 |
+
|
| 230 |
+
return torch.tensor(bucket_assignments, device=x.device)
|
| 231 |
+
|
| 232 |
+
def update_node_statistics(self, x, rewards):
|
| 233 |
+
leaf_nodes, paths = self.route_to_leaf(x, deterministic=True)
|
| 234 |
+
|
| 235 |
+
for i in range(x.shape[0]):
|
| 236 |
+
current_node = 0
|
| 237 |
+
reward = rewards[i].item() if torch.is_tensor(rewards[i]) else rewards[i]
|
| 238 |
+
|
| 239 |
+
for depth in range(self.max_depth - 1):
|
| 240 |
+
if current_node < len(self.node_samples):
|
| 241 |
+
self.node_samples[current_node] += 1
|
| 242 |
+
self.node_active[current_node] = True
|
| 243 |
+
|
| 244 |
+
if reward > 0.5:
|
| 245 |
+
direction = paths[i, depth]
|
| 246 |
+
if direction == 1:
|
| 247 |
+
self.split_biases.data[current_node] += 0.01
|
| 248 |
+
else:
|
| 249 |
+
self.split_biases.data[current_node] -= 0.01
|
| 250 |
+
|
| 251 |
+
direction = paths[i, depth] if depth < paths.shape[1] else 0
|
| 252 |
+
current_node = current_node * 2 + 1 + int(direction.item())
|
| 253 |
+
|
| 254 |
+
if current_node >= 2**self.max_depth - 1:
|
| 255 |
+
break
|
| 256 |
+
|
| 257 |
+
###########################################################################################################################################
|
| 258 |
+
##################################################- - - MEMORY FOREST - - -############################################################
|
| 259 |
+
|
| 260 |
+
class MemoryForest(nn.Module):
|
| 261 |
+
def __init__(self, input_dim, num_trees=5, max_depth=6, bucket_size=64, embedding_dim=128):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.input_dim = input_dim
|
| 264 |
+
self.num_trees = num_trees
|
| 265 |
+
self.embedding_dim = embedding_dim
|
| 266 |
+
|
| 267 |
+
self.trees = nn.ModuleList([
|
| 268 |
+
MemoryDecisionTree(input_dim, max_depth) for _ in range(num_trees)
|
| 269 |
+
])
|
| 270 |
+
|
| 271 |
+
self.num_buckets = num_trees * (2**max_depth)
|
| 272 |
+
self.buckets = nn.ModuleList([
|
| 273 |
+
AssociativeHashBucket(bucket_size, embedding_dim) for _ in range(self.num_buckets)
|
| 274 |
+
])
|
| 275 |
+
|
| 276 |
+
self.feature_encoder = nn.Sequential(
|
| 277 |
+
nn.Linear(input_dim, embedding_dim),
|
| 278 |
+
nn.LayerNorm(embedding_dim),
|
| 279 |
+
nn.ReLU(),
|
| 280 |
+
nn.Linear(embedding_dim, embedding_dim)
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self._initialize_bucket_assignments()
|
| 284 |
+
|
| 285 |
+
def _initialize_bucket_assignments(self):
|
| 286 |
+
bucket_idx = 0
|
| 287 |
+
for tree_idx, tree in enumerate(self.trees):
|
| 288 |
+
start_leaf = 2**(tree.max_depth - 1) - 1
|
| 289 |
+
end_leaf = 2**tree.max_depth - 2
|
| 290 |
+
for leaf in range(start_leaf, end_leaf + 1):
|
| 291 |
+
if bucket_idx < self.num_buckets:
|
| 292 |
+
tree.assign_leaf_to_bucket(torch.tensor(leaf), bucket_idx)
|
| 293 |
+
bucket_idx += 1
|
| 294 |
+
|
| 295 |
+
def store(self, features, items=None):
|
| 296 |
+
if items is None:
|
| 297 |
+
items = features
|
| 298 |
+
|
| 299 |
+
embeddings = self.feature_encoder(features)
|
| 300 |
+
|
| 301 |
+
storage_results = []
|
| 302 |
+
|
| 303 |
+
for tree in self.trees:
|
| 304 |
+
bucket_assignments = tree.get_bucket_for_input(features, deterministic=False)
|
| 305 |
+
|
| 306 |
+
for i, b_idx in enumerate(bucket_assignments.tolist()):
|
| 307 |
+
if b_idx < len(self.buckets):
|
| 308 |
+
stored_idx = self.buckets[b_idx].store_item(embeddings[i])
|
| 309 |
+
storage_results.append((b_idx, stored_idx))
|
| 310 |
+
|
| 311 |
+
return storage_results
|
| 312 |
+
|
| 313 |
+
def retrieve(self, query_features, top_k=5):
|
| 314 |
+
query_embeddings = self.feature_encoder(query_features)
|
| 315 |
+
|
| 316 |
+
bucket_votes = defaultdict(list)
|
| 317 |
+
|
| 318 |
+
for tree in self.trees:
|
| 319 |
+
bucket_assignments = tree.get_bucket_for_input(query_features, deterministic=True)
|
| 320 |
+
|
| 321 |
+
for i, b_idx in enumerate(bucket_assignments.tolist()):
|
| 322 |
+
if b_idx < len(self.buckets):
|
| 323 |
+
retrieved_items, similarities = self.buckets[b_idx].retrieve_similar(
|
| 324 |
+
query_embeddings[i], top_k=top_k
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if len(retrieved_items) > 0:
|
| 328 |
+
float_sims = similarities.detach().cpu().tolist()
|
| 329 |
+
for itm, sim_t, sim_f in zip(retrieved_items, similarities, float_sims):
|
| 330 |
+
bucket_votes[i].append((itm, sim_f, sim_t))
|
| 331 |
+
|
| 332 |
+
final_results = []
|
| 333 |
+
for query_idx in range(query_features.shape[0]):
|
| 334 |
+
if query_idx in bucket_votes and len(bucket_votes[query_idx]) > 0:
|
| 335 |
+
candidates = bucket_votes[query_idx]
|
| 336 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 337 |
+
|
| 338 |
+
top_candidates = candidates[:top_k]
|
| 339 |
+
items = [c[0] for c in top_candidates]
|
| 340 |
+
sims_t = [c[2] for c in top_candidates]
|
| 341 |
+
final_results.append((torch.stack(items), torch.stack(sims_t)))
|
| 342 |
+
else:
|
| 343 |
+
final_results.append((torch.tensor([]), torch.tensor([])))
|
| 344 |
+
|
| 345 |
+
return final_results
|
| 346 |
+
|
| 347 |
+
def update_routing(self, features, retrieval_success):
|
| 348 |
+
for tree in self.trees:
|
| 349 |
+
tree.update_node_statistics(features, retrieval_success)
|
| 350 |
+
|
| 351 |
+
def get_forest_stats(self):
|
| 352 |
+
stats = {
|
| 353 |
+
'num_trees': self.num_trees,
|
| 354 |
+
'num_buckets': self.num_buckets,
|
| 355 |
+
'bucket_stats': [],
|
| 356 |
+
'tree_stats': []
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
for i, bucket in enumerate(self.buckets):
|
| 360 |
+
bucket_stat = bucket.get_bucket_stats()
|
| 361 |
+
bucket_stat['bucket_id'] = i
|
| 362 |
+
stats['bucket_stats'].append(bucket_stat)
|
| 363 |
+
|
| 364 |
+
for i, tree in enumerate(self.trees):
|
| 365 |
+
tree_stat = {
|
| 366 |
+
'tree_id': i,
|
| 367 |
+
'active_nodes': tree.node_active.sum().item(),
|
| 368 |
+
'total_samples': tree.node_samples.sum().item(),
|
| 369 |
+
'max_depth': tree.max_depth
|
| 370 |
+
}
|
| 371 |
+
stats['tree_stats'].append(tree_stat)
|
| 372 |
+
|
| 373 |
+
return stats
|
| 374 |
+
|
| 375 |
+
def forward(self, features, items=None, mode='store'):
|
| 376 |
+
if mode == 'store':
|
| 377 |
+
return self.store(features, items)
|
| 378 |
+
elif mode == 'retrieve':
|
| 379 |
+
return self.retrieve(features)
|
| 380 |
+
else:
|
| 381 |
+
raise ValueError("Mode must be 'store' or 'retrieve'")
|
| 382 |
+
|
memory_forest_docs.py
ADDED
|
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
##############################################################################################################################################
|
| 2 |
+
#||||- - - |6.25.2025| - - - || MEMORY FOREST || - - - |1990two| - - -|||| #
|
| 3 |
+
##############################################################################################################################################
|
| 4 |
+
"""
|
| 5 |
+
Mathematical Foundation & Conceptual Documentation
|
| 6 |
+
-------------------------------------------------
|
| 7 |
+
|
| 8 |
+
CORE PRINCIPLE:
|
| 9 |
+
Combines decision tree routing with associative hash buckets to create scalable
|
| 10 |
+
memory systems that learn optimal organization patterns. Instead of searching
|
| 11 |
+
all memory linearly, learned decision trees route queries to relevant memory
|
| 12 |
+
buckets, creating hierarchical, adaptive memory organization.
|
| 13 |
+
|
| 14 |
+
MATHEMATICAL FOUNDATION:
|
| 15 |
+
=======================
|
| 16 |
+
|
| 17 |
+
1. DECISION TREE ROUTING:
|
| 18 |
+
Split Function: s(x, ΞΈ) = Ο((wΒ·x + b)/Ο)
|
| 19 |
+
|
| 20 |
+
Where:
|
| 21 |
+
- x: input feature vector
|
| 22 |
+
- w, b: learnable split parameters
|
| 23 |
+
- Ο: temperature parameter (controls split sharpness)
|
| 24 |
+
- Ο: sigmoid function
|
| 25 |
+
- s(x,ΞΈ) β [0,1]: routing probability (left vs right)
|
| 26 |
+
|
| 27 |
+
2. HIERARCHICAL ROUTING:
|
| 28 |
+
Path to leaf: p = [sβ, sβ, ..., s_{d-1}] for depth d
|
| 29 |
+
Leaf index: L(x) = Ξ£α΅’ sα΅’ Γ 2^i (binary path encoding)
|
| 30 |
+
Bucket assignment: B(x) = TreeToBucket[L(x)]
|
| 31 |
+
|
| 32 |
+
3. ASSOCIATIVE MEMORY OPERATIONS:
|
| 33 |
+
Hash Functions: h_k(x) = tanh(W_kΒ·x + b_k) for k = 1..K
|
| 34 |
+
Hash Signature: H(x) = [hβ(x), hβ(x), ..., h_K(x)]
|
| 35 |
+
Similarity: sim(x,y) = cosine(H(x), H(y))
|
| 36 |
+
|
| 37 |
+
4. MEMORY STORAGE:
|
| 38 |
+
Storage Condition: sim(x, stored) < ΞΈ_similarity
|
| 39 |
+
Eviction Policy: LRU based on access_count[i]
|
| 40 |
+
Update Rule: x_stored β Ξ±Β·x_stored + (1-Ξ±)Β·x_new for similar items
|
| 41 |
+
|
| 42 |
+
5. ENSEMBLE RETRIEVAL:
|
| 43 |
+
Tree Votes: V_t(x) = {items from bucket B_t(x)}
|
| 44 |
+
Similarity Scores: S(q,i) = cosine_similarity(q, i)
|
| 45 |
+
Final Ranking: rank = argmax_i Ξ£_t w_t Γ S(q,i) Γ I(i β V_t)
|
| 46 |
+
|
| 47 |
+
Where w_t are tree importance weights.
|
| 48 |
+
|
| 49 |
+
6. ADAPTIVE LEARNING:
|
| 50 |
+
Success Feedback: R(query, retrieval) β [0,1]
|
| 51 |
+
Tree Update: ΞΈ_t β ΞΈ_t + Ξ·Β·βΞΈ log P(correct_path|R)
|
| 52 |
+
Split Reinforcement: bias_node β bias_node + Ξ±Β·sign(R - 0.5)
|
| 53 |
+
|
| 54 |
+
CONCEPTUAL REASONING:
|
| 55 |
+
====================
|
| 56 |
+
|
| 57 |
+
WHY DECISION TREES + HASH BUCKETS?
|
| 58 |
+
- Linear search over large memories is O(n) - doesn't scale
|
| 59 |
+
- Fixed hash functions don't adapt to data distribution
|
| 60 |
+
- Decision trees provide hierarchical, learned routing (O(log n))
|
| 61 |
+
- Hash buckets enable efficient similarity-based storage/retrieval
|
| 62 |
+
- Combination creates adaptive, scalable associative memory
|
| 63 |
+
|
| 64 |
+
KEY INNOVATIONS:
|
| 65 |
+
1. **Learned Routing**: Decision trees adapt splits based on retrieval success
|
| 66 |
+
2. **Hierarchical Organization**: Multi-level memory structure (trees β buckets β items)
|
| 67 |
+
3. **Ensemble Retrieval**: Multiple trees vote on best memories
|
| 68 |
+
4. **Adaptive Hash Functions**: Learnable hash functions with Hebbian updates
|
| 69 |
+
5. **Success-Based Learning**: Trees optimize for retrieval performance
|
| 70 |
+
|
| 71 |
+
APPLICATIONS:
|
| 72 |
+
- Large-scale information retrieval systems
|
| 73 |
+
- Adaptive caching and content distribution
|
| 74 |
+
- Knowledge base organization and query
|
| 75 |
+
- Recommender systems with hierarchical user models
|
| 76 |
+
- Scientific literature search and organization
|
| 77 |
+
|
| 78 |
+
COMPLEXITY ANALYSIS:
|
| 79 |
+
- Storage: O(log T + B) where T=trees, B=bucket_size
|
| 80 |
+
- Retrieval: O(T Γ log T + k Γ B) where k=top_k results
|
| 81 |
+
- Tree Update: O(log T) per feedback sample
|
| 82 |
+
- Memory: O(T Γ 2^D + N Γ E) where D=depth, N=items, E=embedding_dim
|
| 83 |
+
- Scalability: Sub-linear in number of stored items
|
| 84 |
+
|
| 85 |
+
BIOLOGICAL INSPIRATION:
|
| 86 |
+
- Hippocampal place cell organization for spatial memory
|
| 87 |
+
- Cortical hierarchical feature extraction and routing
|
| 88 |
+
- Cerebellar learned motor program selection
|
| 89 |
+
- Associative memory formation in neural circuits
|
| 90 |
+
- Synaptic plasticity for adaptive connection strengths
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
from __future__ import annotations
|
| 94 |
+
import torch
|
| 95 |
+
import torch.nn as nn
|
| 96 |
+
import torch.nn.functional as F
|
| 97 |
+
import numpy as np
|
| 98 |
+
import math
|
| 99 |
+
from collections import defaultdict, deque
|
| 100 |
+
from typing import List, Dict, Tuple, Optional
|
| 101 |
+
|
| 102 |
+
SAFE_MIN = -1e6
|
| 103 |
+
SAFE_MAX = 1e6
|
| 104 |
+
EPS = 1e-8
|
| 105 |
+
|
| 106 |
+
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - π¦ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
|
| 107 |
+
|
| 108 |
+
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
|
| 109 |
+
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor)
|
| 110 |
+
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor)
|
| 111 |
+
return torch.clamp(tensor, min_val, max_val)
|
| 112 |
+
|
| 113 |
+
def safe_cosine_similarity(a, b, dim=-1, eps=EPS):
|
| 114 |
+
"""Numerically stable cosine similarity computation.
|
| 115 |
+
|
| 116 |
+
Computes cosine similarity between vectors with proper normalization
|
| 117 |
+
and numerical stability checks to prevent division by zero.
|
| 118 |
+
|
| 119 |
+
Mathematical Details:
|
| 120 |
+
- cosine(a,b) = (aΒ·b) / (||a|| ||b||)
|
| 121 |
+
- Handles zero vectors gracefully
|
| 122 |
+
- Clamps norms to minimum value for stability
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
a, b: Input tensors
|
| 126 |
+
dim: Dimension along which to compute similarity
|
| 127 |
+
eps: Minimum norm value for numerical stability
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Cosine similarity values β [-1, 1]
|
| 131 |
+
"""
|
| 132 |
+
if a.dtype != torch.float32:
|
| 133 |
+
a = a.float()
|
| 134 |
+
if b.dtype != torch.float32:
|
| 135 |
+
b = b.float()
|
| 136 |
+
a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps)
|
| 137 |
+
b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps)
|
| 138 |
+
return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm)
|
| 139 |
+
|
| 140 |
+
###########################################################################################################################################
|
| 141 |
+
#################################################- - - ASSOCIATIVE HASH BUCKET - - -###################################################
|
| 142 |
+
|
| 143 |
+
class AssociativeHashBucket(nn.Module):
|
| 144 |
+
"""Associative memory bucket with learnable hash functions and similarity clustering.
|
| 145 |
+
|
| 146 |
+
Implements a memory bucket that stores items with learned hash signatures
|
| 147 |
+
and retrieves similar items based on cosine similarity. Features adaptive
|
| 148 |
+
hash functions, similarity-based clustering, and LRU eviction policy.
|
| 149 |
+
|
| 150 |
+
Mathematical Framework:
|
| 151 |
+
- Hash functions: h_k(x) = tanh(W_kΒ·x + b_k) for k = 1..K
|
| 152 |
+
- Similarity threshold: store only if max_sim(x, stored) < ΞΈ
|
| 153 |
+
- Retrieval: rank by cosine similarity in hash space
|
| 154 |
+
- Eviction: LRU based on access patterns
|
| 155 |
+
|
| 156 |
+
The bucket learns to cluster similar items together and adapts
|
| 157 |
+
its hash functions based on storage and retrieval patterns.
|
| 158 |
+
"""
|
| 159 |
+
def __init__(self, bucket_size=64, embedding_dim=128, num_hash_functions=4):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.bucket_size = bucket_size
|
| 162 |
+
self.embedding_dim = embedding_dim
|
| 163 |
+
self.num_hash_functions = num_hash_functions
|
| 164 |
+
|
| 165 |
+
# Learnable hash functions (linear projections with nonlinearity)
|
| 166 |
+
self.hash_projections = nn.ModuleList([
|
| 167 |
+
nn.Linear(embedding_dim, 1, bias=True) for _ in range(num_hash_functions)
|
| 168 |
+
])
|
| 169 |
+
|
| 170 |
+
# Storage buffers for items and metadata
|
| 171 |
+
self.register_buffer('stored_items', torch.zeros(bucket_size, embedding_dim))
|
| 172 |
+
self.register_buffer('item_hashes', torch.zeros(bucket_size, num_hash_functions))
|
| 173 |
+
self.register_buffer('occupancy', torch.zeros(bucket_size, dtype=torch.bool))
|
| 174 |
+
self.register_buffer('access_counts', torch.zeros(bucket_size))
|
| 175 |
+
|
| 176 |
+
# Associative memory parameters
|
| 177 |
+
self.similarity_threshold = nn.Parameter(torch.tensor(0.7))
|
| 178 |
+
self.decay_rate = nn.Parameter(torch.tensor(0.99))
|
| 179 |
+
|
| 180 |
+
# Storage management
|
| 181 |
+
self.storage_pointer = 0
|
| 182 |
+
|
| 183 |
+
def compute_hash_signature(self, item_embedding):
|
| 184 |
+
"""Compute hash signature for item using learnable hash functions.
|
| 185 |
+
|
| 186 |
+
Applies K learned hash functions to generate a signature vector
|
| 187 |
+
that captures important features for similarity matching.
|
| 188 |
+
|
| 189 |
+
Mathematical Details:
|
| 190 |
+
- Each hash function: h_k(x) = tanh(W_kΒ·x + b_k)
|
| 191 |
+
- Signature: [hβ(x), hβ(x), ..., h_K(x)]
|
| 192 |
+
- Tanh provides bounded, differentiable hash values
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
item_embedding: Input embedding tensor [batch_size?, embedding_dim]
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Hash signature [num_hash_functions] or [batch_size, num_hash_functions]
|
| 199 |
+
"""
|
| 200 |
+
x = item_embedding
|
| 201 |
+
if x.dim() == 1:
|
| 202 |
+
x = x.unsqueeze(0)
|
| 203 |
+
|
| 204 |
+
signatures = []
|
| 205 |
+
for hash_proj in self.hash_projections:
|
| 206 |
+
sig = torch.tanh(hash_proj(x)).squeeze(-1) # [batch_size]
|
| 207 |
+
signatures.append(sig)
|
| 208 |
+
|
| 209 |
+
sigs = torch.stack(signatures, dim=-1) # [batch_size, num_hash_functions]
|
| 210 |
+
return sigs.squeeze(0) if item_embedding.dim() == 1 else sigs
|
| 211 |
+
|
| 212 |
+
def store_item(self, item_embedding, item_id=None):
|
| 213 |
+
"""Store item in bucket with similarity-based clustering and eviction.
|
| 214 |
+
|
| 215 |
+
Storage Strategy:
|
| 216 |
+
1. Check similarity to existing items
|
| 217 |
+
2. If similar item exists, update it (clustering)
|
| 218 |
+
3. Otherwise, store as new item
|
| 219 |
+
4. Use LRU eviction when bucket is full
|
| 220 |
+
|
| 221 |
+
Mathematical Details:
|
| 222 |
+
- Similarity check: max_i cos_sim(x, stored_i) > ΞΈ
|
| 223 |
+
- Update rule: stored_i β Ξ±Β·stored_i + (1-Ξ±)Β·x (Ξ±=0.9)
|
| 224 |
+
- Eviction: remove item with minimum access_count
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
item_embedding: Item to store [embedding_dim] or [batch_size, embedding_dim]
|
| 228 |
+
item_id: Optional item identifier
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
List of storage indices where items were placed
|
| 232 |
+
"""
|
| 233 |
+
if item_embedding.dim() == 1:
|
| 234 |
+
item_embedding = item_embedding.unsqueeze(0)
|
| 235 |
+
|
| 236 |
+
batch_size = item_embedding.shape[0]
|
| 237 |
+
stored_items = []
|
| 238 |
+
|
| 239 |
+
for i in range(batch_size):
|
| 240 |
+
embedding = item_embedding[i]
|
| 241 |
+
hash_sig = self.compute_hash_signature(embedding)
|
| 242 |
+
|
| 243 |
+
# Check similarity to existing items (similarity-based clustering)
|
| 244 |
+
if self.occupancy.any():
|
| 245 |
+
similarities = safe_cosine_similarity(
|
| 246 |
+
embedding.unsqueeze(0),
|
| 247 |
+
self.stored_items[self.occupancy],
|
| 248 |
+
dim=-1
|
| 249 |
+
).squeeze()
|
| 250 |
+
|
| 251 |
+
threshold = torch.clamp(self.similarity_threshold, 0.1, 0.95)
|
| 252 |
+
if similarities.numel() > 0 and similarities.max() > threshold:
|
| 253 |
+
# Update existing similar item (weighted average)
|
| 254 |
+
best_idx = self.occupancy.nonzero(as_tuple=False)[similarities.argmax()]
|
| 255 |
+
self.stored_items[best_idx] = 0.9 * self.stored_items[best_idx] + 0.1 * embedding
|
| 256 |
+
self.access_counts[best_idx] += 1
|
| 257 |
+
stored_items.append(int(best_idx.item()))
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
# Store as new item
|
| 261 |
+
if self.storage_pointer >= self.bucket_size:
|
| 262 |
+
# Bucket full - use LRU eviction
|
| 263 |
+
if self.occupancy.any():
|
| 264 |
+
rel_idx = self.access_counts[self.occupancy].argmin()
|
| 265 |
+
evict_idx = self.occupancy.nonzero(as_tuple=False)[rel_idx]
|
| 266 |
+
else:
|
| 267 |
+
evict_idx = torch.tensor(0)
|
| 268 |
+
else:
|
| 269 |
+
evict_idx = torch.tensor(self.storage_pointer)
|
| 270 |
+
self.storage_pointer = min(self.storage_pointer + 1, self.bucket_size)
|
| 271 |
+
|
| 272 |
+
# Store item and metadata
|
| 273 |
+
self.stored_items[evict_idx] = embedding
|
| 274 |
+
self.item_hashes[evict_idx] = hash_sig.squeeze()
|
| 275 |
+
self.occupancy[evict_idx] = True
|
| 276 |
+
self.access_counts[evict_idx] = 1
|
| 277 |
+
stored_items.append(int(evict_idx.item()))
|
| 278 |
+
|
| 279 |
+
return stored_items
|
| 280 |
+
|
| 281 |
+
def retrieve_similar(self, query_embedding, top_k=5):
|
| 282 |
+
"""Retrieve most similar items to query based on cosine similarity.
|
| 283 |
+
|
| 284 |
+
Retrieval Process:
|
| 285 |
+
1. Compute similarities to all stored items
|
| 286 |
+
2. Rank by similarity score
|
| 287 |
+
3. Return top-k most similar items
|
| 288 |
+
4. Update access counts for retrieved items
|
| 289 |
+
|
| 290 |
+
Mathematical Details:
|
| 291 |
+
- Similarity: cos_sim(query, stored_i) for all stored items
|
| 292 |
+
- Ranking: argsort(similarities, descending=True)
|
| 293 |
+
- Access update: access_count[retrieved] += 1
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
query_embedding: Query vector [embedding_dim] or [batch_size, embedding_dim]
|
| 297 |
+
top_k: Number of most similar items to return
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Tuple of (retrieved_items, similarity_scores)
|
| 301 |
+
"""
|
| 302 |
+
if query_embedding.dim() == 1:
|
| 303 |
+
query_embedding = query_embedding.unsqueeze(0)
|
| 304 |
+
|
| 305 |
+
if not self.occupancy.any():
|
| 306 |
+
return [], []
|
| 307 |
+
|
| 308 |
+
# Get valid stored items
|
| 309 |
+
valid_items = self.stored_items[self.occupancy]
|
| 310 |
+
valid_indices = self.occupancy.nonzero(as_tuple=False).squeeze(-1)
|
| 311 |
+
|
| 312 |
+
if valid_items.numel() == 0:
|
| 313 |
+
return [], []
|
| 314 |
+
|
| 315 |
+
# Compute cosine similarities
|
| 316 |
+
similarities = safe_cosine_similarity(
|
| 317 |
+
query_embedding.expand(valid_items.shape[0], -1),
|
| 318 |
+
valid_items,
|
| 319 |
+
dim=-1
|
| 320 |
+
).squeeze(-1) # [num_valid_items]
|
| 321 |
+
|
| 322 |
+
if similarities.numel() == 0:
|
| 323 |
+
return [], []
|
| 324 |
+
|
| 325 |
+
# Get top-k most similar items
|
| 326 |
+
k = min(top_k, similarities.size(0))
|
| 327 |
+
top_sims, top_indices = torch.topk(similarities, k)
|
| 328 |
+
|
| 329 |
+
retrieved_items = valid_items[top_indices]
|
| 330 |
+
retrieved_indices = valid_indices[top_indices]
|
| 331 |
+
|
| 332 |
+
# Update access counts for retrieved items (LRU maintenance)
|
| 333 |
+
for idx in retrieved_indices:
|
| 334 |
+
self.access_counts[idx] += 1
|
| 335 |
+
|
| 336 |
+
return retrieved_items, top_sims
|
| 337 |
+
|
| 338 |
+
def get_bucket_stats(self):
|
| 339 |
+
"""Get comprehensive bucket statistics for monitoring and analysis.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
Dictionary containing occupancy, access patterns, and configuration info
|
| 343 |
+
"""
|
| 344 |
+
return {
|
| 345 |
+
'occupancy_rate': self.occupancy.float().mean().item(),
|
| 346 |
+
'total_accesses': self.access_counts.sum().item(),
|
| 347 |
+
'avg_similarity': self.similarity_threshold.item(),
|
| 348 |
+
'storage_pointer': self.storage_pointer
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
###########################################################################################################################################
|
| 352 |
+
################################################- - - MEMORY DECISION TREE - - -#######################################################
|
| 353 |
+
|
| 354 |
+
class MemoryDecisionTree(nn.Module):
|
| 355 |
+
"""Learned decision tree for adaptive memory routing with success-based updates.
|
| 356 |
+
|
| 357 |
+
Implements a binary decision tree where each internal node learns a split
|
| 358 |
+
function based on retrieval success feedback. Trees adapt their routing
|
| 359 |
+
decisions to maximize memory retrieval performance.
|
| 360 |
+
|
| 361 |
+
Mathematical Framework:
|
| 362 |
+
- Split functions: s(x) = Ο((wΒ·x + b)/Ο) where Ο is sigmoid
|
| 363 |
+
- Path encoding: binary path through tree to leaf
|
| 364 |
+
- Success feedback: R β [0,1] from retrieval quality
|
| 365 |
+
- Parameter updates: ΞΈ β ΞΈ + Ξ·Β·β log P(success|path)
|
| 366 |
+
|
| 367 |
+
The tree learns to route queries to memory buckets where similar
|
| 368 |
+
items are most likely to be found, adapting based on retrieval success.
|
| 369 |
+
"""
|
| 370 |
+
def __init__(self, input_dim, max_depth=6, min_samples_split=2):
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.input_dim = input_dim
|
| 373 |
+
self.max_depth = max_depth
|
| 374 |
+
self.min_samples_split = min_samples_split
|
| 375 |
+
|
| 376 |
+
# Maximum number of internal nodes (2^max_depth - 1)
|
| 377 |
+
max_nodes = 2**max_depth - 1
|
| 378 |
+
|
| 379 |
+
# Learnable split functions for each internal node
|
| 380 |
+
self.split_weights = nn.Parameter(torch.randn(max_nodes, input_dim) * 0.1)
|
| 381 |
+
self.split_biases = nn.Parameter(torch.zeros(max_nodes))
|
| 382 |
+
self.split_temperatures = nn.Parameter(torch.ones(max_nodes))
|
| 383 |
+
|
| 384 |
+
# Initialize parameters for stable splits
|
| 385 |
+
with torch.no_grad():
|
| 386 |
+
self.split_temperatures.data.mul_(0.6) # Lower temp = sharper splits
|
| 387 |
+
self.split_biases.data.add_(0.01 * torch.randn_like(self.split_biases))
|
| 388 |
+
|
| 389 |
+
# Node tracking and statistics
|
| 390 |
+
self.register_buffer('node_active', torch.zeros(max_nodes, dtype=torch.bool))
|
| 391 |
+
self.register_buffer('node_samples', torch.zeros(max_nodes))
|
| 392 |
+
|
| 393 |
+
# Bucket assignment mappings
|
| 394 |
+
self.leaf_to_bucket = {}
|
| 395 |
+
self.bucket_to_leaves = defaultdict(list)
|
| 396 |
+
|
| 397 |
+
# Initialize root node as active
|
| 398 |
+
self.node_active[0] = True
|
| 399 |
+
|
| 400 |
+
def get_node_split(self, node_idx, x):
|
| 401 |
+
"""Compute split probability for node given input.
|
| 402 |
+
|
| 403 |
+
Evaluates the learned split function at a specific node to determine
|
| 404 |
+
routing probability (left vs right child).
|
| 405 |
+
|
| 406 |
+
Mathematical Details:
|
| 407 |
+
- Split score: s = wΒ·x + b
|
| 408 |
+
- Temperature scaling: s' = s/Ο
|
| 409 |
+
- Probability: p = Ο(s') where Ο is sigmoid
|
| 410 |
+
- p > 0.5 β go right, p β€ 0.5 β go left
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
node_idx: Index of tree node
|
| 414 |
+
x: Input feature vector [batch_size?, input_dim]
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Split probabilities [batch_size] (probability of going right)
|
| 418 |
+
"""
|
| 419 |
+
if node_idx >= len(self.split_weights):
|
| 420 |
+
return torch.zeros(x.shape[0], device=x.device)
|
| 421 |
+
|
| 422 |
+
weights = self.split_weights[node_idx]
|
| 423 |
+
bias = self.split_biases[node_idx]
|
| 424 |
+
temp = torch.clamp(self.split_temperatures[node_idx], 0.1, 10.0)
|
| 425 |
+
|
| 426 |
+
split_score = torch.matmul(x, weights) + bias
|
| 427 |
+
split_prob = torch.sigmoid(split_score / temp)
|
| 428 |
+
|
| 429 |
+
return split_prob
|
| 430 |
+
|
| 431 |
+
def route_to_leaf(self, x, deterministic=False):
|
| 432 |
+
"""Route input through tree to leaf node.
|
| 433 |
+
|
| 434 |
+
Traverses the decision tree from root to leaf, making routing
|
| 435 |
+
decisions at each internal node based on learned split functions.
|
| 436 |
+
|
| 437 |
+
Tree Traversal:
|
| 438 |
+
- Start at root (index 0)
|
| 439 |
+
- At each node, compute split probability
|
| 440 |
+
- Go left (2*i + 1) or right (2*i + 2) based on probability
|
| 441 |
+
- Continue until reaching leaf at max_depth
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
x: Input features [batch_size, input_dim]
|
| 445 |
+
deterministic: If True, use deterministic splits (p > 0.5)
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Tuple of (leaf_nodes, routing_paths)
|
| 449 |
+
"""
|
| 450 |
+
batch_size = x.shape[0]
|
| 451 |
+
device = x.device
|
| 452 |
+
|
| 453 |
+
# Start at root node
|
| 454 |
+
current_nodes = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 455 |
+
paths = torch.zeros(batch_size, self.max_depth, dtype=torch.long, device=device)
|
| 456 |
+
|
| 457 |
+
# Traverse tree to leaf depth
|
| 458 |
+
for depth in range(self.max_depth - 1):
|
| 459 |
+
split_probs = torch.zeros(batch_size, device=device)
|
| 460 |
+
|
| 461 |
+
# Compute split probabilities for current nodes
|
| 462 |
+
for i in range(batch_size):
|
| 463 |
+
node_idx = int(current_nodes[i].item())
|
| 464 |
+
if self.node_active[node_idx]:
|
| 465 |
+
split_probs[i] = self.get_node_split(node_idx, x[i:i+1]).squeeze()
|
| 466 |
+
|
| 467 |
+
# Make routing decisions
|
| 468 |
+
if deterministic:
|
| 469 |
+
go_right = (split_probs > 0.5).long()
|
| 470 |
+
else:
|
| 471 |
+
go_right = torch.bernoulli(split_probs).long()
|
| 472 |
+
|
| 473 |
+
paths[:, depth] = go_right
|
| 474 |
+
|
| 475 |
+
# Update current nodes using heap indexing
|
| 476 |
+
current_nodes = current_nodes * 2 + 1 + go_right
|
| 477 |
+
|
| 478 |
+
return current_nodes, paths
|
| 479 |
+
|
| 480 |
+
def assign_leaf_to_bucket(self, leaf_idx, bucket_idx):
|
| 481 |
+
"""Assign tree leaf to memory bucket for storage routing.
|
| 482 |
+
|
| 483 |
+
Creates bidirectional mapping between tree leaves and memory buckets
|
| 484 |
+
to enable routing queries to appropriate storage locations.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
leaf_idx: Tree leaf index
|
| 488 |
+
bucket_idx: Memory bucket index
|
| 489 |
+
"""
|
| 490 |
+
self.leaf_to_bucket[int(leaf_idx.item())] = int(bucket_idx)
|
| 491 |
+
self.bucket_to_leaves[int(bucket_idx)].append(int(leaf_idx.item()))
|
| 492 |
+
|
| 493 |
+
def get_bucket_for_input(self, x, deterministic=True):
|
| 494 |
+
"""Route input to appropriate memory bucket via tree traversal.
|
| 495 |
+
|
| 496 |
+
Uses the learned routing tree to determine which memory bucket
|
| 497 |
+
should store/retrieve items for the given input.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
x: Input features [batch_size, input_dim]
|
| 501 |
+
deterministic: Whether to use deterministic routing
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
Bucket indices [batch_size]
|
| 505 |
+
"""
|
| 506 |
+
leaf_nodes, _ = self.route_to_leaf(x, deterministic=deterministic)
|
| 507 |
+
|
| 508 |
+
bucket_assignments = []
|
| 509 |
+
for leaf in leaf_nodes:
|
| 510 |
+
bucket_idx = self.leaf_to_bucket.get(int(leaf.item()), 0)
|
| 511 |
+
bucket_assignments.append(bucket_idx)
|
| 512 |
+
|
| 513 |
+
return torch.tensor(bucket_assignments, device=x.device)
|
| 514 |
+
|
| 515 |
+
def update_node_statistics(self, x, rewards):
|
| 516 |
+
"""Update tree parameters based on retrieval success feedback.
|
| 517 |
+
|
| 518 |
+
Implements success-based learning where tree parameters are updated
|
| 519 |
+
to reinforce routing decisions that lead to successful retrievals.
|
| 520 |
+
|
| 521 |
+
Learning Algorithm:
|
| 522 |
+
1. Trace path through tree for each input
|
| 523 |
+
2. For each node on successful paths, reinforce split decision
|
| 524 |
+
3. For each node on unsuccessful paths, weaken split decision
|
| 525 |
+
4. Update sample counts and node activation
|
| 526 |
+
|
| 527 |
+
Mathematical Details:
|
| 528 |
+
- Success reinforcement: bias β bias + Ξ±Β·sign(reward - 0.5)
|
| 529 |
+
- Learning rate Ξ± = 0.01 for stable updates
|
| 530 |
+
- Binary rewards: >0.5 = success, β€0.5 = failure
|
| 531 |
+
|
| 532 |
+
Args:
|
| 533 |
+
x: Input features [batch_size, input_dim]
|
| 534 |
+
rewards: Retrieval success scores [batch_size] β [0,1]
|
| 535 |
+
"""
|
| 536 |
+
leaf_nodes, paths = self.route_to_leaf(x, deterministic=True)
|
| 537 |
+
|
| 538 |
+
# Update parameters based on success feedback
|
| 539 |
+
for i in range(x.shape[0]):
|
| 540 |
+
current_node = 0
|
| 541 |
+
reward = rewards[i].item() if torch.is_tensor(rewards[i]) else rewards[i]
|
| 542 |
+
|
| 543 |
+
# Traverse path and update nodes
|
| 544 |
+
for depth in range(self.max_depth - 1):
|
| 545 |
+
if current_node < len(self.node_samples):
|
| 546 |
+
# Update statistics
|
| 547 |
+
self.node_samples[current_node] += 1
|
| 548 |
+
self.node_active[current_node] = True
|
| 549 |
+
|
| 550 |
+
# Reinforce successful splits, weaken unsuccessful ones
|
| 551 |
+
if reward > 0.5: # Successful retrieval
|
| 552 |
+
direction = paths[i, depth]
|
| 553 |
+
if direction == 1: # Went right - reinforce positive bias
|
| 554 |
+
self.split_biases.data[current_node] += 0.01
|
| 555 |
+
else: # Went left - reinforce negative bias
|
| 556 |
+
self.split_biases.data[current_node] -= 0.01
|
| 557 |
+
|
| 558 |
+
# Move to next node in path
|
| 559 |
+
direction = paths[i, depth] if depth < paths.shape[1] else 0
|
| 560 |
+
current_node = current_node * 2 + 1 + int(direction.item())
|
| 561 |
+
|
| 562 |
+
if current_node >= 2**self.max_depth - 1:
|
| 563 |
+
break
|
| 564 |
+
|
| 565 |
+
###########################################################################################################################################
|
| 566 |
+
##################################################- - - MEMORY FOREST - - -############################################################
|
| 567 |
+
|
| 568 |
+
class MemoryForest(nn.Module):
|
| 569 |
+
"""Complete memory forest system with ensemble routing and associative storage.
|
| 570 |
+
|
| 571 |
+
Implements the full Memory Forest architecture combining multiple decision
|
| 572 |
+
trees for routing with associative hash buckets for storage. Uses ensemble
|
| 573 |
+
voting across trees and success-based adaptation of routing decisions.
|
| 574 |
+
|
| 575 |
+
System Architecture:
|
| 576 |
+
1. Multiple decision trees learn different routing strategies
|
| 577 |
+
2. Shared memory buckets store items with associative clustering
|
| 578 |
+
3. Feature encoder maps inputs to embedding space
|
| 579 |
+
4. Ensemble retrieval combines votes from all trees
|
| 580 |
+
5. Success feedback adapts tree routing over time
|
| 581 |
+
|
| 582 |
+
The system learns to organize memory hierarchically, with trees discovering
|
| 583 |
+
optimal routing patterns and buckets clustering similar items.
|
| 584 |
+
"""
|
| 585 |
+
def __init__(self, input_dim, num_trees=5, max_depth=6, bucket_size=64, embedding_dim=128):
|
| 586 |
+
super().__init__()
|
| 587 |
+
self.input_dim = input_dim
|
| 588 |
+
self.num_trees = num_trees
|
| 589 |
+
self.embedding_dim = embedding_dim
|
| 590 |
+
|
| 591 |
+
# Multiple decision trees for ensemble routing
|
| 592 |
+
self.trees = nn.ModuleList([
|
| 593 |
+
MemoryDecisionTree(input_dim, max_depth) for _ in range(num_trees)
|
| 594 |
+
])
|
| 595 |
+
|
| 596 |
+
# Shared memory buckets across all trees
|
| 597 |
+
self.num_buckets = num_trees * (2**max_depth)
|
| 598 |
+
self.buckets = nn.ModuleList([
|
| 599 |
+
AssociativeHashBucket(bucket_size, embedding_dim) for _ in range(self.num_buckets)
|
| 600 |
+
])
|
| 601 |
+
|
| 602 |
+
# Feature encoder: maps raw inputs to embedding space
|
| 603 |
+
self.feature_encoder = nn.Sequential(
|
| 604 |
+
nn.Linear(input_dim, embedding_dim),
|
| 605 |
+
nn.LayerNorm(embedding_dim),
|
| 606 |
+
nn.ReLU(),
|
| 607 |
+
nn.Linear(embedding_dim, embedding_dim)
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# Initialize bucket assignments for tree leaves
|
| 611 |
+
self._initialize_bucket_assignments()
|
| 612 |
+
|
| 613 |
+
def _initialize_bucket_assignments(self):
|
| 614 |
+
"""Initialize mapping from tree leaves to memory buckets.
|
| 615 |
+
|
| 616 |
+
Creates systematic assignment of tree leaves to buckets to ensure
|
| 617 |
+
good distribution and avoid conflicts between trees.
|
| 618 |
+
|
| 619 |
+
Assignment Strategy:
|
| 620 |
+
- Each tree gets a separate range of buckets
|
| 621 |
+
- Leaf nodes mapped to buckets in order
|
| 622 |
+
- Ensures no bucket conflicts between trees
|
| 623 |
+
"""
|
| 624 |
+
bucket_idx = 0
|
| 625 |
+
for tree_idx, tree in enumerate(self.trees):
|
| 626 |
+
# Leaf nodes are in range [2^(D-1)-1, 2^D-2] for depth D
|
| 627 |
+
start_leaf = 2**(tree.max_depth - 1) - 1
|
| 628 |
+
end_leaf = 2**tree.max_depth - 2
|
| 629 |
+
|
| 630 |
+
for leaf in range(start_leaf, end_leaf + 1):
|
| 631 |
+
if bucket_idx < self.num_buckets:
|
| 632 |
+
tree.assign_leaf_to_bucket(torch.tensor(leaf), bucket_idx)
|
| 633 |
+
bucket_idx += 1
|
| 634 |
+
|
| 635 |
+
def store(self, features, items=None):
|
| 636 |
+
"""Store items in memory forest using learned routing.
|
| 637 |
+
|
| 638 |
+
Storage Process:
|
| 639 |
+
1. Encode features to embedding space
|
| 640 |
+
2. Route through each tree to get bucket assignments
|
| 641 |
+
3. Store in assigned buckets with associative clustering
|
| 642 |
+
4. Return storage locations for tracking
|
| 643 |
+
|
| 644 |
+
Multiple trees may route the same item to different buckets,
|
| 645 |
+
creating redundancy that improves retrieval robustness.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
features: Input features [batch_size, input_dim]
|
| 649 |
+
items: Items to store (defaults to features) [batch_size, input_dim]
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
List of (bucket_id, storage_indices) tuples
|
| 653 |
+
"""
|
| 654 |
+
if items is None:
|
| 655 |
+
items = features
|
| 656 |
+
|
| 657 |
+
# Encode features to embedding space
|
| 658 |
+
embeddings = self.feature_encoder(features)
|
| 659 |
+
|
| 660 |
+
storage_results = []
|
| 661 |
+
|
| 662 |
+
# Route through each tree and store in assigned buckets
|
| 663 |
+
for tree in self.trees:
|
| 664 |
+
bucket_assignments = tree.get_bucket_for_input(features, deterministic=False)
|
| 665 |
+
|
| 666 |
+
for i, b_idx in enumerate(bucket_assignments.tolist()):
|
| 667 |
+
if b_idx < len(self.buckets):
|
| 668 |
+
stored_idx = self.buckets[b_idx].store_item(embeddings[i])
|
| 669 |
+
storage_results.append((b_idx, stored_idx))
|
| 670 |
+
|
| 671 |
+
return storage_results
|
| 672 |
+
|
| 673 |
+
def retrieve(self, query_features, top_k=5):
|
| 674 |
+
"""Retrieve similar items using ensemble voting across trees.
|
| 675 |
+
|
| 676 |
+
Retrieval Process:
|
| 677 |
+
1. Encode query features to embedding space
|
| 678 |
+
2. Route queries through all trees to get bucket candidates
|
| 679 |
+
3. Retrieve similar items from each candidate bucket
|
| 680 |
+
4. Aggregate results using ensemble voting
|
| 681 |
+
5. Rank by similarity scores and return top-k
|
| 682 |
+
|
| 683 |
+
Ensemble Strategy:
|
| 684 |
+
- Each tree votes for items from its assigned bucket
|
| 685 |
+
- Items receive votes from multiple trees if routed similarly
|
| 686 |
+
- Final ranking combines similarity scores across votes
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
query_features: Query feature vectors [batch_size, input_dim]
|
| 690 |
+
top_k: Number of most similar items to return
|
| 691 |
+
|
| 692 |
+
Returns:
|
| 693 |
+
List of (retrieved_items, similarity_scores) for each query
|
| 694 |
+
"""
|
| 695 |
+
query_embeddings = self.feature_encoder(query_features)
|
| 696 |
+
|
| 697 |
+
# Collect votes from all trees
|
| 698 |
+
bucket_votes = defaultdict(list)
|
| 699 |
+
|
| 700 |
+
for tree in self.trees:
|
| 701 |
+
bucket_assignments = tree.get_bucket_for_input(query_features, deterministic=True)
|
| 702 |
+
|
| 703 |
+
for i, b_idx in enumerate(bucket_assignments.tolist()):
|
| 704 |
+
if b_idx < len(self.buckets):
|
| 705 |
+
retrieved_items, similarities = self.buckets[b_idx].retrieve_similar(
|
| 706 |
+
query_embeddings[i], top_k=top_k
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
if len(retrieved_items) > 0:
|
| 710 |
+
# Store items with both float and tensor similarities
|
| 711 |
+
float_sims = similarities.detach().cpu().tolist()
|
| 712 |
+
for itm, sim_t, sim_f in zip(retrieved_items, similarities, float_sims):
|
| 713 |
+
bucket_votes[i].append((itm, sim_f, sim_t))
|
| 714 |
+
|
| 715 |
+
# Aggregate ensemble results
|
| 716 |
+
final_results = []
|
| 717 |
+
for query_idx in range(query_features.shape[0]):
|
| 718 |
+
if query_idx in bucket_votes and len(bucket_votes[query_idx]) > 0:
|
| 719 |
+
# Sort candidates by similarity score
|
| 720 |
+
candidates = bucket_votes[query_idx]
|
| 721 |
+
candidates.sort(key=lambda x: x[1], reverse=True)
|
| 722 |
+
|
| 723 |
+
# Extract top-k results
|
| 724 |
+
top_candidates = candidates[:top_k]
|
| 725 |
+
items = [c[0] for c in top_candidates]
|
| 726 |
+
sims_t = [c[2] for c in top_candidates]
|
| 727 |
+
final_results.append((torch.stack(items), torch.stack(sims_t)))
|
| 728 |
+
else:
|
| 729 |
+
# No results found
|
| 730 |
+
final_results.append((torch.tensor([]), torch.tensor([])))
|
| 731 |
+
|
| 732 |
+
return final_results
|
| 733 |
+
|
| 734 |
+
def update_routing(self, features, retrieval_success):
|
| 735 |
+
"""Update tree routing based on retrieval success feedback.
|
| 736 |
+
|
| 737 |
+
Implements the learning component where trees adapt their routing
|
| 738 |
+
decisions based on how successful retrievals were. This enables
|
| 739 |
+
the forest to optimize its organization over time.
|
| 740 |
+
|
| 741 |
+
Learning Process:
|
| 742 |
+
1. Trees receive feedback on routing decisions
|
| 743 |
+
2. Successful routes are reinforced
|
| 744 |
+
3. Unsuccessful routes are weakened
|
| 745 |
+
4. Parameters updated via gradient-free reinforcement
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
features: Input features that were queried [batch_size, input_dim]
|
| 749 |
+
retrieval_success: Success scores [batch_size] β [0,1]
|
| 750 |
+
"""
|
| 751 |
+
for tree in self.trees:
|
| 752 |
+
tree.update_node_statistics(features, retrieval_success)
|
| 753 |
+
|
| 754 |
+
def get_forest_stats(self):
|
| 755 |
+
"""Get comprehensive statistics about the memory forest state.
|
| 756 |
+
|
| 757 |
+
Provides detailed information about forest utilization, tree states,
|
| 758 |
+
bucket occupancy, and overall system health for monitoring.
|
| 759 |
+
|
| 760 |
+
Returns:
|
| 761 |
+
Dictionary with complete forest statistics
|
| 762 |
+
"""
|
| 763 |
+
stats = {
|
| 764 |
+
'num_trees': self.num_trees,
|
| 765 |
+
'num_buckets': self.num_buckets,
|
| 766 |
+
'bucket_stats': [],
|
| 767 |
+
'tree_stats': []
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
# Collect bucket statistics
|
| 771 |
+
for i, bucket in enumerate(self.buckets):
|
| 772 |
+
bucket_stat = bucket.get_bucket_stats()
|
| 773 |
+
bucket_stat['bucket_id'] = i
|
| 774 |
+
stats['bucket_stats'].append(bucket_stat)
|
| 775 |
+
|
| 776 |
+
# Collect tree statistics
|
| 777 |
+
for i, tree in enumerate(self.trees):
|
| 778 |
+
tree_stat = {
|
| 779 |
+
'tree_id': i,
|
| 780 |
+
'active_nodes': tree.node_active.sum().item(),
|
| 781 |
+
'total_samples': tree.node_samples.sum().item(),
|
| 782 |
+
'max_depth': tree.max_depth
|
| 783 |
+
}
|
| 784 |
+
stats['tree_stats'].append(tree_stat)
|
| 785 |
+
|
| 786 |
+
return stats
|
| 787 |
+
|
| 788 |
+
def forward(self, features, items=None, mode='store'):
|
| 789 |
+
"""Unified forward interface for storage and retrieval operations.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
features: Input feature vectors
|
| 793 |
+
items: Items to store (for store mode)
|
| 794 |
+
mode: 'store' or 'retrieve'
|
| 795 |
+
|
| 796 |
+
Returns:
|
| 797 |
+
Storage results or retrieval results based on mode
|
| 798 |
+
"""
|
| 799 |
+
if mode == 'store':
|
| 800 |
+
return self.store(features, items)
|
| 801 |
+
elif mode == 'retrieve':
|
| 802 |
+
return self.retrieve(features)
|
| 803 |
+
else:
|
| 804 |
+
raise ValueError("Mode must be 'store' or 'retrieve'")
|
| 805 |
+
|
| 806 |
+
###########################################################################################################################################
|
| 807 |
+
####################################################- - - DEMO AND TESTING - - -#######################################################
|
| 808 |
+
|
| 809 |
+
def test_memory_forest():
|
| 810 |
+
"""Comprehensive test of Memory Forest functionality and performance."""
|
| 811 |
+
print(" Testing Memory Forest - Associative Memory with Learned Routing")
|
| 812 |
+
print("=" * 70)
|
| 813 |
+
|
| 814 |
+
# Create memory forest system
|
| 815 |
+
input_dim = 64
|
| 816 |
+
embedding_dim = 128
|
| 817 |
+
forest = MemoryForest(
|
| 818 |
+
input_dim=input_dim,
|
| 819 |
+
num_trees=3,
|
| 820 |
+
max_depth=4,
|
| 821 |
+
bucket_size=32,
|
| 822 |
+
embedding_dim=embedding_dim
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
print(f"Created Memory Forest:")
|
| 826 |
+
print(f" - Input dimension: {input_dim}")
|
| 827 |
+
print(f" - Embedding dimension: {embedding_dim}")
|
| 828 |
+
print(f" - Number of trees: {forest.num_trees}")
|
| 829 |
+
print(f" - Tree depth: 4")
|
| 830 |
+
print(f" - Total buckets: {forest.num_buckets}")
|
| 831 |
+
print(f" - Bucket capacity: 32 items each")
|
| 832 |
+
|
| 833 |
+
# Generate test data with some structure for meaningful clustering
|
| 834 |
+
print(f"\n Generating structured test data...")
|
| 835 |
+
num_items = 100
|
| 836 |
+
|
| 837 |
+
# Create clustered data (3 clusters)
|
| 838 |
+
cluster_centers = torch.randn(3, input_dim) * 2
|
| 839 |
+
test_features = []
|
| 840 |
+
|
| 841 |
+
for _ in range(num_items):
|
| 842 |
+
cluster_id = torch.randint(0, 3, (1,)).item()
|
| 843 |
+
noise = torch.randn(input_dim) * 0.5
|
| 844 |
+
item = cluster_centers[cluster_id] + noise
|
| 845 |
+
test_features.append(item)
|
| 846 |
+
|
| 847 |
+
test_features = torch.stack(test_features)
|
| 848 |
+
print(f" - Generated {num_items} items in 3 clusters")
|
| 849 |
+
print(f" - Feature dimension: {input_dim}")
|
| 850 |
+
|
| 851 |
+
# Test storage
|
| 852 |
+
print(f"\n Testing storage operations...")
|
| 853 |
+
storage_results = forest.store(test_features)
|
| 854 |
+
|
| 855 |
+
unique_buckets = len(set(r[0] for r in storage_results))
|
| 856 |
+
print(f" - Stored {num_items} items")
|
| 857 |
+
print(f" - Used {unique_buckets} different buckets")
|
| 858 |
+
print(f" - Average items per bucket: {len(storage_results) / unique_buckets:.1f}")
|
| 859 |
+
|
| 860 |
+
# Test retrieval without learning
|
| 861 |
+
print(f"\n Testing retrieval (before learning)...")
|
| 862 |
+
query_features = test_features[:5] # Use first 5 items as queries
|
| 863 |
+
|
| 864 |
+
retrieval_results = forest.retrieve(query_features, top_k=3)
|
| 865 |
+
|
| 866 |
+
initial_success_count = 0
|
| 867 |
+
print("Initial retrieval results:")
|
| 868 |
+
for i, (items, similarities) in enumerate(retrieval_results):
|
| 869 |
+
if len(items) > 0:
|
| 870 |
+
best_sim = similarities[0].item()
|
| 871 |
+
success = best_sim > 0.8 # Threshold for "good" retrieval
|
| 872 |
+
print(f" Query {i}: {len(items)} items, best similarity: {best_sim:.3f} {'β' if success else 'β'}")
|
| 873 |
+
if success:
|
| 874 |
+
initial_success_count += 1
|
| 875 |
+
else:
|
| 876 |
+
print(f" Query {i}: No items retrieved β")
|
| 877 |
+
|
| 878 |
+
initial_success_rate = initial_success_count / len(query_features)
|
| 879 |
+
print(f" Initial success rate: {initial_success_rate:.1%}")
|
| 880 |
+
|
| 881 |
+
# Test adaptive learning
|
| 882 |
+
print(f"\n Testing adaptive learning...")
|
| 883 |
+
print("Simulating retrieval feedback and tree adaptation...")
|
| 884 |
+
|
| 885 |
+
# Simulate multiple rounds of feedback
|
| 886 |
+
for round_num in range(3):
|
| 887 |
+
# Generate random retrieval success scores (biased toward improvement)
|
| 888 |
+
retrieval_success = torch.rand(len(query_features)) * 0.6 + 0.3
|
| 889 |
+
|
| 890 |
+
# Update tree routing based on feedback
|
| 891 |
+
forest.update_routing(query_features, retrieval_success)
|
| 892 |
+
|
| 893 |
+
print(f" Round {round_num + 1}: Updated trees with feedback")
|
| 894 |
+
|
| 895 |
+
# Test retrieval after learning
|
| 896 |
+
print(f"\n Testing retrieval (after learning)...")
|
| 897 |
+
learned_results = forest.retrieve(query_features, top_k=3)
|
| 898 |
+
|
| 899 |
+
learned_success_count = 0
|
| 900 |
+
print("Post-learning retrieval results:")
|
| 901 |
+
for i, (items, similarities) in enumerate(learned_results):
|
| 902 |
+
if len(items) > 0:
|
| 903 |
+
best_sim = similarities[0].item()
|
| 904 |
+
success = best_sim > 0.8
|
| 905 |
+
print(f" Query {i}: {len(items)} items, best similarity: {best_sim:.3f} {'β' if success else 'β'}")
|
| 906 |
+
if success:
|
| 907 |
+
learned_success_count += 1
|
| 908 |
+
else:
|
| 909 |
+
print(f" Query {i}: No items retrieved β")
|
| 910 |
+
|
| 911 |
+
learned_success_rate = learned_success_count / len(query_features)
|
| 912 |
+
improvement = learned_success_rate - initial_success_rate
|
| 913 |
+
print(f" Post-learning success rate: {learned_success_rate:.1%}")
|
| 914 |
+
print(f" Improvement: {improvement:+.1%}")
|
| 915 |
+
|
| 916 |
+
# Analyze forest statistics
|
| 917 |
+
print(f"\n Forest analysis:")
|
| 918 |
+
stats = forest.get_forest_stats()
|
| 919 |
+
|
| 920 |
+
avg_bucket_occupancy = np.mean([b['occupancy_rate'] for b in stats['bucket_stats']])
|
| 921 |
+
total_accesses = sum(b['total_accesses'] for b in stats['bucket_stats'])
|
| 922 |
+
active_nodes = sum(t['active_nodes'] for t in stats['tree_stats'])
|
| 923 |
+
|
| 924 |
+
print(f" - Average bucket occupancy: {avg_bucket_occupancy:.1%}")
|
| 925 |
+
print(f" - Total bucket accesses: {total_accesses}")
|
| 926 |
+
print(f" - Active tree nodes: {active_nodes}")
|
| 927 |
+
|
| 928 |
+
# Test different query types
|
| 929 |
+
print(f"\n Testing query diversity...")
|
| 930 |
+
|
| 931 |
+
# Similar query (from stored data)
|
| 932 |
+
similar_query = test_features[10:11] # Known stored item
|
| 933 |
+
similar_results = forest.retrieve(similar_query, top_k=3)
|
| 934 |
+
similar_best = similar_results[0][1][0].item() if len(similar_results[0][1]) > 0 else 0
|
| 935 |
+
|
| 936 |
+
# Random query (not from stored data)
|
| 937 |
+
random_query = torch.randn(1, input_dim)
|
| 938 |
+
random_results = forest.retrieve(random_query, top_k=3)
|
| 939 |
+
random_best = random_results[0][1][0].item() if len(random_results[0][1]) > 0 else 0
|
| 940 |
+
|
| 941 |
+
print(f" - Known item query similarity: {similar_best:.3f}")
|
| 942 |
+
print(f" - Random query similarity: {random_best:.3f}")
|
| 943 |
+
print(f" - Discrimination ratio: {similar_best / max(random_best, 0.01):.1f}x")
|
| 944 |
+
|
| 945 |
+
print(f"\n Memory Forest test completed!")
|
| 946 |
+
print("β Hierarchical memory organization with learned routing")
|
| 947 |
+
print("β Associative storage with similarity clustering")
|
| 948 |
+
print("β Ensemble retrieval across multiple trees")
|
| 949 |
+
print("β Adaptive routing based on retrieval success")
|
| 950 |
+
print("β Efficient O(log n) routing instead of O(n) search")
|
| 951 |
+
print("β Scalable architecture for large memory systems")
|
| 952 |
+
|
| 953 |
+
return True
|
| 954 |
+
|
| 955 |
+
def simple_demo():
|
| 956 |
+
"""Simple demonstration with clear patterns."""
|
| 957 |
+
print("\n" + "="*50)
|
| 958 |
+
print(" MEMORY FOREST SIMPLE DEMO")
|
| 959 |
+
print("="*50)
|
| 960 |
+
|
| 961 |
+
# Create small forest for clear demonstration
|
| 962 |
+
forest = MemoryForest(input_dim=8, num_trees=2, max_depth=3, bucket_size=16, embedding_dim=32)
|
| 963 |
+
|
| 964 |
+
# Create simple patterns that should cluster together
|
| 965 |
+
patterns = torch.tensor([
|
| 966 |
+
[1, 0, 1, 0, 1, 0, 1, 0], # Pattern A (alternating)
|
| 967 |
+
[0, 1, 0, 1, 0, 1, 0, 1], # Pattern B (inverse alternating)
|
| 968 |
+
[1, 1, 0, 0, 1, 1, 0, 0], # Pattern C (pairs)
|
| 969 |
+
[0, 0, 1, 1, 0, 0, 1, 1], # Pattern D (inverse pairs)
|
| 970 |
+
[1, 0, 1, 0, 1, 0, 1, 1], # Pattern A variant
|
| 971 |
+
[0, 1, 0, 1, 0, 1, 0, 0], # Pattern B variant
|
| 972 |
+
], dtype=torch.float32)
|
| 973 |
+
|
| 974 |
+
print("Storing 6 distinct patterns...")
|
| 975 |
+
print(" - 2 alternating patterns (A, B)")
|
| 976 |
+
print(" - 2 pair patterns (C, D)")
|
| 977 |
+
print(" - 2 pattern variants")
|
| 978 |
+
|
| 979 |
+
# Store patterns
|
| 980 |
+
forest.store(patterns)
|
| 981 |
+
|
| 982 |
+
# Test exact pattern retrieval
|
| 983 |
+
print("\nTesting exact pattern retrieval:")
|
| 984 |
+
results = forest.retrieve(patterns[:4]) # Query first 4 patterns
|
| 985 |
+
|
| 986 |
+
for i, (items, sims) in enumerate(results):
|
| 987 |
+
if len(items) > 0:
|
| 988 |
+
best_sim = sims[0].item()
|
| 989 |
+
print(f" Pattern {i}: Found {len(items)} matches, best similarity: {best_sim:.3f}")
|
| 990 |
+
else:
|
| 991 |
+
print(f" Pattern {i}: No matches found")
|
| 992 |
+
|
| 993 |
+
# Test noisy pattern retrieval
|
| 994 |
+
print("\nTesting noisy pattern retrieval:")
|
| 995 |
+
noisy_patterns = patterns[:2] + 0.1 * torch.randn_like(patterns[:2])
|
| 996 |
+
noisy_results = forest.retrieve(noisy_patterns)
|
| 997 |
+
|
| 998 |
+
for i, (items, sims) in enumerate(noisy_results):
|
| 999 |
+
if len(items) > 0:
|
| 1000 |
+
best_sim = sims[0].item()
|
| 1001 |
+
print(f" Noisy pattern {i}: Found {len(items)} matches, best similarity: {best_sim:.3f}")
|
| 1002 |
+
else:
|
| 1003 |
+
print(f" Noisy pattern {i}: No matches found")
|
| 1004 |
+
|
| 1005 |
+
# Show forest organization
|
| 1006 |
+
stats = forest.get_forest_stats()
|
| 1007 |
+
used_buckets = sum(1 for b in stats['bucket_stats'] if b['occupancy_rate'] > 0)
|
| 1008 |
+
print(f"\nForest organization:")
|
| 1009 |
+
print(f" - Used {used_buckets} buckets out of {len(stats['bucket_stats'])}")
|
| 1010 |
+
print(f" - Trees routed patterns to different memory locations")
|
| 1011 |
+
print(f" - Associative clustering groups similar patterns")
|
| 1012 |
+
|
| 1013 |
+
print("\n Demo completed. Memory Forest successfully organized and retrieved patterns.")
|
| 1014 |
+
|
| 1015 |
+
if __name__ == "__main__":
|
| 1016 |
+
test_memory_forest()
|
| 1017 |
+
simple_demo()
|
| 1018 |
+
|
| 1019 |
+
###########################################################################################################################################
|
| 1020 |
+
###########################################################################################################################################
|