Gladius / extensions /memory_v2 /memory_v2.py
amuzetnoM's picture
GLADIUS v5.0 — Cognitive kernel with Synthase depth attention, PUP uncertainty, Memory V2, multi-tokenizer architecture
3f42614
"""
GLADIUS v2.0 — Three-Temperature Memory V2
The full pipeline, wired. Hot → Warm → Cold, all real.
V1 had: Hot (working), Warm (stub LoRA), Cold (returns zeros).
V2 has: Hot (working), Warm (RealWarmMemory: Locas+Share+EBLoRA), Cold (FossilStore).
+ MemoryDigestionPipeline orchestrating the metabolism.
Drop-in replacement for ThreeTemperatureMemory.
Same API surface: read(), write(), consolidate(), checkpoint(), restore().
New: digest(step), cold_read(query).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import inspect
import logging
from typing import Optional
logger = logging.getLogger(__name__)
# Import from siblings in the kernel package
# These paths work both in the gladius_v2/src/kernel/ tree
# and in the staging/kernel/memory/ tree.
# When deployed to production kernel/, adjust imports accordingly.
try:
from ..memory import HotMemory
from ..warm_memory import RealWarmMemory
except ImportError:
try:
from gladius_v2.src.kernel.memory import HotMemory
from gladius_v2.src.kernel.warm_memory import RealWarmMemory
except ImportError:
# Fallback: allow standalone testing with local definitions
HotMemory = None
RealWarmMemory = None
from .cold_memory import FossilStore
from .memory_pipeline import MemoryDigestionPipeline
class ThreeTemperatureMemoryV2(nn.Module):
"""
Three-temperature memory with full digestion pipeline.
Hot: Learned KV cache. Session-lived. Importance-gated writes.
Warm: Locas GLU-FFN adapters + Share subspace + EBLoRA spectral balance.
Cold: FossilStore — native vector store, persists with checkpoint.
Pipeline: MemoryDigestionPipeline — orchestrates hot→warm→cold flow.
API is backward-compatible with ThreeTemperatureMemory (V1).
"""
def __init__(
self,
config,
cold_capacity: int = 8192,
cold_scale: float = 0.1,
consolidate_every: int = 50,
archive_every: int = 100,
offload_every: int = 200,
):
"""
Args:
config: KernelConfig (must have hidden_dim, num_layers, warm_rank, etc.)
cold_capacity: max fossils in cold storage
cold_scale: initial scale for cold memory contribution (ramps up as fossils accumulate)
consolidate_every: steps between hot→warm consolidation
archive_every: steps between hot→cold archival
offload_every: steps between warm→cold offload check
"""
super().__init__()
self.config = config
self.cold_scale = cold_scale
# === Hot Memory ===
self.hot = HotMemory(config)
# === Warm Memory (the real one) ===
self.warm = RealWarmMemory(config, num_layers=config.num_layers)
# === Cold Memory (fossil store) ===
self.cold = FossilStore(config, capacity=cold_capacity)
# === Digestion Pipeline ===
condition_threshold = getattr(config, 'warm_condition_threshold', 10.0)
self.pipeline = MemoryDigestionPipeline(
hot=self.hot,
warm=self.warm,
cold=self.cold,
consolidate_every=consolidate_every,
archive_every=archive_every,
offload_every=offload_every,
condition_threshold=condition_threshold,
)
# Learnable cold blend gate — starts at cold_scale, can be tuned
self.cold_gate = nn.Parameter(torch.tensor(cold_scale))
def read(self, query: torch.Tensor, layer_idx: int = 0) -> torch.Tensor:
"""
Read from all three memory tiers.
Args:
query: (batch, seq_len, hidden_dim)
layer_idx: which warm adapter to apply (default 0)
Returns:
(batch, seq_len, hidden_dim) — enriched hidden state
"""
# 1. Hot: attention over session cache
hot_context = self.hot.read(query) # (B, S, D)
# 2. Cold: fossil retrieval (mean-pool query → retrieve → expand)
query_pooled = query.mean(dim=1) # (B, D)
cold_context = self.pipeline.cold_read(query_pooled) # (B, D)
# Expand to sequence length and scale
cold_expanded = cold_context.unsqueeze(1).expand_as(query) # (B, S, D)
# 3. Combine: query + hot context + scaled cold context
combined = query + hot_context + cold_expanded * self.cold_gate
# 4. Warm: apply per-layer adapter
output = self.warm.forward(combined, layer_idx=layer_idx)
return output
def write(self, hidden: torch.Tensor) -> torch.Tensor:
"""
Write to hot memory. Returns importance scores.
Args:
hidden: (batch, seq_len, hidden_dim)
Returns:
importance: (batch, seq_len, 1) — importance scores from write gate
"""
return self.hot.write(hidden)
@torch.no_grad()
def consolidate(self):
"""
Hot → Warm consolidation. Backward-compatible API.
"""
self.pipeline.consolidate_hot_to_warm()
@torch.no_grad()
def digest(self, step: int, domain_id: int = 0) -> dict:
"""
Run the full digestion pipeline.
Args:
step: current training step
domain_id: current domain (for fossil metadata)
Returns:
dict with stats from all operations performed
"""
return self.pipeline.digest(step, domain_id=domain_id)
def cold_read(self, query: torch.Tensor, top_k: int = 4) -> torch.Tensor:
"""
Direct cold memory read (bypass warm adapter).
Args:
query: (batch, hidden_dim)
top_k: number of fossils to retrieve
Returns:
(batch, hidden_dim) — weighted fossil context
"""
return self.pipeline.cold_read(query, top_k=top_k)
def checkpoint(self, path: str):
"""
Save all memory state to disk.
Hot: keys, values, usage (saved via state_dict automatically)
Warm: adapters, subspace trackers, spectral history
Cold: fossil buffers (saved via state_dict automatically)
"""
state = {
'hot_state_dict': self.hot.state_dict(),
'warm_state': {
'state_dict': self.warm.state_dict(),
'update_count': self.warm.update_count.item(),
},
'cold_state_dict': self.cold.state_dict(),
'cold_gate': self.cold_gate.data,
'pipeline_stats': {
'total_consolidated': self.pipeline.total_consolidated,
'total_archived': self.pipeline.total_archived,
'total_offloaded': self.pipeline.total_offloaded,
},
}
# Save warm subspace trackers if available
if hasattr(self.warm, 'trackers'):
state['warm_state']['subspace_states'] = [
{
'basis': t.basis.clone(),
'importance': t.importance.clone(),
'initialized': t.initialized,
}
for t in self.warm.trackers
]
if hasattr(self.warm, 'balancer'):
state['warm_state']['spectral_history'] = self.warm.balancer.history[-100:]
torch.save(state, path)
logger.info(
f"Memory checkpoint saved: hot={self.hot.usage.sum().item():.0f} usage, "
f"warm={self.warm.update_count.item()} updates, "
f"cold={self.cold.fossil_count.item()} fossils"
)
def restore(self, path: str):
"""
Restore all memory state from disk.
"""
state = torch.load(path, weights_only=False)
# Hot
self.hot.load_state_dict(state['hot_state_dict'])
# Warm
warm_state = state['warm_state']
self.warm.load_state_dict(warm_state['state_dict'], strict=False)
self.warm.update_count.fill_(warm_state['update_count'])
if 'subspace_states' in warm_state and hasattr(self.warm, 'trackers'):
for tracker, ss in zip(self.warm.trackers, warm_state['subspace_states']):
tracker.basis = ss['basis']
tracker.importance = ss['importance']
tracker.initialized = ss.get('initialized', True)
if 'spectral_history' in warm_state and hasattr(self.warm, 'balancer'):
self.warm.balancer.history = warm_state['spectral_history']
# Cold
self.cold.load_state_dict(state['cold_state_dict'])
# Cold gate
if 'cold_gate' in state:
self.cold_gate.data = state['cold_gate']
# Pipeline stats
if 'pipeline_stats' in state:
ps = state['pipeline_stats']
self.pipeline.total_consolidated = ps.get('total_consolidated', 0)
self.pipeline.total_archived = ps.get('total_archived', 0)
self.pipeline.total_offloaded = ps.get('total_offloaded', 0)
logger.info(
f"Memory restored: cold={self.cold.fossil_count.item()} fossils, "
f"warm={self.warm.update_count.item()} updates"
)
def diagnostics(self) -> dict:
"""
Full memory system diagnostics. Use in heartbeat.
"""
diag = {
'hot': {
'total_usage': self.hot.usage.sum().item(),
'max_usage': self.hot.usage.max().item(),
'min_usage': self.hot.usage.min().item(),
'active_slots': (self.hot.usage > 0).sum().item(),
'write_head': self.hot.write_head.item(),
},
'warm': {
'update_count': self.warm.update_count.item(),
},
'cold': self.cold.stats(),
'cold_gate': self.cold_gate.item(),
'pipeline': {
'total_consolidated': self.pipeline.total_consolidated,
'total_archived': self.pipeline.total_archived,
'total_offloaded': self.pipeline.total_offloaded,
},
}
# Warm condition numbers
if hasattr(self.warm, 'adapters') and hasattr(self.warm, 'balancer'):
cns = []
for adapter in self.warm.adapters:
cns.append(self.warm.balancer.condition_number(adapter))
diag['warm']['condition_numbers'] = cns
diag['warm']['avg_condition_number'] = sum(cns) / len(cns) if cns else 0.0
# Warm adapter scales
if hasattr(self.warm, 'adapters'):
diag['warm']['adapter_scales'] = [
a.scale.item() for a in self.warm.adapters
]
return diag
def reset_hot(self):
"""Clear hot memory (session boundary). Warm and cold persist."""
self.hot.reset()
def __repr__(self) -> str:
cold_count = self.cold.fossil_count.item()
warm_updates = self.warm.update_count.item()
return (
f"ThreeTemperatureMemoryV2("
f"hot_slots={self.config.hot_memory_slots}, "
f"warm_layers={len(self.warm.adapters)}, "
f"warm_updates={warm_updates}, "
f"cold_fossils={cold_count}/{self.cold.capacity}, "
f"cold_gate={self.cold_gate.item():.4f})"
)