""" GLADIUS v2.0 — Three-Temperature Memory V2 The full pipeline, wired. Hot → Warm → Cold, all real. V1 had: Hot (working), Warm (stub LoRA), Cold (returns zeros). V2 has: Hot (working), Warm (RealWarmMemory: Locas+Share+EBLoRA), Cold (FossilStore). + MemoryDigestionPipeline orchestrating the metabolism. Drop-in replacement for ThreeTemperatureMemory. Same API surface: read(), write(), consolidate(), checkpoint(), restore(). New: digest(step), cold_read(query). """ import torch import torch.nn as nn import torch.nn.functional as F import math import inspect import logging from typing import Optional logger = logging.getLogger(__name__) # Import from siblings in the kernel package # These paths work both in the gladius_v2/src/kernel/ tree # and in the staging/kernel/memory/ tree. # When deployed to production kernel/, adjust imports accordingly. try: from ..memory import HotMemory from ..warm_memory import RealWarmMemory except ImportError: try: from gladius_v2.src.kernel.memory import HotMemory from gladius_v2.src.kernel.warm_memory import RealWarmMemory except ImportError: # Fallback: allow standalone testing with local definitions HotMemory = None RealWarmMemory = None from .cold_memory import FossilStore from .memory_pipeline import MemoryDigestionPipeline class ThreeTemperatureMemoryV2(nn.Module): """ Three-temperature memory with full digestion pipeline. Hot: Learned KV cache. Session-lived. Importance-gated writes. Warm: Locas GLU-FFN adapters + Share subspace + EBLoRA spectral balance. Cold: FossilStore — native vector store, persists with checkpoint. Pipeline: MemoryDigestionPipeline — orchestrates hot→warm→cold flow. API is backward-compatible with ThreeTemperatureMemory (V1). """ def __init__( self, config, cold_capacity: int = 8192, cold_scale: float = 0.1, consolidate_every: int = 50, archive_every: int = 100, offload_every: int = 200, ): """ Args: config: KernelConfig (must have hidden_dim, num_layers, warm_rank, etc.) cold_capacity: max fossils in cold storage cold_scale: initial scale for cold memory contribution (ramps up as fossils accumulate) consolidate_every: steps between hot→warm consolidation archive_every: steps between hot→cold archival offload_every: steps between warm→cold offload check """ super().__init__() self.config = config self.cold_scale = cold_scale # === Hot Memory === self.hot = HotMemory(config) # === Warm Memory (the real one) === self.warm = RealWarmMemory(config, num_layers=config.num_layers) # === Cold Memory (fossil store) === self.cold = FossilStore(config, capacity=cold_capacity) # === Digestion Pipeline === condition_threshold = getattr(config, 'warm_condition_threshold', 10.0) self.pipeline = MemoryDigestionPipeline( hot=self.hot, warm=self.warm, cold=self.cold, consolidate_every=consolidate_every, archive_every=archive_every, offload_every=offload_every, condition_threshold=condition_threshold, ) # Learnable cold blend gate — starts at cold_scale, can be tuned self.cold_gate = nn.Parameter(torch.tensor(cold_scale)) def read(self, query: torch.Tensor, layer_idx: int = 0) -> torch.Tensor: """ Read from all three memory tiers. Args: query: (batch, seq_len, hidden_dim) layer_idx: which warm adapter to apply (default 0) Returns: (batch, seq_len, hidden_dim) — enriched hidden state """ # 1. Hot: attention over session cache hot_context = self.hot.read(query) # (B, S, D) # 2. Cold: fossil retrieval (mean-pool query → retrieve → expand) query_pooled = query.mean(dim=1) # (B, D) cold_context = self.pipeline.cold_read(query_pooled) # (B, D) # Expand to sequence length and scale cold_expanded = cold_context.unsqueeze(1).expand_as(query) # (B, S, D) # 3. Combine: query + hot context + scaled cold context combined = query + hot_context + cold_expanded * self.cold_gate # 4. Warm: apply per-layer adapter output = self.warm.forward(combined, layer_idx=layer_idx) return output def write(self, hidden: torch.Tensor) -> torch.Tensor: """ Write to hot memory. Returns importance scores. Args: hidden: (batch, seq_len, hidden_dim) Returns: importance: (batch, seq_len, 1) — importance scores from write gate """ return self.hot.write(hidden) @torch.no_grad() def consolidate(self): """ Hot → Warm consolidation. Backward-compatible API. """ self.pipeline.consolidate_hot_to_warm() @torch.no_grad() def digest(self, step: int, domain_id: int = 0) -> dict: """ Run the full digestion pipeline. Args: step: current training step domain_id: current domain (for fossil metadata) Returns: dict with stats from all operations performed """ return self.pipeline.digest(step, domain_id=domain_id) def cold_read(self, query: torch.Tensor, top_k: int = 4) -> torch.Tensor: """ Direct cold memory read (bypass warm adapter). Args: query: (batch, hidden_dim) top_k: number of fossils to retrieve Returns: (batch, hidden_dim) — weighted fossil context """ return self.pipeline.cold_read(query, top_k=top_k) def checkpoint(self, path: str): """ Save all memory state to disk. Hot: keys, values, usage (saved via state_dict automatically) Warm: adapters, subspace trackers, spectral history Cold: fossil buffers (saved via state_dict automatically) """ state = { 'hot_state_dict': self.hot.state_dict(), 'warm_state': { 'state_dict': self.warm.state_dict(), 'update_count': self.warm.update_count.item(), }, 'cold_state_dict': self.cold.state_dict(), 'cold_gate': self.cold_gate.data, 'pipeline_stats': { 'total_consolidated': self.pipeline.total_consolidated, 'total_archived': self.pipeline.total_archived, 'total_offloaded': self.pipeline.total_offloaded, }, } # Save warm subspace trackers if available if hasattr(self.warm, 'trackers'): state['warm_state']['subspace_states'] = [ { 'basis': t.basis.clone(), 'importance': t.importance.clone(), 'initialized': t.initialized, } for t in self.warm.trackers ] if hasattr(self.warm, 'balancer'): state['warm_state']['spectral_history'] = self.warm.balancer.history[-100:] torch.save(state, path) logger.info( f"Memory checkpoint saved: hot={self.hot.usage.sum().item():.0f} usage, " f"warm={self.warm.update_count.item()} updates, " f"cold={self.cold.fossil_count.item()} fossils" ) def restore(self, path: str): """ Restore all memory state from disk. """ state = torch.load(path, weights_only=False) # Hot self.hot.load_state_dict(state['hot_state_dict']) # Warm warm_state = state['warm_state'] self.warm.load_state_dict(warm_state['state_dict'], strict=False) self.warm.update_count.fill_(warm_state['update_count']) if 'subspace_states' in warm_state and hasattr(self.warm, 'trackers'): for tracker, ss in zip(self.warm.trackers, warm_state['subspace_states']): tracker.basis = ss['basis'] tracker.importance = ss['importance'] tracker.initialized = ss.get('initialized', True) if 'spectral_history' in warm_state and hasattr(self.warm, 'balancer'): self.warm.balancer.history = warm_state['spectral_history'] # Cold self.cold.load_state_dict(state['cold_state_dict']) # Cold gate if 'cold_gate' in state: self.cold_gate.data = state['cold_gate'] # Pipeline stats if 'pipeline_stats' in state: ps = state['pipeline_stats'] self.pipeline.total_consolidated = ps.get('total_consolidated', 0) self.pipeline.total_archived = ps.get('total_archived', 0) self.pipeline.total_offloaded = ps.get('total_offloaded', 0) logger.info( f"Memory restored: cold={self.cold.fossil_count.item()} fossils, " f"warm={self.warm.update_count.item()} updates" ) def diagnostics(self) -> dict: """ Full memory system diagnostics. Use in heartbeat. """ diag = { 'hot': { 'total_usage': self.hot.usage.sum().item(), 'max_usage': self.hot.usage.max().item(), 'min_usage': self.hot.usage.min().item(), 'active_slots': (self.hot.usage > 0).sum().item(), 'write_head': self.hot.write_head.item(), }, 'warm': { 'update_count': self.warm.update_count.item(), }, 'cold': self.cold.stats(), 'cold_gate': self.cold_gate.item(), 'pipeline': { 'total_consolidated': self.pipeline.total_consolidated, 'total_archived': self.pipeline.total_archived, 'total_offloaded': self.pipeline.total_offloaded, }, } # Warm condition numbers if hasattr(self.warm, 'adapters') and hasattr(self.warm, 'balancer'): cns = [] for adapter in self.warm.adapters: cns.append(self.warm.balancer.condition_number(adapter)) diag['warm']['condition_numbers'] = cns diag['warm']['avg_condition_number'] = sum(cns) / len(cns) if cns else 0.0 # Warm adapter scales if hasattr(self.warm, 'adapters'): diag['warm']['adapter_scales'] = [ a.scale.item() for a in self.warm.adapters ] return diag def reset_hot(self): """Clear hot memory (session boundary). Warm and cold persist.""" self.hot.reset() def __repr__(self) -> str: cold_count = self.cold.fossil_count.item() warm_updates = self.warm.update_count.item() return ( f"ThreeTemperatureMemoryV2(" f"hot_slots={self.config.hot_memory_slots}, " f"warm_layers={len(self.warm.adapters)}, " f"warm_updates={warm_updates}, " f"cold_fossils={cold_count}/{self.cold.capacity}, " f"cold_gate={self.cold_gate.item():.4f})" )