memory_forest / memory_forest.py
1990two's picture
Upload 2 files
233f515 verified
#############################################################################################################################################
#||||- - - |6.25.2025| - - - || MEMORY FOREST || - - - |memory_forest.py| - - -||||#
#############################################################################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from collections import defaultdict, deque
from typing import List, Dict, Tuple, Optional
SAFE_MIN = -1e6
SAFE_MAX = 1e6
EPS = 1e-8
#||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 𓅸 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||#
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX):
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor)
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor)
return torch.clamp(tensor, min_val, max_val)
def safe_cosine_similarity(a, b, dim=-1, eps=EPS):
if a.dtype != torch.float32:
a = a.float()
if b.dtype != torch.float32:
b = b.float()
a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps)
b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps)
return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm)
#############################################################################################################################################
###################################################- - - ASSOCIATIVE HASH BUCKET - - -###################################################
class AssociativeHashBucket(nn.Module):
def __init__(self, bucket_size=64, embedding_dim=128, num_hash_functions=4):
super().__init__()
self.bucket_size = bucket_size
self.embedding_dim = embedding_dim
self.num_hash_functions = num_hash_functions
self.hash_projections = nn.ModuleList([
nn.Linear(embedding_dim, 1, bias=True) for _ in range(num_hash_functions)
])
self.register_buffer('stored_items', torch.zeros(bucket_size, embedding_dim))
self.register_buffer('item_hashes', torch.zeros(bucket_size, num_hash_functions))
self.register_buffer('occupancy', torch.zeros(bucket_size, dtype=torch.bool))
self.register_buffer('access_counts', torch.zeros(bucket_size))
self.similarity_threshold = nn.Parameter(torch.tensor(0.7))
self.decay_rate = nn.Parameter(torch.tensor(0.99))
self.storage_pointer = 0
def compute_hash_signature(self, item_embedding):
x = item_embedding
if x.dim() == 1:
x = x.unsqueeze(0)
signatures = []
for hash_proj in self.hash_projections:
sig = torch.tanh(hash_proj(x)).squeeze(-1) # (B,)
signatures.append(sig)
sigs = torch.stack(signatures, dim=-1) # (B, num_hash)
return sigs.squeeze(0)
def store_item(self, item_embedding, item_id=None):
if item_embedding.dim() == 1:
item_embedding = item_embedding.unsqueeze(0)
batch_size = item_embedding.shape[0]
stored_items = []
for i in range(batch_size):
embedding = item_embedding[i]
hash_sig = self.compute_hash_signature(embedding)
if self.occupancy.any():
similarities = safe_cosine_similarity(
embedding.unsqueeze(0),
self.stored_items[self.occupancy],
dim=-1
).squeeze()
threshold = torch.clamp(self.similarity_threshold, 0.1, 0.95)
if similarities.numel() > 0 and similarities.max() > threshold:
best_idx = self.occupancy.nonzero(as_tuple=False)[similarities.argmax()]
self.stored_items[best_idx] = 0.9 * self.stored_items[best_idx] + 0.1 * embedding
self.access_counts[best_idx] += 1
stored_items.append(int(best_idx.item()))
continue
if self.storage_pointer >= self.bucket_size:
if self.occupancy.any():
rel_idx = self.access_counts[self.occupancy].argmin()
evict_idx = self.occupancy.nonzero(as_tuple=False)[rel_idx]
else:
evict_idx = torch.tensor(0)
else:
evict_idx = torch.tensor(self.storage_pointer)
self.storage_pointer = min(self.storage_pointer + 1, self.bucket_size)
self.stored_items[evict_idx] = embedding
self.item_hashes[evict_idx] = hash_sig.squeeze()
self.occupancy[evict_idx] = True
self.access_counts[evict_idx] = 1
stored_items.append(int(evict_idx.item()))
return stored_items
def retrieve_similar(self, query_embedding, top_k=5):
if query_embedding.dim() == 1:
query_embedding = query_embedding.unsqueeze(0)
if not self.occupancy.any():
return [], []
valid_items = self.stored_items[self.occupancy]
valid_indices = self.occupancy.nonzero(as_tuple=False).squeeze(-1)
if valid_items.numel() == 0:
return [], []
similarities = safe_cosine_similarity(
query_embedding.expand(valid_items.shape[0], -1),
valid_items,
dim=-1
).squeeze(-1) # (N,)
if similarities.numel() == 0:
return [], []
k = min(top_k, similarities.size(0))
top_sims, top_indices = torch.topk(similarities, k)
retrieved_items = valid_items[top_indices]
retrieved_indices = valid_indices[top_indices]
for idx in retrieved_indices:
self.access_counts[idx] += 1
return retrieved_items, top_sims
def get_bucket_stats(self):
return {
'occupancy_rate': self.occupancy.float().mean().item(),
'total_accesses': self.access_counts.sum().item(),
'avg_similarity': self.similarity_threshold.item(),
'storage_pointer': self.storage_pointer
}
###########################################################################################################################################
################################################- - - MEMORY DECISION TREE - - -#######################################################
class MemoryDecisionTree(nn.Module):
def __init__(self, input_dim, max_depth=6, min_samples_split=2):
super().__init__()
self.input_dim = input_dim
self.max_depth = max_depth
self.min_samples_split = min_samples_split
max_nodes = 2**max_depth - 1
self.split_weights = nn.Parameter(torch.randn(max_nodes, input_dim) * 0.1)
self.split_biases = nn.Parameter(torch.zeros(max_nodes))
self.split_temperatures = nn.Parameter(torch.ones(max_nodes))
with torch.no_grad():
self.split_temperatures.data.mul_(0.6)
self.split_biases.data.add_(0.01 * torch.randn_like(self.split_biases))
self.register_buffer('node_active', torch.zeros(max_nodes, dtype=torch.bool))
self.register_buffer('node_samples', torch.zeros(max_nodes))
self.leaf_to_bucket = {}
self.bucket_to_leaves = defaultdict(list)
self.node_active[0] = True
def get_node_split(self, node_idx, x):
if node_idx >= len(self.split_weights):
return torch.zeros(x.shape[0], device=x.device)
weights = self.split_weights[node_idx]
bias = self.split_biases[node_idx]
temp = torch.clamp(self.split_temperatures[node_idx], 0.1, 10.0)
split_score = torch.matmul(x, weights) + bias
split_prob = torch.sigmoid(split_score / temp)
return split_prob
def route_to_leaf(self, x, deterministic=False):
batch_size = x.shape[0]
device = x.device
current_nodes = torch.zeros(batch_size, dtype=torch.long, device=device)
paths = torch.zeros(batch_size, self.max_depth, dtype=torch.long, device=device)
for depth in range(self.max_depth - 1):
split_probs = torch.zeros(batch_size, device=device)
for i in range(batch_size):
node_idx = int(current_nodes[i].item())
if self.node_active[node_idx]:
split_probs[i] = self.get_node_split(node_idx, x[i:i+1]).squeeze()
if deterministic:
go_right = (split_probs > 0.5).long()
else:
go_right = torch.bernoulli(split_probs).long()
paths[:, depth] = go_right
current_nodes = current_nodes * 2 + 1 + go_right
return current_nodes, paths
def assign_leaf_to_bucket(self, leaf_idx, bucket_idx):
self.leaf_to_bucket[int(leaf_idx.item())] = int(bucket_idx)
self.bucket_to_leaves[int(bucket_idx)].append(int(leaf_idx.item()))
def get_bucket_for_input(self, x, deterministic=True):
leaf_nodes, _ = self.route_to_leaf(x, deterministic=deterministic)
bucket_assignments = []
for leaf in leaf_nodes:
bucket_idx = self.leaf_to_bucket.get(int(leaf.item()), 0)
bucket_assignments.append(bucket_idx)
return torch.tensor(bucket_assignments, device=x.device)
def update_node_statistics(self, x, rewards):
leaf_nodes, paths = self.route_to_leaf(x, deterministic=True)
for i in range(x.shape[0]):
current_node = 0
reward = rewards[i].item() if torch.is_tensor(rewards[i]) else rewards[i]
for depth in range(self.max_depth - 1):
if current_node < len(self.node_samples):
self.node_samples[current_node] += 1
self.node_active[current_node] = True
if reward > 0.5:
direction = paths[i, depth]
if direction == 1:
self.split_biases.data[current_node] += 0.01
else:
self.split_biases.data[current_node] -= 0.01
direction = paths[i, depth] if depth < paths.shape[1] else 0
current_node = current_node * 2 + 1 + int(direction.item())
if current_node >= 2**self.max_depth - 1:
break
###########################################################################################################################################
##################################################- - - MEMORY FOREST - - -############################################################
class MemoryForest(nn.Module):
def __init__(self, input_dim, num_trees=5, max_depth=6, bucket_size=64, embedding_dim=128):
super().__init__()
self.input_dim = input_dim
self.num_trees = num_trees
self.embedding_dim = embedding_dim
self.trees = nn.ModuleList([
MemoryDecisionTree(input_dim, max_depth) for _ in range(num_trees)
])
self.num_buckets = num_trees * (2**max_depth)
self.buckets = nn.ModuleList([
AssociativeHashBucket(bucket_size, embedding_dim) for _ in range(self.num_buckets)
])
self.feature_encoder = nn.Sequential(
nn.Linear(input_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim)
)
self._initialize_bucket_assignments()
def _initialize_bucket_assignments(self):
bucket_idx = 0
for tree_idx, tree in enumerate(self.trees):
start_leaf = 2**(tree.max_depth - 1) - 1
end_leaf = 2**tree.max_depth - 2
for leaf in range(start_leaf, end_leaf + 1):
if bucket_idx < self.num_buckets:
tree.assign_leaf_to_bucket(torch.tensor(leaf), bucket_idx)
bucket_idx += 1
def store(self, features, items=None):
if items is None:
items = features
embeddings = self.feature_encoder(features)
storage_results = []
for tree in self.trees:
bucket_assignments = tree.get_bucket_for_input(features, deterministic=False)
for i, b_idx in enumerate(bucket_assignments.tolist()):
if b_idx < len(self.buckets):
stored_idx = self.buckets[b_idx].store_item(embeddings[i])
storage_results.append((b_idx, stored_idx))
return storage_results
def retrieve(self, query_features, top_k=5):
query_embeddings = self.feature_encoder(query_features)
bucket_votes = defaultdict(list)
for tree in self.trees:
bucket_assignments = tree.get_bucket_for_input(query_features, deterministic=True)
for i, b_idx in enumerate(bucket_assignments.tolist()):
if b_idx < len(self.buckets):
retrieved_items, similarities = self.buckets[b_idx].retrieve_similar(
query_embeddings[i], top_k=top_k
)
if len(retrieved_items) > 0:
float_sims = similarities.detach().cpu().tolist()
for itm, sim_t, sim_f in zip(retrieved_items, similarities, float_sims):
bucket_votes[i].append((itm, sim_f, sim_t))
final_results = []
for query_idx in range(query_features.shape[0]):
if query_idx in bucket_votes and len(bucket_votes[query_idx]) > 0:
candidates = bucket_votes[query_idx]
candidates.sort(key=lambda x: x[1], reverse=True)
top_candidates = candidates[:top_k]
items = [c[0] for c in top_candidates]
sims_t = [c[2] for c in top_candidates]
final_results.append((torch.stack(items), torch.stack(sims_t)))
else:
final_results.append((torch.tensor([]), torch.tensor([])))
return final_results
def update_routing(self, features, retrieval_success):
for tree in self.trees:
tree.update_node_statistics(features, retrieval_success)
def get_forest_stats(self):
stats = {
'num_trees': self.num_trees,
'num_buckets': self.num_buckets,
'bucket_stats': [],
'tree_stats': []
}
for i, bucket in enumerate(self.buckets):
bucket_stat = bucket.get_bucket_stats()
bucket_stat['bucket_id'] = i
stats['bucket_stats'].append(bucket_stat)
for i, tree in enumerate(self.trees):
tree_stat = {
'tree_id': i,
'active_nodes': tree.node_active.sum().item(),
'total_samples': tree.node_samples.sum().item(),
'max_depth': tree.max_depth
}
stats['tree_stats'].append(tree_stat)
return stats
def forward(self, features, items=None, mode='store'):
if mode == 'store':
return self.store(features, items)
elif mode == 'retrieve':
return self.retrieve(features)
else:
raise ValueError("Mode must be 'store' or 'retrieve'")