|
|
from __future__ import annotations |
|
|
import json |
|
|
import os |
|
|
import uuid |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, List, Any, Optional |
|
|
|
|
|
@dataclass |
|
|
class MemoryNode: |
|
|
id: str |
|
|
type: str |
|
|
content: Any |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
|
class MemoryEdge: |
|
|
source: str |
|
|
target: str |
|
|
relation: str |
|
|
|
|
|
class MemoryLattice: |
|
|
""" |
|
|
AXIS Memory Lattice Layer (MLL) |
|
|
Manages persistent graph of past conversations and verified facts. |
|
|
""" |
|
|
def __init__(self, storage_path: str): |
|
|
self.storage_path = storage_path |
|
|
self.nodes: Dict[str, MemoryNode] = {} |
|
|
self.edges: List[MemoryEdge] = [] |
|
|
self._load() |
|
|
|
|
|
def _load(self): |
|
|
if os.path.exists(self.storage_path): |
|
|
try: |
|
|
with open(self.storage_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
for n in data.get("nodes", []): |
|
|
self.nodes[n["id"]] = MemoryNode(**n) |
|
|
for e in data.get("edges", []): |
|
|
self.edges.append(MemoryEdge(**e)) |
|
|
except: pass |
|
|
|
|
|
def save(self): |
|
|
data = { |
|
|
"nodes": [n.__dict__ for n in self.nodes.values()], |
|
|
"edges": [e.__dict__ for e in self.edges] |
|
|
} |
|
|
os.makedirs(os.path.dirname(self.storage_path), exist_ok=True) |
|
|
with open(self.storage_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
def add_turn(self, turn_id: str, query: str, response: str, entities: List[str], facts: List[str]): |
|
|
"""Compresses a chat turn into the lattice structure.""" |
|
|
u_node_id = f"U_{turn_id}" |
|
|
self.nodes[u_node_id] = MemoryNode(u_node_id, "utterance", {"query": query, "response": response}) |
|
|
|
|
|
|
|
|
for ent in entities: |
|
|
e_id = f"E_{ent}" |
|
|
if e_id not in self.nodes: |
|
|
self.nodes[e_id] = MemoryNode(e_id, "concept", ent) |
|
|
self.edges.append(MemoryEdge(u_node_id, e_id, "mentions")) |
|
|
|
|
|
|
|
|
for fact in facts: |
|
|
f_id = f"F_{uuid.uuid4().hex[:8]}" |
|
|
self.nodes[f_id] = MemoryNode(f_id, "verified_fact", fact) |
|
|
self.edges.append(MemoryEdge(u_node_id, f_id, "produced")) |
|
|
|
|
|
for ent in entities: |
|
|
if ent.lower() in fact.lower(): |
|
|
self.edges.append(MemoryEdge(f_id, f"E_{ent}", "relates_to")) |
|
|
|
|
|
def get_context_pack(self, current_entities: List[str], k_hop: int = 1) -> Dict[str, Any]: |
|
|
"""Retrieves relevant past info based on current entities.""" |
|
|
relevant_nodes = set() |
|
|
for ent in current_entities: |
|
|
e_id = f"E_{ent}" |
|
|
if e_id in self.nodes: |
|
|
relevant_nodes.add(e_id) |
|
|
|
|
|
for edge in self.edges: |
|
|
if edge.source == e_id: relevant_nodes.add(edge.target) |
|
|
if edge.target == e_id: relevant_nodes.add(edge.source) |
|
|
|
|
|
pack = { |
|
|
"concepts": [], |
|
|
"past_facts": [], |
|
|
"history_snip": [] |
|
|
} |
|
|
|
|
|
for nid in relevant_nodes: |
|
|
node = self.nodes[nid] |
|
|
if node.type == "concept": pack["concepts"].append(node.content) |
|
|
elif node.type == "verified_fact": pack["past_facts"].append(node.content) |
|
|
elif node.type == "utterance": pack["history_snip"].append(node.content) |
|
|
|
|
|
return pack |
|
|
|