|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import math |
|
|
from collections import defaultdict, deque |
|
|
from typing import List, Dict, Tuple, Optional |
|
|
|
|
|
SAFE_MIN = -1e6 |
|
|
SAFE_MAX = 1e6 |
|
|
EPS = 1e-8 |
|
|
|
|
|
|
|
|
|
|
|
def make_safe(tensor, min_val=SAFE_MIN, max_val=SAFE_MAX): |
|
|
tensor = torch.where(torch.isnan(tensor), torch.tensor(0.0, device=tensor.device, dtype=tensor.dtype), tensor) |
|
|
tensor = torch.where(torch.isinf(tensor), torch.tensor(max_val, device=tensor.device, dtype=tensor.dtype), tensor) |
|
|
return torch.clamp(tensor, min_val, max_val) |
|
|
|
|
|
def safe_cosine_similarity(a, b, dim=-1, eps=EPS): |
|
|
if a.dtype != torch.float32: |
|
|
a = a.float() |
|
|
if b.dtype != torch.float32: |
|
|
b = b.float() |
|
|
a_norm = torch.norm(a, dim=dim, keepdim=True).clamp(min=eps) |
|
|
b_norm = torch.norm(b, dim=dim, keepdim=True).clamp(min=eps) |
|
|
return torch.sum(a * b, dim=dim, keepdim=True) / (a_norm * b_norm) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AssociativeHashBucket(nn.Module): |
|
|
def __init__(self, bucket_size=64, embedding_dim=128, num_hash_functions=4): |
|
|
super().__init__() |
|
|
self.bucket_size = bucket_size |
|
|
self.embedding_dim = embedding_dim |
|
|
self.num_hash_functions = num_hash_functions |
|
|
|
|
|
self.hash_projections = nn.ModuleList([ |
|
|
nn.Linear(embedding_dim, 1, bias=True) for _ in range(num_hash_functions) |
|
|
]) |
|
|
|
|
|
self.register_buffer('stored_items', torch.zeros(bucket_size, embedding_dim)) |
|
|
self.register_buffer('item_hashes', torch.zeros(bucket_size, num_hash_functions)) |
|
|
self.register_buffer('occupancy', torch.zeros(bucket_size, dtype=torch.bool)) |
|
|
self.register_buffer('access_counts', torch.zeros(bucket_size)) |
|
|
|
|
|
self.similarity_threshold = nn.Parameter(torch.tensor(0.7)) |
|
|
self.decay_rate = nn.Parameter(torch.tensor(0.99)) |
|
|
|
|
|
self.storage_pointer = 0 |
|
|
|
|
|
def compute_hash_signature(self, item_embedding): |
|
|
x = item_embedding |
|
|
if x.dim() == 1: |
|
|
x = x.unsqueeze(0) |
|
|
signatures = [] |
|
|
for hash_proj in self.hash_projections: |
|
|
sig = torch.tanh(hash_proj(x)).squeeze(-1) |
|
|
signatures.append(sig) |
|
|
sigs = torch.stack(signatures, dim=-1) |
|
|
return sigs.squeeze(0) |
|
|
|
|
|
def store_item(self, item_embedding, item_id=None): |
|
|
if item_embedding.dim() == 1: |
|
|
item_embedding = item_embedding.unsqueeze(0) |
|
|
|
|
|
batch_size = item_embedding.shape[0] |
|
|
stored_items = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
embedding = item_embedding[i] |
|
|
hash_sig = self.compute_hash_signature(embedding) |
|
|
|
|
|
if self.occupancy.any(): |
|
|
similarities = safe_cosine_similarity( |
|
|
embedding.unsqueeze(0), |
|
|
self.stored_items[self.occupancy], |
|
|
dim=-1 |
|
|
).squeeze() |
|
|
|
|
|
threshold = torch.clamp(self.similarity_threshold, 0.1, 0.95) |
|
|
if similarities.numel() > 0 and similarities.max() > threshold: |
|
|
best_idx = self.occupancy.nonzero(as_tuple=False)[similarities.argmax()] |
|
|
self.stored_items[best_idx] = 0.9 * self.stored_items[best_idx] + 0.1 * embedding |
|
|
self.access_counts[best_idx] += 1 |
|
|
stored_items.append(int(best_idx.item())) |
|
|
continue |
|
|
|
|
|
if self.storage_pointer >= self.bucket_size: |
|
|
if self.occupancy.any(): |
|
|
rel_idx = self.access_counts[self.occupancy].argmin() |
|
|
evict_idx = self.occupancy.nonzero(as_tuple=False)[rel_idx] |
|
|
else: |
|
|
evict_idx = torch.tensor(0) |
|
|
else: |
|
|
evict_idx = torch.tensor(self.storage_pointer) |
|
|
self.storage_pointer = min(self.storage_pointer + 1, self.bucket_size) |
|
|
|
|
|
self.stored_items[evict_idx] = embedding |
|
|
self.item_hashes[evict_idx] = hash_sig.squeeze() |
|
|
self.occupancy[evict_idx] = True |
|
|
self.access_counts[evict_idx] = 1 |
|
|
stored_items.append(int(evict_idx.item())) |
|
|
|
|
|
return stored_items |
|
|
|
|
|
def retrieve_similar(self, query_embedding, top_k=5): |
|
|
if query_embedding.dim() == 1: |
|
|
query_embedding = query_embedding.unsqueeze(0) |
|
|
|
|
|
if not self.occupancy.any(): |
|
|
return [], [] |
|
|
|
|
|
valid_items = self.stored_items[self.occupancy] |
|
|
valid_indices = self.occupancy.nonzero(as_tuple=False).squeeze(-1) |
|
|
|
|
|
if valid_items.numel() == 0: |
|
|
return [], [] |
|
|
|
|
|
similarities = safe_cosine_similarity( |
|
|
query_embedding.expand(valid_items.shape[0], -1), |
|
|
valid_items, |
|
|
dim=-1 |
|
|
).squeeze(-1) |
|
|
|
|
|
if similarities.numel() == 0: |
|
|
return [], [] |
|
|
|
|
|
k = min(top_k, similarities.size(0)) |
|
|
top_sims, top_indices = torch.topk(similarities, k) |
|
|
|
|
|
retrieved_items = valid_items[top_indices] |
|
|
retrieved_indices = valid_indices[top_indices] |
|
|
|
|
|
for idx in retrieved_indices: |
|
|
self.access_counts[idx] += 1 |
|
|
|
|
|
return retrieved_items, top_sims |
|
|
|
|
|
def get_bucket_stats(self): |
|
|
return { |
|
|
'occupancy_rate': self.occupancy.float().mean().item(), |
|
|
'total_accesses': self.access_counts.sum().item(), |
|
|
'avg_similarity': self.similarity_threshold.item(), |
|
|
'storage_pointer': self.storage_pointer |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryDecisionTree(nn.Module): |
|
|
def __init__(self, input_dim, max_depth=6, min_samples_split=2): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.max_depth = max_depth |
|
|
self.min_samples_split = min_samples_split |
|
|
|
|
|
max_nodes = 2**max_depth - 1 |
|
|
|
|
|
self.split_weights = nn.Parameter(torch.randn(max_nodes, input_dim) * 0.1) |
|
|
self.split_biases = nn.Parameter(torch.zeros(max_nodes)) |
|
|
self.split_temperatures = nn.Parameter(torch.ones(max_nodes)) |
|
|
with torch.no_grad(): |
|
|
self.split_temperatures.data.mul_(0.6) |
|
|
self.split_biases.data.add_(0.01 * torch.randn_like(self.split_biases)) |
|
|
|
|
|
self.register_buffer('node_active', torch.zeros(max_nodes, dtype=torch.bool)) |
|
|
self.register_buffer('node_samples', torch.zeros(max_nodes)) |
|
|
|
|
|
self.leaf_to_bucket = {} |
|
|
self.bucket_to_leaves = defaultdict(list) |
|
|
|
|
|
self.node_active[0] = True |
|
|
|
|
|
def get_node_split(self, node_idx, x): |
|
|
if node_idx >= len(self.split_weights): |
|
|
return torch.zeros(x.shape[0], device=x.device) |
|
|
|
|
|
weights = self.split_weights[node_idx] |
|
|
bias = self.split_biases[node_idx] |
|
|
temp = torch.clamp(self.split_temperatures[node_idx], 0.1, 10.0) |
|
|
|
|
|
split_score = torch.matmul(x, weights) + bias |
|
|
split_prob = torch.sigmoid(split_score / temp) |
|
|
|
|
|
return split_prob |
|
|
|
|
|
def route_to_leaf(self, x, deterministic=False): |
|
|
batch_size = x.shape[0] |
|
|
device = x.device |
|
|
|
|
|
current_nodes = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
|
paths = torch.zeros(batch_size, self.max_depth, dtype=torch.long, device=device) |
|
|
|
|
|
for depth in range(self.max_depth - 1): |
|
|
split_probs = torch.zeros(batch_size, device=device) |
|
|
|
|
|
for i in range(batch_size): |
|
|
node_idx = int(current_nodes[i].item()) |
|
|
if self.node_active[node_idx]: |
|
|
split_probs[i] = self.get_node_split(node_idx, x[i:i+1]).squeeze() |
|
|
|
|
|
if deterministic: |
|
|
go_right = (split_probs > 0.5).long() |
|
|
else: |
|
|
go_right = torch.bernoulli(split_probs).long() |
|
|
|
|
|
paths[:, depth] = go_right |
|
|
|
|
|
current_nodes = current_nodes * 2 + 1 + go_right |
|
|
|
|
|
return current_nodes, paths |
|
|
|
|
|
def assign_leaf_to_bucket(self, leaf_idx, bucket_idx): |
|
|
self.leaf_to_bucket[int(leaf_idx.item())] = int(bucket_idx) |
|
|
self.bucket_to_leaves[int(bucket_idx)].append(int(leaf_idx.item())) |
|
|
|
|
|
def get_bucket_for_input(self, x, deterministic=True): |
|
|
leaf_nodes, _ = self.route_to_leaf(x, deterministic=deterministic) |
|
|
|
|
|
bucket_assignments = [] |
|
|
for leaf in leaf_nodes: |
|
|
bucket_idx = self.leaf_to_bucket.get(int(leaf.item()), 0) |
|
|
bucket_assignments.append(bucket_idx) |
|
|
|
|
|
return torch.tensor(bucket_assignments, device=x.device) |
|
|
|
|
|
def update_node_statistics(self, x, rewards): |
|
|
leaf_nodes, paths = self.route_to_leaf(x, deterministic=True) |
|
|
|
|
|
for i in range(x.shape[0]): |
|
|
current_node = 0 |
|
|
reward = rewards[i].item() if torch.is_tensor(rewards[i]) else rewards[i] |
|
|
|
|
|
for depth in range(self.max_depth - 1): |
|
|
if current_node < len(self.node_samples): |
|
|
self.node_samples[current_node] += 1 |
|
|
self.node_active[current_node] = True |
|
|
|
|
|
if reward > 0.5: |
|
|
direction = paths[i, depth] |
|
|
if direction == 1: |
|
|
self.split_biases.data[current_node] += 0.01 |
|
|
else: |
|
|
self.split_biases.data[current_node] -= 0.01 |
|
|
|
|
|
direction = paths[i, depth] if depth < paths.shape[1] else 0 |
|
|
current_node = current_node * 2 + 1 + int(direction.item()) |
|
|
|
|
|
if current_node >= 2**self.max_depth - 1: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryForest(nn.Module): |
|
|
def __init__(self, input_dim, num_trees=5, max_depth=6, bucket_size=64, embedding_dim=128): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.num_trees = num_trees |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
self.trees = nn.ModuleList([ |
|
|
MemoryDecisionTree(input_dim, max_depth) for _ in range(num_trees) |
|
|
]) |
|
|
|
|
|
self.num_buckets = num_trees * (2**max_depth) |
|
|
self.buckets = nn.ModuleList([ |
|
|
AssociativeHashBucket(bucket_size, embedding_dim) for _ in range(self.num_buckets) |
|
|
]) |
|
|
|
|
|
self.feature_encoder = nn.Sequential( |
|
|
nn.Linear(input_dim, embedding_dim), |
|
|
nn.LayerNorm(embedding_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(embedding_dim, embedding_dim) |
|
|
) |
|
|
|
|
|
self._initialize_bucket_assignments() |
|
|
|
|
|
def _initialize_bucket_assignments(self): |
|
|
bucket_idx = 0 |
|
|
for tree_idx, tree in enumerate(self.trees): |
|
|
start_leaf = 2**(tree.max_depth - 1) - 1 |
|
|
end_leaf = 2**tree.max_depth - 2 |
|
|
for leaf in range(start_leaf, end_leaf + 1): |
|
|
if bucket_idx < self.num_buckets: |
|
|
tree.assign_leaf_to_bucket(torch.tensor(leaf), bucket_idx) |
|
|
bucket_idx += 1 |
|
|
|
|
|
def store(self, features, items=None): |
|
|
if items is None: |
|
|
items = features |
|
|
|
|
|
embeddings = self.feature_encoder(features) |
|
|
|
|
|
storage_results = [] |
|
|
|
|
|
for tree in self.trees: |
|
|
bucket_assignments = tree.get_bucket_for_input(features, deterministic=False) |
|
|
|
|
|
for i, b_idx in enumerate(bucket_assignments.tolist()): |
|
|
if b_idx < len(self.buckets): |
|
|
stored_idx = self.buckets[b_idx].store_item(embeddings[i]) |
|
|
storage_results.append((b_idx, stored_idx)) |
|
|
|
|
|
return storage_results |
|
|
|
|
|
def retrieve(self, query_features, top_k=5): |
|
|
query_embeddings = self.feature_encoder(query_features) |
|
|
|
|
|
bucket_votes = defaultdict(list) |
|
|
|
|
|
for tree in self.trees: |
|
|
bucket_assignments = tree.get_bucket_for_input(query_features, deterministic=True) |
|
|
|
|
|
for i, b_idx in enumerate(bucket_assignments.tolist()): |
|
|
if b_idx < len(self.buckets): |
|
|
retrieved_items, similarities = self.buckets[b_idx].retrieve_similar( |
|
|
query_embeddings[i], top_k=top_k |
|
|
) |
|
|
|
|
|
if len(retrieved_items) > 0: |
|
|
float_sims = similarities.detach().cpu().tolist() |
|
|
for itm, sim_t, sim_f in zip(retrieved_items, similarities, float_sims): |
|
|
bucket_votes[i].append((itm, sim_f, sim_t)) |
|
|
|
|
|
final_results = [] |
|
|
for query_idx in range(query_features.shape[0]): |
|
|
if query_idx in bucket_votes and len(bucket_votes[query_idx]) > 0: |
|
|
candidates = bucket_votes[query_idx] |
|
|
candidates.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
top_candidates = candidates[:top_k] |
|
|
items = [c[0] for c in top_candidates] |
|
|
sims_t = [c[2] for c in top_candidates] |
|
|
final_results.append((torch.stack(items), torch.stack(sims_t))) |
|
|
else: |
|
|
final_results.append((torch.tensor([]), torch.tensor([]))) |
|
|
|
|
|
return final_results |
|
|
|
|
|
def update_routing(self, features, retrieval_success): |
|
|
for tree in self.trees: |
|
|
tree.update_node_statistics(features, retrieval_success) |
|
|
|
|
|
def get_forest_stats(self): |
|
|
stats = { |
|
|
'num_trees': self.num_trees, |
|
|
'num_buckets': self.num_buckets, |
|
|
'bucket_stats': [], |
|
|
'tree_stats': [] |
|
|
} |
|
|
|
|
|
for i, bucket in enumerate(self.buckets): |
|
|
bucket_stat = bucket.get_bucket_stats() |
|
|
bucket_stat['bucket_id'] = i |
|
|
stats['bucket_stats'].append(bucket_stat) |
|
|
|
|
|
for i, tree in enumerate(self.trees): |
|
|
tree_stat = { |
|
|
'tree_id': i, |
|
|
'active_nodes': tree.node_active.sum().item(), |
|
|
'total_samples': tree.node_samples.sum().item(), |
|
|
'max_depth': tree.max_depth |
|
|
} |
|
|
stats['tree_stats'].append(tree_stat) |
|
|
|
|
|
return stats |
|
|
|
|
|
def forward(self, features, items=None, mode='store'): |
|
|
if mode == 'store': |
|
|
return self.store(features, items) |
|
|
elif mode == 'retrieve': |
|
|
return self.retrieve(features) |
|
|
else: |
|
|
raise ValueError("Mode must be 'store' or 'retrieve'") |
|
|
|
|
|
|