Gladius / extensions /memory_v2 /cold_memory.py
amuzetnoM's picture
GLADIUS v5.0 — Cognitive kernel with Synthase depth attention, PUP uncertainty, Memory V2, multi-tokenizer architecture
3f42614
"""
GLADIUS v2.0 — Cold Memory: Fossil Record
Native vector store for long-term memory. NOT HEKTOR.
GLADIUS-native 640-dim embeddings as keys, stored as register_buffers
so they persist with checkpoints but are not trainable.
Architecture analogy: sedimentary rock. Hot memory is lava (active, volatile).
Warm memory is magma (slow-moving, adaptive). Cold memory is fossil record —
compressed knowledge from past experience that can be retrieved when relevant.
Capacity: 8192 fossils (configurable). Importance-weighted eviction.
Retrieval: Cosine similarity, top-k nearest neighbors.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class FossilStore(nn.Module):
"""
Native vector store for GLADIUS long-term memory.
All storage is in register_buffers — persists with checkpoint,
no gradients, GPU-friendly. Circular buffer with importance-weighted
eviction: when full, the least important fossil gets overwritten.
Metadata per fossil: [step_written, domain_id, importance_score, layer_origin]
"""
def __init__(self, config, capacity: int = 8192):
super().__init__()
self.capacity = capacity
self.hidden_dim = config.hidden_dim
# Fossil storage — register_buffer = checkpoint-persistent, non-trainable
self.register_buffer('fossil_keys', torch.zeros(capacity, config.hidden_dim))
self.register_buffer('fossil_values', torch.zeros(capacity, config.hidden_dim))
self.register_buffer(
'fossil_metadata', torch.zeros(capacity, 4)
) # [step_written, domain_id, importance_score, layer_origin]
self.register_buffer('fossil_count', torch.tensor(0, dtype=torch.long))
self.register_buffer('fossil_write_head', torch.tensor(0, dtype=torch.long))
@torch.no_grad()
def store(
self,
key: torch.Tensor,
value: torch.Tensor,
metadata: Optional[torch.Tensor] = None,
) -> int:
"""
Write a single fossil to the store.
Args:
key: (hidden_dim,) — the embedding key for retrieval
value: (hidden_dim,) — the content to retrieve
metadata: (4,) — [step_written, domain_id, importance_score, layer_origin]
If None, defaults to [0, 0, 1.0, 0]
Returns:
slot index where the fossil was written
"""
if metadata is None:
metadata = torch.tensor(
[0.0, 0.0, 1.0, 0.0], device=key.device, dtype=key.dtype
)
# Ensure tensors are on same device
key = key.detach().to(self.fossil_keys.device)
value = value.detach().to(self.fossil_values.device)
metadata = metadata.detach().to(self.fossil_metadata.device)
if self.fossil_count.item() < self.capacity:
# Still have empty slots — use circular write head
slot = self.fossil_write_head.item()
self.fossil_count.add_(1)
else:
# Full — evict lowest importance fossil
importance_scores = self.fossil_metadata[:, 2] # Column 2 = importance
slot = importance_scores.argmin().item()
# Write
self.fossil_keys[slot] = key
self.fossil_values[slot] = value
self.fossil_metadata[slot] = metadata
# Advance write head (wrap around)
self.fossil_write_head.fill_((slot + 1) % self.capacity)
return slot
@torch.no_grad()
def retrieve(
self, query: torch.Tensor, top_k: int = 4
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieve top-k fossils by cosine similarity.
Args:
query: (batch, hidden_dim) — query vectors
Returns:
values: (batch, top_k, hidden_dim) — retrieved fossil values
scores: (batch, top_k) — similarity scores
metadata: (batch, top_k, 4) — metadata for each retrieved fossil
"""
B = query.shape[0]
device = query.device
count = self.fossil_count.item()
# Empty store — return zeros
if count == 0:
return (
torch.zeros(B, top_k, self.hidden_dim, device=device),
torch.zeros(B, top_k, device=device),
torch.zeros(B, top_k, 4, device=device),
)
# Only search over populated slots
active_keys = self.fossil_keys[:count] # (count, hidden_dim)
active_values = self.fossil_values[:count]
active_metadata = self.fossil_metadata[:count]
# Cosine similarity: normalize then matmul
query_norm = F.normalize(query, dim=-1) # (B, D)
keys_norm = F.normalize(active_keys, dim=-1) # (count, D)
# Similarity: (B, count)
similarity = torch.matmul(query_norm, keys_norm.T)
# Handle case where count < top_k
actual_k = min(top_k, count)
top_scores, top_indices = similarity.topk(actual_k, dim=-1) # (B, actual_k)
# Gather values and metadata
# top_indices: (B, actual_k) — indices into active storage
gathered_values = active_values[top_indices] # (B, actual_k, D)
gathered_metadata = active_metadata[top_indices] # (B, actual_k, 4)
# Pad if count < top_k
if actual_k < top_k:
pad_size = top_k - actual_k
gathered_values = F.pad(gathered_values, (0, 0, 0, pad_size))
top_scores = F.pad(top_scores, (0, pad_size))
gathered_metadata = F.pad(gathered_metadata, (0, 0, 0, pad_size))
return gathered_values, top_scores, gathered_metadata
@torch.no_grad()
def compress(
self,
hot_keys: torch.Tensor,
hot_values: torch.Tensor,
hot_usage: torch.Tensor,
importance_threshold: float = 0.7,
step: int = 0,
domain_id: int = 0,
) -> int:
"""
Archive hot memory slots to cold storage.
Filters by usage (top 25% by usage count) and stores qualifying
slots as fossils. Called during hot→cold archival.
Args:
hot_keys: (num_slots, hidden_dim) — hot memory keys
hot_values: (num_slots, hidden_dim) — hot memory values
hot_usage: (num_slots,) — usage count per slot
importance_threshold: minimum relative usage to qualify
step: current training step (for metadata)
domain_id: current domain (for metadata)
Returns:
Number of fossils written
"""
num_slots = hot_keys.shape[0]
if num_slots == 0 or hot_usage.sum() == 0:
return 0
# Top 25% by usage count
threshold_count = max(1, num_slots // 4)
_, top_indices = hot_usage.topk(threshold_count)
# Normalize usage to [0, 1] for importance scoring
usage_max = hot_usage.max()
if usage_max > 0:
normalized_usage = hot_usage / usage_max
else:
return 0
# Filter by importance threshold
written = 0
for idx in top_indices:
idx = idx.item()
importance = normalized_usage[idx].item()
if importance < importance_threshold:
continue
# Skip empty slots (all zeros)
if hot_keys[idx].abs().sum() < 1e-8:
continue
metadata = torch.tensor(
[float(step), float(domain_id), importance, 0.0],
device=hot_keys.device,
dtype=hot_keys.dtype,
)
self.store(hot_keys[idx], hot_values[idx], metadata)
written += 1
return written
def stats(self) -> dict:
"""Return diagnostic statistics about the fossil store."""
count = self.fossil_count.item()
result = {
'count': count,
'capacity': self.capacity,
'usage_pct': count / self.capacity * 100 if self.capacity > 0 else 0,
'write_head': self.fossil_write_head.item(),
}
if count > 0:
active_metadata = self.fossil_metadata[:count]
importance = active_metadata[:, 2]
result.update({
'importance_mean': importance.mean().item(),
'importance_std': importance.std().item() if count > 1 else 0.0,
'importance_min': importance.min().item(),
'importance_max': importance.max().item(),
'oldest_step': active_metadata[:, 0].min().item(),
'newest_step': active_metadata[:, 0].max().item(),
})
else:
result.update({
'importance_mean': 0.0,
'importance_std': 0.0,
'importance_min': 0.0,
'importance_max': 0.0,
'oldest_step': 0.0,
'newest_step': 0.0,
})
return result
def __repr__(self) -> str:
count = self.fossil_count.item()
return (
f"FossilStore(capacity={self.capacity}, "
f"count={count}, "
f"usage={count/self.capacity*100:.1f}%, "
f"hidden_dim={self.hidden_dim})"
)