| """ |
| GLADIUS v2.0 — Cold Memory: Fossil Record |
| |
| Native vector store for long-term memory. NOT HEKTOR. |
| GLADIUS-native 640-dim embeddings as keys, stored as register_buffers |
| so they persist with checkpoints but are not trainable. |
| |
| Architecture analogy: sedimentary rock. Hot memory is lava (active, volatile). |
| Warm memory is magma (slow-moving, adaptive). Cold memory is fossil record — |
| compressed knowledge from past experience that can be retrieved when relevant. |
| |
| Capacity: 8192 fossils (configurable). Importance-weighted eviction. |
| Retrieval: Cosine similarity, top-k nearest neighbors. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
|
|
|
|
| class FossilStore(nn.Module): |
| """ |
| Native vector store for GLADIUS long-term memory. |
| |
| All storage is in register_buffers — persists with checkpoint, |
| no gradients, GPU-friendly. Circular buffer with importance-weighted |
| eviction: when full, the least important fossil gets overwritten. |
| |
| Metadata per fossil: [step_written, domain_id, importance_score, layer_origin] |
| """ |
|
|
| def __init__(self, config, capacity: int = 8192): |
| super().__init__() |
| self.capacity = capacity |
| self.hidden_dim = config.hidden_dim |
|
|
| |
| self.register_buffer('fossil_keys', torch.zeros(capacity, config.hidden_dim)) |
| self.register_buffer('fossil_values', torch.zeros(capacity, config.hidden_dim)) |
| self.register_buffer( |
| 'fossil_metadata', torch.zeros(capacity, 4) |
| ) |
| self.register_buffer('fossil_count', torch.tensor(0, dtype=torch.long)) |
| self.register_buffer('fossil_write_head', torch.tensor(0, dtype=torch.long)) |
|
|
| @torch.no_grad() |
| def store( |
| self, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| metadata: Optional[torch.Tensor] = None, |
| ) -> int: |
| """ |
| Write a single fossil to the store. |
| |
| Args: |
| key: (hidden_dim,) — the embedding key for retrieval |
| value: (hidden_dim,) — the content to retrieve |
| metadata: (4,) — [step_written, domain_id, importance_score, layer_origin] |
| If None, defaults to [0, 0, 1.0, 0] |
| |
| Returns: |
| slot index where the fossil was written |
| """ |
| if metadata is None: |
| metadata = torch.tensor( |
| [0.0, 0.0, 1.0, 0.0], device=key.device, dtype=key.dtype |
| ) |
|
|
| |
| key = key.detach().to(self.fossil_keys.device) |
| value = value.detach().to(self.fossil_values.device) |
| metadata = metadata.detach().to(self.fossil_metadata.device) |
|
|
| if self.fossil_count.item() < self.capacity: |
| |
| slot = self.fossil_write_head.item() |
| self.fossil_count.add_(1) |
| else: |
| |
| importance_scores = self.fossil_metadata[:, 2] |
| slot = importance_scores.argmin().item() |
|
|
| |
| self.fossil_keys[slot] = key |
| self.fossil_values[slot] = value |
| self.fossil_metadata[slot] = metadata |
|
|
| |
| self.fossil_write_head.fill_((slot + 1) % self.capacity) |
|
|
| return slot |
|
|
| @torch.no_grad() |
| def retrieve( |
| self, query: torch.Tensor, top_k: int = 4 |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Retrieve top-k fossils by cosine similarity. |
| |
| Args: |
| query: (batch, hidden_dim) — query vectors |
| |
| Returns: |
| values: (batch, top_k, hidden_dim) — retrieved fossil values |
| scores: (batch, top_k) — similarity scores |
| metadata: (batch, top_k, 4) — metadata for each retrieved fossil |
| """ |
| B = query.shape[0] |
| device = query.device |
| count = self.fossil_count.item() |
|
|
| |
| if count == 0: |
| return ( |
| torch.zeros(B, top_k, self.hidden_dim, device=device), |
| torch.zeros(B, top_k, device=device), |
| torch.zeros(B, top_k, 4, device=device), |
| ) |
|
|
| |
| active_keys = self.fossil_keys[:count] |
| active_values = self.fossil_values[:count] |
| active_metadata = self.fossil_metadata[:count] |
|
|
| |
| query_norm = F.normalize(query, dim=-1) |
| keys_norm = F.normalize(active_keys, dim=-1) |
|
|
| |
| similarity = torch.matmul(query_norm, keys_norm.T) |
|
|
| |
| actual_k = min(top_k, count) |
| top_scores, top_indices = similarity.topk(actual_k, dim=-1) |
|
|
| |
| |
| gathered_values = active_values[top_indices] |
| gathered_metadata = active_metadata[top_indices] |
|
|
| |
| if actual_k < top_k: |
| pad_size = top_k - actual_k |
| gathered_values = F.pad(gathered_values, (0, 0, 0, pad_size)) |
| top_scores = F.pad(top_scores, (0, pad_size)) |
| gathered_metadata = F.pad(gathered_metadata, (0, 0, 0, pad_size)) |
|
|
| return gathered_values, top_scores, gathered_metadata |
|
|
| @torch.no_grad() |
| def compress( |
| self, |
| hot_keys: torch.Tensor, |
| hot_values: torch.Tensor, |
| hot_usage: torch.Tensor, |
| importance_threshold: float = 0.7, |
| step: int = 0, |
| domain_id: int = 0, |
| ) -> int: |
| """ |
| Archive hot memory slots to cold storage. |
| |
| Filters by usage (top 25% by usage count) and stores qualifying |
| slots as fossils. Called during hot→cold archival. |
| |
| Args: |
| hot_keys: (num_slots, hidden_dim) — hot memory keys |
| hot_values: (num_slots, hidden_dim) — hot memory values |
| hot_usage: (num_slots,) — usage count per slot |
| importance_threshold: minimum relative usage to qualify |
| step: current training step (for metadata) |
| domain_id: current domain (for metadata) |
| |
| Returns: |
| Number of fossils written |
| """ |
| num_slots = hot_keys.shape[0] |
| if num_slots == 0 or hot_usage.sum() == 0: |
| return 0 |
|
|
| |
| threshold_count = max(1, num_slots // 4) |
| _, top_indices = hot_usage.topk(threshold_count) |
|
|
| |
| usage_max = hot_usage.max() |
| if usage_max > 0: |
| normalized_usage = hot_usage / usage_max |
| else: |
| return 0 |
|
|
| |
| written = 0 |
| for idx in top_indices: |
| idx = idx.item() |
| importance = normalized_usage[idx].item() |
| if importance < importance_threshold: |
| continue |
|
|
| |
| if hot_keys[idx].abs().sum() < 1e-8: |
| continue |
|
|
| metadata = torch.tensor( |
| [float(step), float(domain_id), importance, 0.0], |
| device=hot_keys.device, |
| dtype=hot_keys.dtype, |
| ) |
| self.store(hot_keys[idx], hot_values[idx], metadata) |
| written += 1 |
|
|
| return written |
|
|
| def stats(self) -> dict: |
| """Return diagnostic statistics about the fossil store.""" |
| count = self.fossil_count.item() |
| result = { |
| 'count': count, |
| 'capacity': self.capacity, |
| 'usage_pct': count / self.capacity * 100 if self.capacity > 0 else 0, |
| 'write_head': self.fossil_write_head.item(), |
| } |
|
|
| if count > 0: |
| active_metadata = self.fossil_metadata[:count] |
| importance = active_metadata[:, 2] |
| result.update({ |
| 'importance_mean': importance.mean().item(), |
| 'importance_std': importance.std().item() if count > 1 else 0.0, |
| 'importance_min': importance.min().item(), |
| 'importance_max': importance.max().item(), |
| 'oldest_step': active_metadata[:, 0].min().item(), |
| 'newest_step': active_metadata[:, 0].max().item(), |
| }) |
| else: |
| result.update({ |
| 'importance_mean': 0.0, |
| 'importance_std': 0.0, |
| 'importance_min': 0.0, |
| 'importance_max': 0.0, |
| 'oldest_step': 0.0, |
| 'newest_step': 0.0, |
| }) |
|
|
| return result |
|
|
| def __repr__(self) -> str: |
| count = self.fossil_count.item() |
| return ( |
| f"FossilStore(capacity={self.capacity}, " |
| f"count={count}, " |
| f"usage={count/self.capacity*100:.1f}%, " |
| f"hidden_dim={self.hidden_dim})" |
| ) |
|
|