| """ |
| GLADIUS v2.0 — Three-Temperature Memory V2 |
| |
| The full pipeline, wired. Hot → Warm → Cold, all real. |
| |
| V1 had: Hot (working), Warm (stub LoRA), Cold (returns zeros). |
| V2 has: Hot (working), Warm (RealWarmMemory: Locas+Share+EBLoRA), Cold (FossilStore). |
| + MemoryDigestionPipeline orchestrating the metabolism. |
| |
| Drop-in replacement for ThreeTemperatureMemory. |
| Same API surface: read(), write(), consolidate(), checkpoint(), restore(). |
| New: digest(step), cold_read(query). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import inspect |
| import logging |
| from typing import Optional |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| |
| try: |
| from ..memory import HotMemory |
| from ..warm_memory import RealWarmMemory |
| except ImportError: |
| try: |
| from gladius_v2.src.kernel.memory import HotMemory |
| from gladius_v2.src.kernel.warm_memory import RealWarmMemory |
| except ImportError: |
| |
| HotMemory = None |
| RealWarmMemory = None |
|
|
| from .cold_memory import FossilStore |
| from .memory_pipeline import MemoryDigestionPipeline |
|
|
|
|
| class ThreeTemperatureMemoryV2(nn.Module): |
| """ |
| Three-temperature memory with full digestion pipeline. |
| |
| Hot: Learned KV cache. Session-lived. Importance-gated writes. |
| Warm: Locas GLU-FFN adapters + Share subspace + EBLoRA spectral balance. |
| Cold: FossilStore — native vector store, persists with checkpoint. |
| Pipeline: MemoryDigestionPipeline — orchestrates hot→warm→cold flow. |
| |
| API is backward-compatible with ThreeTemperatureMemory (V1). |
| """ |
|
|
| def __init__( |
| self, |
| config, |
| cold_capacity: int = 8192, |
| cold_scale: float = 0.1, |
| consolidate_every: int = 50, |
| archive_every: int = 100, |
| offload_every: int = 200, |
| ): |
| """ |
| Args: |
| config: KernelConfig (must have hidden_dim, num_layers, warm_rank, etc.) |
| cold_capacity: max fossils in cold storage |
| cold_scale: initial scale for cold memory contribution (ramps up as fossils accumulate) |
| consolidate_every: steps between hot→warm consolidation |
| archive_every: steps between hot→cold archival |
| offload_every: steps between warm→cold offload check |
| """ |
| super().__init__() |
| self.config = config |
| self.cold_scale = cold_scale |
|
|
| |
| self.hot = HotMemory(config) |
|
|
| |
| self.warm = RealWarmMemory(config, num_layers=config.num_layers) |
|
|
| |
| self.cold = FossilStore(config, capacity=cold_capacity) |
|
|
| |
| condition_threshold = getattr(config, 'warm_condition_threshold', 10.0) |
| self.pipeline = MemoryDigestionPipeline( |
| hot=self.hot, |
| warm=self.warm, |
| cold=self.cold, |
| consolidate_every=consolidate_every, |
| archive_every=archive_every, |
| offload_every=offload_every, |
| condition_threshold=condition_threshold, |
| ) |
|
|
| |
| self.cold_gate = nn.Parameter(torch.tensor(cold_scale)) |
|
|
| def read(self, query: torch.Tensor, layer_idx: int = 0) -> torch.Tensor: |
| """ |
| Read from all three memory tiers. |
| |
| Args: |
| query: (batch, seq_len, hidden_dim) |
| layer_idx: which warm adapter to apply (default 0) |
| |
| Returns: |
| (batch, seq_len, hidden_dim) — enriched hidden state |
| """ |
| |
| hot_context = self.hot.read(query) |
|
|
| |
| query_pooled = query.mean(dim=1) |
| cold_context = self.pipeline.cold_read(query_pooled) |
| |
| cold_expanded = cold_context.unsqueeze(1).expand_as(query) |
|
|
| |
| combined = query + hot_context + cold_expanded * self.cold_gate |
|
|
| |
| output = self.warm.forward(combined, layer_idx=layer_idx) |
|
|
| return output |
|
|
| def write(self, hidden: torch.Tensor) -> torch.Tensor: |
| """ |
| Write to hot memory. Returns importance scores. |
| |
| Args: |
| hidden: (batch, seq_len, hidden_dim) |
| |
| Returns: |
| importance: (batch, seq_len, 1) — importance scores from write gate |
| """ |
| return self.hot.write(hidden) |
|
|
| @torch.no_grad() |
| def consolidate(self): |
| """ |
| Hot → Warm consolidation. Backward-compatible API. |
| """ |
| self.pipeline.consolidate_hot_to_warm() |
|
|
| @torch.no_grad() |
| def digest(self, step: int, domain_id: int = 0) -> dict: |
| """ |
| Run the full digestion pipeline. |
| |
| Args: |
| step: current training step |
| domain_id: current domain (for fossil metadata) |
| |
| Returns: |
| dict with stats from all operations performed |
| """ |
| return self.pipeline.digest(step, domain_id=domain_id) |
|
|
| def cold_read(self, query: torch.Tensor, top_k: int = 4) -> torch.Tensor: |
| """ |
| Direct cold memory read (bypass warm adapter). |
| |
| Args: |
| query: (batch, hidden_dim) |
| top_k: number of fossils to retrieve |
| |
| Returns: |
| (batch, hidden_dim) — weighted fossil context |
| """ |
| return self.pipeline.cold_read(query, top_k=top_k) |
|
|
| def checkpoint(self, path: str): |
| """ |
| Save all memory state to disk. |
| |
| Hot: keys, values, usage (saved via state_dict automatically) |
| Warm: adapters, subspace trackers, spectral history |
| Cold: fossil buffers (saved via state_dict automatically) |
| """ |
| state = { |
| 'hot_state_dict': self.hot.state_dict(), |
| 'warm_state': { |
| 'state_dict': self.warm.state_dict(), |
| 'update_count': self.warm.update_count.item(), |
| }, |
| 'cold_state_dict': self.cold.state_dict(), |
| 'cold_gate': self.cold_gate.data, |
| 'pipeline_stats': { |
| 'total_consolidated': self.pipeline.total_consolidated, |
| 'total_archived': self.pipeline.total_archived, |
| 'total_offloaded': self.pipeline.total_offloaded, |
| }, |
| } |
|
|
| |
| if hasattr(self.warm, 'trackers'): |
| state['warm_state']['subspace_states'] = [ |
| { |
| 'basis': t.basis.clone(), |
| 'importance': t.importance.clone(), |
| 'initialized': t.initialized, |
| } |
| for t in self.warm.trackers |
| ] |
|
|
| if hasattr(self.warm, 'balancer'): |
| state['warm_state']['spectral_history'] = self.warm.balancer.history[-100:] |
|
|
| torch.save(state, path) |
| logger.info( |
| f"Memory checkpoint saved: hot={self.hot.usage.sum().item():.0f} usage, " |
| f"warm={self.warm.update_count.item()} updates, " |
| f"cold={self.cold.fossil_count.item()} fossils" |
| ) |
|
|
| def restore(self, path: str): |
| """ |
| Restore all memory state from disk. |
| """ |
| state = torch.load(path, weights_only=False) |
|
|
| |
| self.hot.load_state_dict(state['hot_state_dict']) |
|
|
| |
| warm_state = state['warm_state'] |
| self.warm.load_state_dict(warm_state['state_dict'], strict=False) |
| self.warm.update_count.fill_(warm_state['update_count']) |
| if 'subspace_states' in warm_state and hasattr(self.warm, 'trackers'): |
| for tracker, ss in zip(self.warm.trackers, warm_state['subspace_states']): |
| tracker.basis = ss['basis'] |
| tracker.importance = ss['importance'] |
| tracker.initialized = ss.get('initialized', True) |
| if 'spectral_history' in warm_state and hasattr(self.warm, 'balancer'): |
| self.warm.balancer.history = warm_state['spectral_history'] |
|
|
| |
| self.cold.load_state_dict(state['cold_state_dict']) |
|
|
| |
| if 'cold_gate' in state: |
| self.cold_gate.data = state['cold_gate'] |
|
|
| |
| if 'pipeline_stats' in state: |
| ps = state['pipeline_stats'] |
| self.pipeline.total_consolidated = ps.get('total_consolidated', 0) |
| self.pipeline.total_archived = ps.get('total_archived', 0) |
| self.pipeline.total_offloaded = ps.get('total_offloaded', 0) |
|
|
| logger.info( |
| f"Memory restored: cold={self.cold.fossil_count.item()} fossils, " |
| f"warm={self.warm.update_count.item()} updates" |
| ) |
|
|
| def diagnostics(self) -> dict: |
| """ |
| Full memory system diagnostics. Use in heartbeat. |
| """ |
| diag = { |
| 'hot': { |
| 'total_usage': self.hot.usage.sum().item(), |
| 'max_usage': self.hot.usage.max().item(), |
| 'min_usage': self.hot.usage.min().item(), |
| 'active_slots': (self.hot.usage > 0).sum().item(), |
| 'write_head': self.hot.write_head.item(), |
| }, |
| 'warm': { |
| 'update_count': self.warm.update_count.item(), |
| }, |
| 'cold': self.cold.stats(), |
| 'cold_gate': self.cold_gate.item(), |
| 'pipeline': { |
| 'total_consolidated': self.pipeline.total_consolidated, |
| 'total_archived': self.pipeline.total_archived, |
| 'total_offloaded': self.pipeline.total_offloaded, |
| }, |
| } |
|
|
| |
| if hasattr(self.warm, 'adapters') and hasattr(self.warm, 'balancer'): |
| cns = [] |
| for adapter in self.warm.adapters: |
| cns.append(self.warm.balancer.condition_number(adapter)) |
| diag['warm']['condition_numbers'] = cns |
| diag['warm']['avg_condition_number'] = sum(cns) / len(cns) if cns else 0.0 |
|
|
| |
| if hasattr(self.warm, 'adapters'): |
| diag['warm']['adapter_scales'] = [ |
| a.scale.item() for a in self.warm.adapters |
| ] |
|
|
| return diag |
|
|
| def reset_hot(self): |
| """Clear hot memory (session boundary). Warm and cold persist.""" |
| self.hot.reset() |
|
|
| def __repr__(self) -> str: |
| cold_count = self.cold.fossil_count.item() |
| warm_updates = self.warm.update_count.item() |
| return ( |
| f"ThreeTemperatureMemoryV2(" |
| f"hot_slots={self.config.hot_memory_slots}, " |
| f"warm_layers={len(self.warm.adapters)}, " |
| f"warm_updates={warm_updates}, " |
| f"cold_fossils={cold_count}/{self.cold.capacity}, " |
| f"cold_gate={self.cold_gate.item():.4f})" |
| ) |
|
|