############################################################################################################################################# #||||- - - |6.25.2025| - - - || MEMORY FOREST || - - - |memory_forest.py| - - -||||# ############################################################################################################################################# import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from collections import defaultdict, deque from typing import List, Dict, Tuple, Optional SAFE_MIN = -1e6 SAFE_MAX = 1e6 EPS = 1e-8 #||||- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ð“…¸ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -||||# def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor) tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor) return torch.clamp(tensor, min_val, max_val) def safe_cosine_similarity(a, b, dim=-1, eps=EPS): if a.dtype != torch.float32: a = a.float() if b.dtype != torch.float32: b = b.float() a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps) b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps) return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm) ############################################################################################################################################# ###################################################- - - ASSOCIATIVE HASH BUCKET - - -################################################### class AssociativeHashBucket(nn.Module): def __init__(self, bucket_size=64, embedding_dim=128, num_hash_functions=4): super().__init__() self.bucket_size = bucket_size self.embedding_dim = embedding_dim self.num_hash_functions = num_hash_functions self.hash_projections = nn.ModuleList([ nn.Linear(embedding_dim, 1, bias=True) for _ in range(num_hash_functions) ]) self.register_buffer('stored_items', torch.zeros(bucket_size, embedding_dim)) self.register_buffer('item_hashes', torch.zeros(bucket_size, num_hash_functions)) self.register_buffer('occupancy', torch.zeros(bucket_size, dtype=torch.bool)) self.register_buffer('access_counts', torch.zeros(bucket_size)) self.similarity_threshold = nn.Parameter(torch.tensor(0.7)) self.decay_rate = nn.Parameter(torch.tensor(0.99)) self.storage_pointer = 0 def compute_hash_signature(self, item_embedding): x = item_embedding if x.dim() == 1: x = x.unsqueeze(0) signatures = [] for hash_proj in self.hash_projections: sig = torch.tanh(hash_proj(x)).squeeze(-1) # (B,) signatures.append(sig) sigs = torch.stack(signatures, dim=-1) # (B, num_hash) return sigs.squeeze(0) def store_item(self, item_embedding, item_id=None): if item_embedding.dim() == 1: item_embedding = item_embedding.unsqueeze(0) batch_size = item_embedding.shape[0] stored_items = [] for i in range(batch_size): embedding = item_embedding[i] hash_sig = self.compute_hash_signature(embedding) if self.occupancy.any(): similarities = safe_cosine_similarity( embedding.unsqueeze(0), self.stored_items[self.occupancy], dim=-1 ).squeeze() threshold = torch.clamp(self.similarity_threshold, 0.1, 0.95) if similarities.numel() > 0 and similarities.max() > threshold: best_idx = self.occupancy.nonzero(as_tuple=False)[similarities.argmax()] self.stored_items[best_idx] = 0.9 * self.stored_items[best_idx] + 0.1 * embedding self.access_counts[best_idx] += 1 stored_items.append(int(best_idx.item())) continue if self.storage_pointer >= self.bucket_size: if self.occupancy.any(): rel_idx = self.access_counts[self.occupancy].argmin() evict_idx = self.occupancy.nonzero(as_tuple=False)[rel_idx] else: evict_idx = torch.tensor(0) else: evict_idx = torch.tensor(self.storage_pointer) self.storage_pointer = min(self.storage_pointer + 1, self.bucket_size) self.stored_items[evict_idx] = embedding self.item_hashes[evict_idx] = hash_sig.squeeze() self.occupancy[evict_idx] = True self.access_counts[evict_idx] = 1 stored_items.append(int(evict_idx.item())) return stored_items def retrieve_similar(self, query_embedding, top_k=5): if query_embedding.dim() == 1: query_embedding = query_embedding.unsqueeze(0) if not self.occupancy.any(): return [], [] valid_items = self.stored_items[self.occupancy] valid_indices = self.occupancy.nonzero(as_tuple=False).squeeze(-1) if valid_items.numel() == 0: return [], [] similarities = safe_cosine_similarity( query_embedding.expand(valid_items.shape[0], -1), valid_items, dim=-1 ).squeeze(-1) # (N,) if similarities.numel() == 0: return [], [] k = min(top_k, similarities.size(0)) top_sims, top_indices = torch.topk(similarities, k) retrieved_items = valid_items[top_indices] retrieved_indices = valid_indices[top_indices] for idx in retrieved_indices: self.access_counts[idx] += 1 return retrieved_items, top_sims def get_bucket_stats(self): return { 'occupancy_rate': self.occupancy.float().mean().item(), 'total_accesses': self.access_counts.sum().item(), 'avg_similarity': self.similarity_threshold.item(), 'storage_pointer': self.storage_pointer } ########################################################################################################################################### ################################################- - - MEMORY DECISION TREE - - -####################################################### class MemoryDecisionTree(nn.Module): def __init__(self, input_dim, max_depth=6, min_samples_split=2): super().__init__() self.input_dim = input_dim self.max_depth = max_depth self.min_samples_split = min_samples_split max_nodes = 2**max_depth - 1 self.split_weights = nn.Parameter(torch.randn(max_nodes, input_dim) * 0.1) self.split_biases = nn.Parameter(torch.zeros(max_nodes)) self.split_temperatures = nn.Parameter(torch.ones(max_nodes)) with torch.no_grad(): self.split_temperatures.data.mul_(0.6) self.split_biases.data.add_(0.01 * torch.randn_like(self.split_biases)) self.register_buffer('node_active', torch.zeros(max_nodes, dtype=torch.bool)) self.register_buffer('node_samples', torch.zeros(max_nodes)) self.leaf_to_bucket = {} self.bucket_to_leaves = defaultdict(list) self.node_active[0] = True def get_node_split(self, node_idx, x): if node_idx >= len(self.split_weights): return torch.zeros(x.shape[0], device=x.device) weights = self.split_weights[node_idx] bias = self.split_biases[node_idx] temp = torch.clamp(self.split_temperatures[node_idx], 0.1, 10.0) split_score = torch.matmul(x, weights) + bias split_prob = torch.sigmoid(split_score / temp) return split_prob def route_to_leaf(self, x, deterministic=False): batch_size = x.shape[0] device = x.device current_nodes = torch.zeros(batch_size, dtype=torch.long, device=device) paths = torch.zeros(batch_size, self.max_depth, dtype=torch.long, device=device) for depth in range(self.max_depth - 1): split_probs = torch.zeros(batch_size, device=device) for i in range(batch_size): node_idx = int(current_nodes[i].item()) if self.node_active[node_idx]: split_probs[i] = self.get_node_split(node_idx, x[i:i+1]).squeeze() if deterministic: go_right = (split_probs > 0.5).long() else: go_right = torch.bernoulli(split_probs).long() paths[:, depth] = go_right current_nodes = current_nodes * 2 + 1 + go_right return current_nodes, paths def assign_leaf_to_bucket(self, leaf_idx, bucket_idx): self.leaf_to_bucket[int(leaf_idx.item())] = int(bucket_idx) self.bucket_to_leaves[int(bucket_idx)].append(int(leaf_idx.item())) def get_bucket_for_input(self, x, deterministic=True): leaf_nodes, _ = self.route_to_leaf(x, deterministic=deterministic) bucket_assignments = [] for leaf in leaf_nodes: bucket_idx = self.leaf_to_bucket.get(int(leaf.item()), 0) bucket_assignments.append(bucket_idx) return torch.tensor(bucket_assignments, device=x.device) def update_node_statistics(self, x, rewards): leaf_nodes, paths = self.route_to_leaf(x, deterministic=True) for i in range(x.shape[0]): current_node = 0 reward = rewards[i].item() if torch.is_tensor(rewards[i]) else rewards[i] for depth in range(self.max_depth - 1): if current_node < len(self.node_samples): self.node_samples[current_node] += 1 self.node_active[current_node] = True if reward > 0.5: direction = paths[i, depth] if direction == 1: self.split_biases.data[current_node] += 0.01 else: self.split_biases.data[current_node] -= 0.01 direction = paths[i, depth] if depth < paths.shape[1] else 0 current_node = current_node * 2 + 1 + int(direction.item()) if current_node >= 2**self.max_depth - 1: break ########################################################################################################################################### ##################################################- - - MEMORY FOREST - - -############################################################ class MemoryForest(nn.Module): def __init__(self, input_dim, num_trees=5, max_depth=6, bucket_size=64, embedding_dim=128): super().__init__() self.input_dim = input_dim self.num_trees = num_trees self.embedding_dim = embedding_dim self.trees = nn.ModuleList([ MemoryDecisionTree(input_dim, max_depth) for _ in range(num_trees) ]) self.num_buckets = num_trees * (2**max_depth) self.buckets = nn.ModuleList([ AssociativeHashBucket(bucket_size, embedding_dim) for _ in range(self.num_buckets) ]) self.feature_encoder = nn.Sequential( nn.Linear(input_dim, embedding_dim), nn.LayerNorm(embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, embedding_dim) ) self._initialize_bucket_assignments() def _initialize_bucket_assignments(self): bucket_idx = 0 for tree_idx, tree in enumerate(self.trees): start_leaf = 2**(tree.max_depth - 1) - 1 end_leaf = 2**tree.max_depth - 2 for leaf in range(start_leaf, end_leaf + 1): if bucket_idx < self.num_buckets: tree.assign_leaf_to_bucket(torch.tensor(leaf), bucket_idx) bucket_idx += 1 def store(self, features, items=None): if items is None: items = features embeddings = self.feature_encoder(features) storage_results = [] for tree in self.trees: bucket_assignments = tree.get_bucket_for_input(features, deterministic=False) for i, b_idx in enumerate(bucket_assignments.tolist()): if b_idx < len(self.buckets): stored_idx = self.buckets[b_idx].store_item(embeddings[i]) storage_results.append((b_idx, stored_idx)) return storage_results def retrieve(self, query_features, top_k=5): query_embeddings = self.feature_encoder(query_features) bucket_votes = defaultdict(list) for tree in self.trees: bucket_assignments = tree.get_bucket_for_input(query_features, deterministic=True) for i, b_idx in enumerate(bucket_assignments.tolist()): if b_idx < len(self.buckets): retrieved_items, similarities = self.buckets[b_idx].retrieve_similar( query_embeddings[i], top_k=top_k ) if len(retrieved_items) > 0: float_sims = similarities.detach().cpu().tolist() for itm, sim_t, sim_f in zip(retrieved_items, similarities, float_sims): bucket_votes[i].append((itm, sim_f, sim_t)) final_results = [] for query_idx in range(query_features.shape[0]): if query_idx in bucket_votes and len(bucket_votes[query_idx]) > 0: candidates = bucket_votes[query_idx] candidates.sort(key=lambda x: x[1], reverse=True) top_candidates = candidates[:top_k] items = [c[0] for c in top_candidates] sims_t = [c[2] for c in top_candidates] final_results.append((torch.stack(items), torch.stack(sims_t))) else: final_results.append((torch.tensor([]), torch.tensor([]))) return final_results def update_routing(self, features, retrieval_success): for tree in self.trees: tree.update_node_statistics(features, retrieval_success) def get_forest_stats(self): stats = { 'num_trees': self.num_trees, 'num_buckets': self.num_buckets, 'bucket_stats': [], 'tree_stats': [] } for i, bucket in enumerate(self.buckets): bucket_stat = bucket.get_bucket_stats() bucket_stat['bucket_id'] = i stats['bucket_stats'].append(bucket_stat) for i, tree in enumerate(self.trees): tree_stat = { 'tree_id': i, 'active_nodes': tree.node_active.sum().item(), 'total_samples': tree.node_samples.sum().item(), 'max_depth': tree.max_depth } stats['tree_stats'].append(tree_stat) return stats def forward(self, features, items=None, mode='store'): if mode == 'store': return self.store(features, items) elif mode == 'retrieve': return self.retrieve(features) else: raise ValueError("Mode must be 'store' or 'retrieve'")