kofdai's picture
Upload folder using huggingface_hub
6d07351 verified
from __future__ import annotations
import json
import os
import uuid
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional
@dataclass
class MemoryNode:
id: str
type: str # 'utterance', 'concept', 'verified_fact'
content: Any
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MemoryEdge:
source: str
target: str
relation: str # 'includes', 'supports', 'next_turn'
class MemoryLattice:
"""
AXIS Memory Lattice Layer (MLL)
Manages persistent graph of past conversations and verified facts.
"""
def __init__(self, storage_path: str):
self.storage_path = storage_path
self.nodes: Dict[str, MemoryNode] = {}
self.edges: List[MemoryEdge] = []
self._load()
def _load(self):
if os.path.exists(self.storage_path):
try:
with open(self.storage_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for n in data.get("nodes", []):
self.nodes[n["id"]] = MemoryNode(**n)
for e in data.get("edges", []):
self.edges.append(MemoryEdge(**e))
except: pass
def save(self):
data = {
"nodes": [n.__dict__ for n in self.nodes.values()],
"edges": [e.__dict__ for e in self.edges]
}
os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
with open(self.storage_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def add_turn(self, turn_id: str, query: str, response: str, entities: List[str], facts: List[str]):
"""Compresses a chat turn into the lattice structure."""
u_node_id = f"U_{turn_id}"
self.nodes[u_node_id] = MemoryNode(u_node_id, "utterance", {"query": query, "response": response})
# Link Entities (Concepts)
for ent in entities:
e_id = f"E_{ent}"
if e_id not in self.nodes:
self.nodes[e_id] = MemoryNode(e_id, "concept", ent)
self.edges.append(MemoryEdge(u_node_id, e_id, "mentions"))
# Link Verified Facts
for fact in facts:
f_id = f"F_{uuid.uuid4().hex[:8]}"
self.nodes[f_id] = MemoryNode(f_id, "verified_fact", fact)
self.edges.append(MemoryEdge(u_node_id, f_id, "produced"))
# Link fact to related entities
for ent in entities:
if ent.lower() in fact.lower():
self.edges.append(MemoryEdge(f_id, f"E_{ent}", "relates_to"))
def get_context_pack(self, current_entities: List[str], k_hop: int = 1) -> Dict[str, Any]:
"""Retrieves relevant past info based on current entities."""
relevant_nodes = set()
for ent in current_entities:
e_id = f"E_{ent}"
if e_id in self.nodes:
relevant_nodes.add(e_id)
# Simple 1-hop neighbor search
for edge in self.edges:
if edge.source == e_id: relevant_nodes.add(edge.target)
if edge.target == e_id: relevant_nodes.add(edge.source)
pack = {
"concepts": [],
"past_facts": [],
"history_snip": []
}
for nid in relevant_nodes:
node = self.nodes[nid]
if node.type == "concept": pack["concepts"].append(node.content)
elif node.type == "verified_fact": pack["past_facts"].append(node.content)
elif node.type == "utterance": pack["history_snip"].append(node.content)
return pack