zorkclaw / ari_graph_memory.py
Simon Sassi
feat: add valid actions, NPC puzzle hints, and state hash tracking for loop detection
d254e99
"""AriGraph-style memory for Jericho text adventures.
This module implements a lightweight bi-layer memory:
- Semantic facts (entity -> relation -> entity) with temporal validity.
- Episodic turns (action + observation) that can be searched and used as context.
It uses an embedded Kuzu graph database for persistence.
Design goals for this assignment:
- No network calls required.
- Deterministic, fast ingestion.
- Provide "last known location" style retrieval for common game objects.
If you later want full Graphiti ingestion (LLM-based extraction + embeddings),
this module can be swapped out behind the same interface.
"""
from __future__ import annotations
import os
import re
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable, Any
import kuzu
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
_COMMON_DIRECTIONS: list[str] = [
"north",
"south",
"east",
"west",
"up",
"down",
"enter",
"exit",
]
def _norm(s: str) -> str:
s = s.strip().lower()
s = re.sub(r"[^a-z0-9\s\-']+", " ", s)
s = re.sub(r"\s+", " ", s)
return s
def _entity_id(kind: str, name: str) -> str:
return f"{kind}::{_norm(name)}"
@dataclass(frozen=True)
class MemoryConfig:
db_path: str
class AriGraphMemory:
"""Kuzu-backed semantic+episodic memory."""
def __init__(self, config: MemoryConfig):
self._db_path = config.db_path
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
self._db = kuzu.Database(self._db_path)
self._conn = kuzu.Connection(self._db)
self._init_schema()
def close(self) -> None:
# Kuzu python bindings don't require explicit close for embedded DB,
# but we keep this for symmetry.
pass
def _qr(
self, query: str, parameters: dict[str, Any] | None = None
) -> kuzu.QueryResult:
"""Execute a query and normalize Kuzu's return type.
Depending on the Kuzu Python bindings version, Connection.execute() may
return a single QueryResult or a list of QueryResult objects.
"""
res = self._conn.execute(query, parameters or {})
if isinstance(res, list):
return res[0]
return res
# ---------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------
def _init_schema(self) -> None:
statements = [
"CREATE NODE TABLE SemanticEntity(id STRING, name STRING, kind STRING, created_at STRING, PRIMARY KEY(id))",
"CREATE NODE TABLE EpisodicTurn(id STRING, game STRING, turn INT64, action STRING, location STRING, observation STRING, created_at STRING, PRIMARY KEY(id))",
"CREATE REL TABLE Mentions(FROM EpisodicTurn TO SemanticEntity, role STRING)",
"CREATE REL TABLE SemanticRel(FROM SemanticEntity TO SemanticEntity, pred STRING, valid_from INT64, valid_to INT64, turn_id STRING)",
]
for stmt in statements:
try:
self._conn.execute(stmt)
except Exception:
# Best-effort idempotency: ignore if already exists.
continue
# Basic indexes to speed up lookups.
for stmt in [
"CREATE INDEX SemanticEntity_name_idx ON SemanticEntity(name)",
"CREATE INDEX EpisodicTurn_game_turn_idx ON EpisodicTurn(game, turn)",
]:
try:
self._conn.execute(stmt)
except Exception:
continue
# ---------------------------------------------------------------------
# Ingestion
# ---------------------------------------------------------------------
def ingest_turn(
self,
*,
game: str,
turn: int,
action: str,
location: str,
observation: str,
inventory: list[str] | None = None,
) -> None:
"""Store an episodic node, then update semantic facts.
Semantic extraction is intentionally heuristic and deterministic.
"""
inventory = inventory or []
turn_id = str(uuid.uuid4())
self._upsert_turn(
turn_id=turn_id,
game=game,
turn=turn,
action=action,
location=location,
observation=observation,
)
# Entities: agent + room + items
agent_id = self._upsert_entity(kind="agent", name="player")
room_id = self._upsert_entity(kind="room", name=location or "Unknown")
self._add_mention(turn_id, agent_id, role="actor")
self._add_mention(turn_id, room_id, role="location")
# Update agent location (invalidate old AT edges)
self._set_singleton_relation(
subj_id=agent_id,
pred="at",
obj_id=room_id,
turn=turn,
turn_id=turn_id,
)
# Inventory facts
inv_item_ids: set[str] = set()
for item in inventory:
item_name = self._clean_item_name(item)
if not item_name:
continue
item_id = self._upsert_entity(kind="item", name=item_name)
inv_item_ids.add(item_id)
self._add_mention(turn_id, item_id, role="inventory")
# has is multi-valued: use _add_relation directly (no invalidation of other items)
self._add_relation(
subj_id=agent_id,
pred="has",
obj_id=item_id,
turn=turn,
turn_id=turn_id,
singleton=True,
)
# If we have it, it is not in a room.
self._invalidate_predicate_for_subject(
subj_id=item_id,
pred="located_in",
turn=turn,
)
# Visible items (best-effort) become located_in current room unless in inventory.
for visible in self._extract_visible_items(observation):
item_id = self._upsert_entity(kind="item", name=visible)
if item_id in inv_item_ids:
continue
self._add_mention(turn_id, item_id, role="visible")
# located_in is per-item singleton: use _add_relation directly
self._add_relation(
subj_id=item_id,
pred="located_in",
obj_id=room_id,
turn=turn,
turn_id=turn_id,
singleton=True,
)
# Exits: if we can infer from action (movement), store a directed edge.
move_dir = self._normalize_direction(action)
if move_dir and location:
# If ingest_turn is called for a movement command, it's usually a failed move
# (successful moves are ingested via ingest_move). Record it as a tried direction.
dir_id = self._upsert_entity(kind="direction", name=move_dir)
self._add_relation(
subj_id=room_id,
pred=f"tried_{move_dir}",
obj_id=dir_id,
turn=turn,
turn_id=turn_id,
singleton=True, # only store once per (room, direction)
)
def ingest_move(
self,
*,
game: str,
turn: int,
action: str,
from_location: str,
to_location: str,
observation: str,
inventory: list[str] | None = None,
) -> None:
"""Ingest a movement turn and create an exit edge from old->new."""
inventory = inventory or []
turn_id = str(uuid.uuid4())
self._upsert_turn(
turn_id=turn_id,
game=game,
turn=turn,
action=action,
location=to_location,
observation=observation,
)
agent_id = self._upsert_entity(kind="agent", name="player")
from_room_id = self._upsert_entity(kind="room", name=from_location or "Unknown")
to_room_id = self._upsert_entity(kind="room", name=to_location or "Unknown")
self._add_mention(turn_id, agent_id, role="actor")
self._add_mention(turn_id, from_room_id, role="from")
self._add_mention(turn_id, to_room_id, role="to")
# Update location
self._set_singleton_relation(
subj_id=agent_id,
pred="at",
obj_id=to_room_id,
turn=turn,
turn_id=turn_id,
)
# Exit edge (semantic) from from->to
move_dir = self._normalize_direction(action)
if move_dir and from_location and to_location and from_location != to_location:
self._add_relation(
subj_id=from_room_id,
pred=f"exit_{move_dir}",
obj_id=to_room_id,
turn=turn,
turn_id=turn_id,
singleton=True, # only store once per (from_room, direction) pair
)
# Inventory + visible items best-effort
inv_item_ids: set[str] = set()
for item in inventory:
item_name = self._clean_item_name(item)
if not item_name:
continue
item_id = self._upsert_entity(kind="item", name=item_name)
inv_item_ids.add(item_id)
self._add_mention(turn_id, item_id, role="inventory")
# has is multi-valued: use _add_relation directly (no invalidation of other items)
self._add_relation(
subj_id=agent_id,
pred="has",
obj_id=item_id,
turn=turn,
turn_id=turn_id,
singleton=True,
)
self._invalidate_predicate_for_subject(
subj_id=item_id,
pred="located_in",
turn=turn,
)
for visible in self._extract_visible_items(observation):
item_id = self._upsert_entity(kind="item", name=visible)
if item_id in inv_item_ids:
continue
self._add_mention(turn_id, item_id, role="visible")
# located_in is per-item singleton: use _add_relation directly
self._add_relation(
subj_id=item_id,
pred="located_in",
obj_id=to_room_id,
turn=turn,
turn_id=turn_id,
singleton=True,
)
# ---------------------------------------------------------------------
# Retrieval
# ---------------------------------------------------------------------
def context_summary(self, *, max_facts: int = 10) -> str:
agent_id = _entity_id("agent", "player")
location = self._get_active_object_name(subj_id=agent_id, pred="at")
inv = self._get_active_objects(subj_id=agent_id, pred="has", limit=20)
lines: list[str] = []
lines.append("AriGraph Memory (semantic+episodic)")
if location:
lines.append(f"- Belief: you are at: {location}")
if inv:
lines.append(f"- Belief: you have: {', '.join(inv)}")
# Show a few known item locations
facts = self._get_recent_facts(limit=max_facts)
if facts:
lines.append("- Recent facts:")
for fact in facts:
lines.append(f" - {fact}")
return "\n".join(lines)
def search(self, query: str, *, limit: int = 10) -> str:
q = _norm(query)
if not q:
return "No query provided."
agent_id = _entity_id("agent", "player")
# "Where am I?"
if re.search(r"\b(where am i|current location|my location)\b", q):
location = self._get_active_object_name(subj_id=agent_id, pred="at")
return f"You are at: {location}" if location else "Location unknown."
# "What do I have?" / inventory
if re.search(r"\b(what do i have|inventory|what am i carrying|carrying)\b", q):
inv = self._get_active_objects(subj_id=agent_id, pred="has", limit=25)
return (
"You have: " + ", ".join(inv)
if inv
else "Inventory: empty-handed (based on memory)."
)
# "Unexplored exits" / "untried directions"
m_untried = re.search(
r"(?:unexplored exits|untried directions|untried exits|where should i explore)(?:\s+from\s+(.+))?",
q,
)
if m_untried:
room_name = (m_untried.group(1) or "").strip()
if room_name:
room_name = re.sub(r"^(the|a|an)\s+", "", room_name)
else:
room_name = (
self._get_active_object_name(subj_id=agent_id, pred="at") or ""
)
if not room_name:
return "Current room unknown; cannot compute untried directions."
room_id = self._best_match_entity_id(kind="room", name_fragment=room_name)
if not room_id:
return f"Room not found in memory: '{room_name}'."
exit_preds = [f"exit_{d}" for d in _COMMON_DIRECTIONS]
tried_preds = [f"tried_{d}" for d in _COMMON_DIRECTIONS]
rows = self._qr(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity) "
"WHERE r.pred IN $preds AND r.valid_to = -1 "
"RETURN r.pred",
{"sid": room_id, "preds": exit_preds + tried_preds},
)
tried: set[str] = set()
while rows.has_next():
(pred,) = rows.get_next()
pred_s = str(pred)
pred_s = pred_s.replace("exit_", "").replace("tried_", "").strip()
tried.add(pred_s)
untried = [d for d in _COMMON_DIRECTIONS if d not in tried]
if not untried:
return f"Untried directions from '{room_name}': none recorded (you've tried them all in memory)."
return (
f"Untried directions from '{room_name}': {', '.join(untried)}\n"
"(These are candidate directions; some may be blocked.)"
)
# "Exits from <room>" / navigation
m_exits = re.search(
r"(?:exits from|where can i go from|ways out of|paths from)\s+(.+)",
q,
)
if m_exits:
room_name = re.sub(r"^(the|a|an)\s+", "", m_exits.group(1).strip())
room_id = self._best_match_entity_id(kind="room", name_fragment=room_name)
if room_id:
exit_preds = [
"exit_north",
"exit_south",
"exit_east",
"exit_west",
"exit_up",
"exit_down",
"exit_enter",
"exit_exit",
]
rows = self._qr(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity) "
"WHERE r.pred IN $preds AND r.valid_to = -1 "
"RETURN r.pred, o.name",
{"sid": room_id, "preds": exit_preds},
)
exits: list[str] = []
while rows.has_next():
pred, dst = rows.get_next()
direction = str(pred).replace("exit_", "")
exits.append(f"{direction} -> {dst}")
if exits:
return f"Known exits from '{room_name}':\n" + "\n".join(
f"- {e}" for e in sorted(set(exits))
)
return f"No known exits recorded yet for '{room_name}'."
return f"Room not found in memory: '{room_name}'."
# Heuristic: if the query looks like "where is X" try item location first.
m = re.search(r"(?:where is|find|locate)\s+(.+)", q)
if m:
target = m.group(1).strip()
target = re.sub(r"^(the|a|an)\s+", "", target)
item_id = self._best_match_entity_id(kind="item", name_fragment=target)
if item_id:
active_loc = self._get_active_object_name(
subj_id=item_id, pred="located_in"
)
if active_loc:
return f"Last known location of '{target}': {active_loc}"
held_by = self._is_agent_holding(item_id)
if held_by:
return f"You are holding '{target}'."
# Fallback: keyword search over recent episodic turns.
rows = self._qr(
"MATCH (t:EpisodicTurn) WHERE lower(t.observation) CONTAINS $q OR lower(t.action) CONTAINS $q "
"RETURN t.turn, t.action, substring(t.observation, 0, 180) AS obs "
"ORDER BY t.turn DESC LIMIT $lim",
{"q": q, "lim": limit},
)
results: list[str] = []
while rows.has_next():
turn_num, act, obs = rows.get_next()
results.append(f"Turn {turn_num} | {act} -> {obs}")
if results:
return "Episodic matches:\n" + "\n".join(results)
return "No memory matches found."
# ---------------------------------------------------------------------
# Low-level helpers
# ---------------------------------------------------------------------
def _upsert_entity(self, *, kind: str, name: str) -> str:
entity_id = _entity_id(kind, name)
created_at = _now_iso()
try:
self._conn.execute(
"CREATE (e:SemanticEntity {id: $id, name: $name, kind: $kind, created_at: $created_at})",
{"id": entity_id, "name": name, "kind": kind, "created_at": created_at},
)
except Exception:
# Probably already exists.
pass
return entity_id
def _upsert_turn(
self,
*,
turn_id: str,
game: str,
turn: int,
action: str,
location: str,
observation: str,
) -> None:
created_at = _now_iso()
self._conn.execute(
"CREATE (t:EpisodicTurn {id: $id, game: $game, turn: $turn, action: $action, location: $loc, observation: $obs, created_at: $created_at})",
{
"id": turn_id,
"game": game,
"turn": int(turn),
"action": action,
"loc": location,
"obs": observation,
"created_at": created_at,
},
)
def _add_mention(self, turn_id: str, entity_id: str, *, role: str) -> None:
try:
self._conn.execute(
"MATCH (t:EpisodicTurn {id: $tid}), (e:SemanticEntity {id: $eid}) "
"CREATE (t)-[:Mentions {role: $role}]->(e)",
{"tid": turn_id, "eid": entity_id, "role": role},
)
except Exception:
pass
def _invalidate_predicate_for_subject(
self, *, subj_id: str, pred: str, turn: int
) -> None:
# valid_to is nullable; we store -1 as "still valid" for simplicity.
self._conn.execute(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity) "
"WHERE r.pred = $pred AND r.valid_to = -1 "
"SET r.valid_to = $turn",
{"sid": subj_id, "pred": pred, "turn": int(turn)},
)
def _set_singleton_relation(
self,
*,
subj_id: str,
pred: str,
obj_id: str,
turn: int,
turn_id: str,
) -> None:
# No-op if the active relation already points to this exact object.
# This prevents noisy invalidation chains when nothing changed.
already = self._qr(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity {id: $oid}) "
"WHERE r.pred = $pred AND r.valid_to = -1 RETURN count(*)",
{"sid": subj_id, "oid": obj_id, "pred": pred},
)
if already.has_next():
(count_val,) = already.get_next()
try:
count_int = int(count_val)
except Exception:
count_int = 0
if count_int > 0:
return
self._invalidate_predicate_for_subject(subj_id=subj_id, pred=pred, turn=turn)
self._add_relation(
subj_id=subj_id,
pred=pred,
obj_id=obj_id,
turn=turn,
turn_id=turn_id,
singleton=True,
)
def _add_relation(
self,
*,
subj_id: str,
pred: str,
obj_id: str,
turn: int,
turn_id: str,
singleton: bool,
) -> None:
# For singleton relations, we prevent exact duplicates for the same (subj,pred,obj) valid window.
if singleton:
existing = self._qr(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity {id: $oid}) "
"WHERE r.pred = $pred AND r.valid_to = -1 "
"RETURN count(*)",
{"sid": subj_id, "oid": obj_id, "pred": pred},
)
if existing.has_next():
(count_val,) = existing.get_next()
count_int: int
try:
count_int = int(count_val)
except Exception:
count_int = 0
if count_int > 0:
return
self._conn.execute(
"MATCH (s:SemanticEntity {id: $sid}), (o:SemanticEntity {id: $oid}) "
"CREATE (s)-[:SemanticRel {pred: $pred, valid_from: $vf, valid_to: -1, turn_id: $tid}]->(o)",
{
"sid": subj_id,
"oid": obj_id,
"pred": pred,
"vf": int(turn),
"tid": turn_id,
},
)
def _get_active_object_name(self, *, subj_id: str, pred: str) -> str | None:
rows = self._qr(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity) "
"WHERE r.pred = $pred AND r.valid_to = -1 "
"RETURN o.name LIMIT 1",
{"sid": subj_id, "pred": pred},
)
if rows.has_next():
(name,) = rows.get_next()
return name
return None
def _get_active_objects(self, *, subj_id: str, pred: str, limit: int) -> list[str]:
rows = self._qr(
"MATCH (s:SemanticEntity {id: $sid})-[r:SemanticRel]->(o:SemanticEntity) "
"WHERE r.pred = $pred AND r.valid_to = -1 "
"RETURN o.name LIMIT $lim",
{"sid": subj_id, "pred": pred, "lim": int(limit)},
)
out: list[str] = []
while rows.has_next():
(name,) = rows.get_next()
out.append(name)
return out
def _is_agent_holding(self, item_id: str) -> bool:
agent_id = _entity_id("agent", "player")
rows = self._qr(
"MATCH (a:SemanticEntity {id: $aid})-[r:SemanticRel]->(i:SemanticEntity {id: $iid}) "
"WHERE r.pred = 'has' AND r.valid_to = -1 RETURN count(*)",
{"aid": agent_id, "iid": item_id},
)
if not rows.has_next():
return False
(count_val,) = rows.get_next()
try:
count_int = int(count_val)
except Exception:
count_int = 0
return count_int > 0
def _get_recent_facts(self, *, limit: int) -> list[str]:
rows = self._qr(
"MATCH (s:SemanticEntity)-[r:SemanticRel]->(o:SemanticEntity) "
"RETURN s.name, r.pred, o.name, r.valid_from, r.valid_to "
"ORDER BY r.valid_from DESC LIMIT $lim",
{"lim": int(limit)},
)
facts: list[str] = []
while rows.has_next():
s, pred, o, vf, vt = rows.get_next()
# Truncate long entity names to keep facts readable.
s_short = (str(s)[:50] + "…") if len(str(s)) > 50 else str(s)
o_short = (str(o)[:50] + "…") if len(str(o)) > 50 else str(o)
if vt == -1:
facts.append(f"{s_short} {pred} {o_short} (since {vf})")
else:
facts.append(f"{s_short} {pred} {o_short} ({vf}..{vt})")
return facts
def _best_match_entity_id(self, *, kind: str, name_fragment: str) -> str | None:
frag = _norm(name_fragment)
if not frag:
return None
rows = self._qr(
"MATCH (e:SemanticEntity) WHERE e.kind = $k AND lower(e.name) CONTAINS $q "
"RETURN e.id, e.name ORDER BY e.name ASC LIMIT 1",
{"k": kind, "q": frag},
)
if rows.has_next():
entity_id, _name = rows.get_next()
return entity_id
return None
# ---------------------------------------------------------------------
# Extraction heuristics
# ---------------------------------------------------------------------
def _clean_item_name(self, raw: str) -> str:
# Jericho inventory strings can include metadata; extract just the name.
s = str(raw).strip()
s = re.sub(r"\s+", " ", s)
# Handle Jericho object format: "Obj<N>: <name> Parent<N> ..."
s_lower = s.lower()
if "parent" in s_lower:
idx = s_lower.index("parent")
s = s[:idx].strip()
# Also strip leading "Obj<N>: " prefix
if ":" in s:
s = s.split(":", 1)[1].strip()
elif ":" in s:
# Strip leading "Obj<N>: " style prefix
parts = s.split(":", 1)
if re.match(r"obj\d+", parts[0].strip().lower()):
s = parts[1].strip()
# Remove trailing (metadata) groups
s = re.sub(r"\s*\(.*?\)\s*$", "", s)
return s.strip()
def _normalize_direction(self, action: str) -> str | None:
a = _norm(action)
mapping = {
"n": "north",
"s": "south",
"e": "east",
"w": "west",
"u": "up",
"d": "down",
}
if a in mapping:
return mapping[a]
if a in ("north", "south", "east", "west", "up", "down"):
return a
if a.startswith("go "):
rest = a[3:].strip()
return (
mapping.get(rest, rest)
if rest in mapping or rest in mapping.values()
else None
)
return None
def _extract_visible_items(self, observation: str) -> Iterable[str]:
return extract_visible_items(observation)
def extract_visible_items(observation: str) -> list[str]:
"""Heuristic object mention extractor.
Kept pure (no DB access) so other modules can reuse it.
"""
text = observation or ""
text = text.replace("\r", "")
out: set[str] = set()
# Common Zork-ish patterns.
patterns = [
r"\bA\s+([A-Za-z][A-Za-z0-9 '\-]+?)\s+is\s+here\b",
r"\bAn\s+([A-Za-z][A-Za-z0-9 '\-]+?)\s+is\s+here\b",
r"\bYou\s+can\s+see\s+an?\s+([A-Za-z][A-Za-z0-9 '\-]+?)(?:\.|,|\s+here\b)",
r"\bThere\s+is\s+an?\s+([A-Za-z][A-Za-z0-9 '\-]+?)(?:\.|,|\s+here\b)",
]
for pat in patterns:
for m in re.finditer(pat, text, flags=re.IGNORECASE):
item = m.group(1).strip()
item = re.sub(r"\s+", " ", item)
if len(item) < 2:
continue
out.add(item)
# Lines that look like object mentions.
for line in text.split("\n"):
line = line.strip()
if not line:
continue
match = re.match(r"^(?:A|An)\s+(.+?)\.$", line)
if match:
item = match.group(1).strip()
if 2 <= len(item) <= 50:
out.add(item)
# Filter low-value generic words.
filtered: set[str] = set()
for item in out:
low = _norm(item)
if low in {"nothing", "someone", "something"}:
continue
filtered.add(item)
return sorted(filtered)
def default_memory(game: str | None = None) -> AriGraphMemory:
"""Create a default memory DB under the project directory.
If ARIGRAPH_DB is set, it takes precedence.
Otherwise, use a per-game DB file to avoid stale cross-game facts.
"""
override = os.environ.get("ARIGRAPH_DB")
if override:
return AriGraphMemory(MemoryConfig(db_path=override))
suffix = _norm(game or "default")
db_path = str(Path(".memory") / f"ari_graph_{suffix}.kuzu")
return AriGraphMemory(MemoryConfig(db_path=db_path))