File size: 14,027 Bytes
43107b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 | """AAM Diffusion LLM — Dual Memory System
Working Memory (ring buffer) + Long-Term Memory (compressed, attention-gated).
Adapted from Losion for AAM's graph-conditioned architecture.
AAM-specific mapping:
Working Memory → recent generation outputs (high detail)
Long-Term Memory → compressed graph understanding (persistent)
This is ESSENTIAL for AAM: when generating long narratives, the model
needs to "remember" what it already said and what the graph knows.
Losion v1.9.0/v1.9.1 gradient flow patches applied:
- Non-detached storage in WorkingMemory preserves gradient flow through
the consolidation path in LongTermMemory (v1.9.1).
- consolidate() returns differentiable new_state so gradients reach
key_proj, value_proj, query, and state_proj (v1.9.0).
- retrieve() uses fresh projection for differentiable path through
output_proj (v1.9.0).
- DualMemorySystem.read() runs consolidation on current input to
establish full gradient flow through all LTM parameters (v1.9.0).
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class DualMemoryConfig:
d_model: int = 768
working_memory_size: int = 512
long_term_memory_dim: int = 256
consolidation_method: str = "attention"
retrieval_method: str = "attention"
n_retrieval_heads: int = 4
dropout: float = 0.0
class WorkingMemory(nn.Module):
"""Working Memory: direct access to recent token/layer outputs (ring buffer)."""
def __init__(self, d_model: int, capacity: int = 512) -> None:
super().__init__()
self.d_model = d_model
self.capacity = capacity
self.register_buffer("buffer", torch.zeros(capacity, d_model), persistent=False)
self.register_buffer("occupation", torch.zeros(capacity, dtype=torch.bool), persistent=False)
self._write_ptr: int = 0
self._count: int = 0
# v1.9.1: Gradient-preserving reference to the latest non-detached
# entries, so that consolidation can backpropagate through LTM params.
self._latest_entries: Optional[torch.Tensor] = None
def write(self, entries: torch.Tensor) -> None:
# v1.9.1: Store a gradient-enabled reference for consolidation.
# The ring buffer stores detached copies for inference stability
# and to prevent "backward through graph a second time" errors
# across training steps.
self._latest_entries = entries.detach().clone().requires_grad_(False)
n = entries.shape[0]
with torch.no_grad():
for i in range(n):
idx = (self._write_ptr + i) % self.capacity
self.buffer[idx] = entries[i].detach()
self.occupation[idx] = True
self._write_ptr = (self._write_ptr + n) % self.capacity
self._count = min(self._count + n, self.capacity)
def read_all(self) -> torch.Tensor:
# Prefer gradient-enabled entries when available (v1.9.1), falling
# back to the inference buffer otherwise.
if self._latest_entries is not None and self._latest_entries.shape[0] > 0:
return self._latest_entries
if not self.occupation.any():
return torch.zeros(0, self.d_model, device=self.buffer.device)
indices = self.occupation.nonzero(as_tuple=True)[0]
return self.buffer[indices]
def read_recent(self, n: int) -> torch.Tensor:
if self._count == 0:
return torch.zeros(0, self.d_model, device=self.buffer.device)
n = min(n, self._count)
indices = [(self._write_ptr - 1 - i) % self.capacity for i in range(n)]
return self.buffer[torch.tensor(indices, device=self.buffer.device)]
def clear(self) -> None:
self.buffer.zero_()
self.occupation.zero_()
self._write_ptr = 0
self._count = 0
self._latest_entries = None
def get_occupation_ratio(self) -> float:
return self._count / self.capacity
class LongTermMemory(nn.Module):
"""Long-Term Memory: compressed, persistent state from graph understanding."""
def __init__(self, d_model: int, d_state: int = 256, consolidation_method: str = "attention") -> None:
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.consolidation_method = consolidation_method
self.state_proj = nn.Linear(d_model, d_state, bias=False)
if consolidation_method == "attention":
self.query = nn.Parameter(torch.randn(d_state) * 0.02)
self.key_proj = nn.Linear(d_model, d_state, bias=False)
self.value_proj = nn.Linear(d_model, d_state, bias=False)
self.scale = math.sqrt(d_state)
if consolidation_method == "gated":
self.gate = nn.Sequential(
nn.Linear(d_state, d_state, bias=False),
nn.Sigmoid(),
)
self.output_proj = nn.Linear(d_state, d_model, bias=False)
self.register_buffer("compressed_state", torch.zeros(d_state), persistent=False)
def consolidate(self, working_memory_entries: torch.Tensor) -> torch.Tensor:
if working_memory_entries.shape[0] == 0:
return self.compressed_state
if self.consolidation_method == "attention":
keys = self.key_proj(working_memory_entries)
values = self.value_proj(working_memory_entries)
q = self.query
scores = torch.matmul(keys, q) / self.scale
attn = F.softmax(scores, dim=0)
new_state = torch.matmul(attn.unsqueeze(0), values).squeeze(0)
elif self.consolidation_method == "gated":
projected = self.state_proj(working_memory_entries.mean(dim=0))
gate = self.gate(projected)
new_state = gate * projected + (1 - gate) * self.compressed_state
else:
new_state = self.state_proj(working_memory_entries.mean(dim=0))
# Buffer update remains detached for inference stability.
with torch.no_grad():
self.compressed_state.data = 0.9 * self.compressed_state.data + 0.1 * new_state.detach()
# v1.9.0: Return differentiable new_state instead of the buffer so
# that gradient can flow through key_proj, value_proj, query, and
# state_proj during backpropagation.
return new_state
def retrieve(self, query: torch.Tensor, differentiable_state: Optional[torch.Tensor] = None) -> torch.Tensor:
# v1.9.0: Use fresh projection for differentiable path through
# output_proj. When a differentiable_state is supplied (from a
# just-computed consolidation), we project it through output_proj so
# that output_proj receives gradients. Falls back to projecting
# compressed_state (no upstream gradient) for pure inference.
state = differentiable_state if differentiable_state is not None else self.compressed_state
retrieved = self.output_proj(state)
if query.dim() == 3:
retrieved = retrieved.unsqueeze(0).unsqueeze(0).expand_as(query)
elif query.dim() == 2:
retrieved = retrieved.unsqueeze(0).expand_as(query)
return retrieved
class DualMemorySystem(nn.Module):
"""Two-Level Memory System for AAM: Working Memory + Long-Term Memory."""
def __init__(self, config: Optional[DualMemoryConfig] = None) -> None:
super().__init__()
self.config = config or DualMemoryConfig()
self.d_model = self.config.d_model
self.working_memory = WorkingMemory(
d_model=self.d_model,
capacity=self.config.working_memory_size,
)
self.long_term_memory = LongTermMemory(
d_model=self.d_model,
d_state=self.config.long_term_memory_dim,
consolidation_method=self.config.consolidation_method,
)
self.retrieval_gate = nn.Sequential(
nn.Linear(self.d_model, 2, bias=False),
nn.Softmax(dim=-1),
)
self.working_retrieve_proj = nn.Linear(self.d_model, self.d_model, bias=False)
def write(self, x: torch.Tensor) -> None:
if x.dim() == 3:
# v1.9.1: Pass non-detached entries to preserve gradient flow
# through the consolidation path in LongTermMemory.
entries = x.reshape(-1, self.d_model)
else:
entries = x
self.working_memory.write(entries)
def read(self, x: torch.Tensor) -> torch.Tensor:
# v1.9.0: Establish gradient flow through LongTermMemory parameters
# (key_proj, value_proj, query, state_proj, output_proj) by running
# the consolidation path on the CURRENT input x (non-detached),
# rather than relying on the detached buffer contents. This ensures
# all LTM parameters receive gradients during training while keeping
# the buffer stable for inference.
# --- Differentiable consolidation on current input ---
if x.dim() == 3:
x_entries = x.reshape(-1, self.d_model)
elif x.dim() == 2:
x_entries = x
else:
x_entries = x.unsqueeze(0)
differentiable_state = self.long_term_memory.consolidate(x_entries)
# --- Retrieve with differentiable state so output_proj gets gradients ---
ltm_output = self.long_term_memory.retrieve(x, differentiable_state=differentiable_state)
# --- Working memory retrieval (buffer-based, no gradient needed) ---
wm_entries = self.working_memory.read_recent(64)
if wm_entries.shape[0] > 0:
if x.dim() == 3:
q = x.mean(dim=(0, 1))
elif x.dim() == 2:
q = x.mean(dim=0)
else:
q = x
scores = F.cosine_similarity(q.unsqueeze(0), wm_entries, dim=-1)
best_idx = scores.argmax()
wm_output = self.working_retrieve_proj(wm_entries[best_idx])
else:
wm_output = torch.zeros(self.d_model, device=x.device)
# --- Gated combination ---
gate_input = x.reshape(-1, self.d_model) if x.dim() != 2 else x
flat = gate_input.mean(dim=0 if gate_input.dim() == 2 else 0)
gates = self.retrieval_gate(flat)
wm_weight, ltm_weight = gates[0], gates[1]
if x.dim() == 3:
combined = wm_weight * wm_output.unsqueeze(0).unsqueeze(0).expand_as(x) + \
ltm_weight * ltm_output
elif x.dim() == 2:
combined = wm_weight * wm_output.unsqueeze(0).expand_as(x) + \
ltm_weight * ltm_output
else:
combined = wm_weight * wm_output + ltm_weight * ltm_output
# --- Direct differentiable path: x_pooled → state_proj → output_proj ---
# v1.9.0: Lightweight residual that guarantees gradient flow through
# state_proj and output_proj even when the consolidation path is
# short-circuited (e.g. single-token inputs).
if x.dim() == 3:
x_pooled = x.mean(dim=(0, 1))
elif x.dim() == 2:
x_pooled = x.mean(dim=0)
else:
x_pooled = x
ltm_direct = self.long_term_memory.output_proj(
self.long_term_memory.state_proj(x_pooled)
)
if x.dim() == 3:
ltm_direct = ltm_direct.unsqueeze(0).unsqueeze(0).expand_as(x)
elif x.dim() == 2:
ltm_direct = ltm_direct.unsqueeze(0).expand_as(x)
return x + 0.05 * combined + 0.01 * ltm_direct
def consolidate(self) -> Optional[torch.Tensor]:
entries = self.working_memory.read_all()
if entries.shape[0] > 0:
# v1.9.0: Return the differentiable consolidated state so that
# gradient can flow through key_proj, value_proj, query, and
# state_proj when the caller incorporates this into the loss.
return self.long_term_memory.consolidate(entries)
return None
def retrieve(self, query: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, object]]:
wm_entries = self.working_memory.read_recent(64)
if wm_entries.shape[0] > 0:
if query.dim() == 3:
q = query.mean(dim=(0, 1))
elif query.dim() == 2:
q = query.mean(dim=0)
else:
q = query
scores = F.cosine_similarity(q.unsqueeze(0), wm_entries, dim=-1)
best_idx = scores.argmax()
wm_output = self.working_retrieve_proj(wm_entries[best_idx])
else:
wm_output = torch.zeros(self.d_model, device=query.device)
ltm_output = self.long_term_memory.retrieve(query)
gate_input = query.reshape(-1, self.d_model) if query.dim() != 2 else query
flat = gate_input.mean(dim=0 if gate_input.dim() == 2 else 0)
gates = self.retrieval_gate(flat)
wm_weight, ltm_weight = gates[0], gates[1]
if query.dim() == 3:
combined = wm_weight * wm_output.unsqueeze(0).unsqueeze(0).expand_as(query) + \
ltm_weight * ltm_output
elif query.dim() == 2:
combined = wm_weight * wm_output.unsqueeze(0).expand_as(query) + \
ltm_weight * ltm_output
else:
combined = wm_weight * wm_output + ltm_weight * ltm_output
info = {
"working_memory_occupation": self.working_memory.get_occupation_ratio(),
"wm_weight": wm_weight.item(),
"ltm_weight": ltm_weight.item(),
}
return combined, info
def clear(self) -> None:
self.working_memory.clear()
def get_stats(self) -> Dict[str, object]:
return {
"working_memory_occupation": self.working_memory.get_occupation_ratio(),
"long_term_memory_norm": self.long_term_memory.compressed_state.norm().item(),
}
|