""" GLADIUS v2.0 — Cold Memory: Fossil Record Native vector store for long-term memory. NOT HEKTOR. GLADIUS-native 640-dim embeddings as keys, stored as register_buffers so they persist with checkpoints but are not trainable. Architecture analogy: sedimentary rock. Hot memory is lava (active, volatile). Warm memory is magma (slow-moving, adaptive). Cold memory is fossil record — compressed knowledge from past experience that can be retrieved when relevant. Capacity: 8192 fossils (configurable). Importance-weighted eviction. Retrieval: Cosine similarity, top-k nearest neighbors. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class FossilStore(nn.Module): """ Native vector store for GLADIUS long-term memory. All storage is in register_buffers — persists with checkpoint, no gradients, GPU-friendly. Circular buffer with importance-weighted eviction: when full, the least important fossil gets overwritten. Metadata per fossil: [step_written, domain_id, importance_score, layer_origin] """ def __init__(self, config, capacity: int = 8192): super().__init__() self.capacity = capacity self.hidden_dim = config.hidden_dim # Fossil storage — register_buffer = checkpoint-persistent, non-trainable self.register_buffer('fossil_keys', torch.zeros(capacity, config.hidden_dim)) self.register_buffer('fossil_values', torch.zeros(capacity, config.hidden_dim)) self.register_buffer( 'fossil_metadata', torch.zeros(capacity, 4) ) # [step_written, domain_id, importance_score, layer_origin] self.register_buffer('fossil_count', torch.tensor(0, dtype=torch.long)) self.register_buffer('fossil_write_head', torch.tensor(0, dtype=torch.long)) @torch.no_grad() def store( self, key: torch.Tensor, value: torch.Tensor, metadata: Optional[torch.Tensor] = None, ) -> int: """ Write a single fossil to the store. Args: key: (hidden_dim,) — the embedding key for retrieval value: (hidden_dim,) — the content to retrieve metadata: (4,) — [step_written, domain_id, importance_score, layer_origin] If None, defaults to [0, 0, 1.0, 0] Returns: slot index where the fossil was written """ if metadata is None: metadata = torch.tensor( [0.0, 0.0, 1.0, 0.0], device=key.device, dtype=key.dtype ) # Ensure tensors are on same device key = key.detach().to(self.fossil_keys.device) value = value.detach().to(self.fossil_values.device) metadata = metadata.detach().to(self.fossil_metadata.device) if self.fossil_count.item() < self.capacity: # Still have empty slots — use circular write head slot = self.fossil_write_head.item() self.fossil_count.add_(1) else: # Full — evict lowest importance fossil importance_scores = self.fossil_metadata[:, 2] # Column 2 = importance slot = importance_scores.argmin().item() # Write self.fossil_keys[slot] = key self.fossil_values[slot] = value self.fossil_metadata[slot] = metadata # Advance write head (wrap around) self.fossil_write_head.fill_((slot + 1) % self.capacity) return slot @torch.no_grad() def retrieve( self, query: torch.Tensor, top_k: int = 4 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Retrieve top-k fossils by cosine similarity. Args: query: (batch, hidden_dim) — query vectors Returns: values: (batch, top_k, hidden_dim) — retrieved fossil values scores: (batch, top_k) — similarity scores metadata: (batch, top_k, 4) — metadata for each retrieved fossil """ B = query.shape[0] device = query.device count = self.fossil_count.item() # Empty store — return zeros if count == 0: return ( torch.zeros(B, top_k, self.hidden_dim, device=device), torch.zeros(B, top_k, device=device), torch.zeros(B, top_k, 4, device=device), ) # Only search over populated slots active_keys = self.fossil_keys[:count] # (count, hidden_dim) active_values = self.fossil_values[:count] active_metadata = self.fossil_metadata[:count] # Cosine similarity: normalize then matmul query_norm = F.normalize(query, dim=-1) # (B, D) keys_norm = F.normalize(active_keys, dim=-1) # (count, D) # Similarity: (B, count) similarity = torch.matmul(query_norm, keys_norm.T) # Handle case where count < top_k actual_k = min(top_k, count) top_scores, top_indices = similarity.topk(actual_k, dim=-1) # (B, actual_k) # Gather values and metadata # top_indices: (B, actual_k) — indices into active storage gathered_values = active_values[top_indices] # (B, actual_k, D) gathered_metadata = active_metadata[top_indices] # (B, actual_k, 4) # Pad if count < top_k if actual_k < top_k: pad_size = top_k - actual_k gathered_values = F.pad(gathered_values, (0, 0, 0, pad_size)) top_scores = F.pad(top_scores, (0, pad_size)) gathered_metadata = F.pad(gathered_metadata, (0, 0, 0, pad_size)) return gathered_values, top_scores, gathered_metadata @torch.no_grad() def compress( self, hot_keys: torch.Tensor, hot_values: torch.Tensor, hot_usage: torch.Tensor, importance_threshold: float = 0.7, step: int = 0, domain_id: int = 0, ) -> int: """ Archive hot memory slots to cold storage. Filters by usage (top 25% by usage count) and stores qualifying slots as fossils. Called during hot→cold archival. Args: hot_keys: (num_slots, hidden_dim) — hot memory keys hot_values: (num_slots, hidden_dim) — hot memory values hot_usage: (num_slots,) — usage count per slot importance_threshold: minimum relative usage to qualify step: current training step (for metadata) domain_id: current domain (for metadata) Returns: Number of fossils written """ num_slots = hot_keys.shape[0] if num_slots == 0 or hot_usage.sum() == 0: return 0 # Top 25% by usage count threshold_count = max(1, num_slots // 4) _, top_indices = hot_usage.topk(threshold_count) # Normalize usage to [0, 1] for importance scoring usage_max = hot_usage.max() if usage_max > 0: normalized_usage = hot_usage / usage_max else: return 0 # Filter by importance threshold written = 0 for idx in top_indices: idx = idx.item() importance = normalized_usage[idx].item() if importance < importance_threshold: continue # Skip empty slots (all zeros) if hot_keys[idx].abs().sum() < 1e-8: continue metadata = torch.tensor( [float(step), float(domain_id), importance, 0.0], device=hot_keys.device, dtype=hot_keys.dtype, ) self.store(hot_keys[idx], hot_values[idx], metadata) written += 1 return written def stats(self) -> dict: """Return diagnostic statistics about the fossil store.""" count = self.fossil_count.item() result = { 'count': count, 'capacity': self.capacity, 'usage_pct': count / self.capacity * 100 if self.capacity > 0 else 0, 'write_head': self.fossil_write_head.item(), } if count > 0: active_metadata = self.fossil_metadata[:count] importance = active_metadata[:, 2] result.update({ 'importance_mean': importance.mean().item(), 'importance_std': importance.std().item() if count > 1 else 0.0, 'importance_min': importance.min().item(), 'importance_max': importance.max().item(), 'oldest_step': active_metadata[:, 0].min().item(), 'newest_step': active_metadata[:, 0].max().item(), }) else: result.update({ 'importance_mean': 0.0, 'importance_std': 0.0, 'importance_min': 0.0, 'importance_max': 0.0, 'oldest_step': 0.0, 'newest_step': 0.0, }) return result def __repr__(self) -> str: count = self.fossil_count.item() return ( f"FossilStore(capacity={self.capacity}, " f"count={count}, " f"usage={count/self.capacity*100:.1f}%, " f"hidden_dim={self.hidden_dim})" )