Spaces:
Paused
Paused
| """Tests for the RetainDB memory plugin. | |
| Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter, | |
| RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling. | |
| """ | |
| import json | |
| import os | |
| import sqlite3 | |
| import tempfile | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, patch, PropertyMock | |
| import pytest | |
| # --------------------------------------------------------------------------- | |
| # Imports — guarded since plugins/memory lives outside the standard test path | |
| # --------------------------------------------------------------------------- | |
| def _isolate_env(tmp_path, monkeypatch): | |
| """Ensure HERMES_HOME and RETAINDB vars are isolated.""" | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir() | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| monkeypatch.delenv("RETAINDB_API_KEY", raising=False) | |
| monkeypatch.delenv("RETAINDB_BASE_URL", raising=False) | |
| monkeypatch.delenv("RETAINDB_PROJECT", raising=False) | |
| # We need the repo root on sys.path so the plugin can import agent.memory_provider | |
| import sys | |
| _repo_root = str(Path(__file__).resolve().parents[2]) | |
| if _repo_root not in sys.path: | |
| sys.path.insert(0, _repo_root) | |
| from plugins.memory.retaindb import ( | |
| _Client, | |
| _WriteQueue, | |
| _build_overlay, | |
| RetainDBMemoryProvider, | |
| _ASYNC_SHUTDOWN, | |
| _DEFAULT_BASE_URL, | |
| ) | |
| # =========================================================================== | |
| # _Client tests | |
| # =========================================================================== | |
| class TestClient: | |
| """Test the HTTP client with mocked requests.""" | |
| def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"): | |
| return _Client(api_key, base_url, project) | |
| def test_base_url_trailing_slash_stripped(self): | |
| c = self._make_client(base_url="https://api.retaindb.com///") | |
| assert c.base_url == "https://api.retaindb.com" | |
| def test_headers_include_auth(self): | |
| c = self._make_client() | |
| h = c._headers("/v1/files") | |
| assert h["Authorization"] == "Bearer rdb-test-key" | |
| assert "X-API-Key" not in h | |
| def test_headers_include_api_key_for_memory_path(self): | |
| c = self._make_client() | |
| h = c._headers("/v1/memory/search") | |
| assert h["X-API-Key"] == "rdb-test-key" | |
| def test_headers_include_api_key_for_context_path(self): | |
| c = self._make_client() | |
| h = c._headers("/v1/context/query") | |
| assert h["X-API-Key"] == "rdb-test-key" | |
| def test_headers_strip_bearer_prefix(self): | |
| c = self._make_client(api_key="Bearer rdb-test-key") | |
| h = c._headers("/v1/memory/search") | |
| assert h["Authorization"] == "Bearer rdb-test-key" | |
| assert h["X-API-Key"] == "rdb-test-key" | |
| def test_query_context_builds_correct_payload(self): | |
| c = self._make_client() | |
| with patch.object(c, "request") as mock_req: | |
| mock_req.return_value = {"results": []} | |
| c.query_context("user1", "sess1", "test query", max_tokens=500) | |
| mock_req.assert_called_once_with("POST", "/v1/context/query", json_body={ | |
| "project": "test", | |
| "query": "test query", | |
| "user_id": "user1", | |
| "session_id": "sess1", | |
| "include_memories": True, | |
| "max_tokens": 500, | |
| }) | |
| def test_search_builds_correct_payload(self): | |
| c = self._make_client() | |
| with patch.object(c, "request") as mock_req: | |
| mock_req.return_value = {"results": []} | |
| c.search("user1", "sess1", "find this", top_k=5) | |
| mock_req.assert_called_once_with("POST", "/v1/memory/search", json_body={ | |
| "project": "test", | |
| "query": "find this", | |
| "user_id": "user1", | |
| "session_id": "sess1", | |
| "top_k": 5, | |
| "include_pending": True, | |
| }) | |
| def test_add_memory_tries_fallback(self): | |
| c = self._make_client() | |
| call_count = 0 | |
| def fake_request(method, path, **kwargs): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count == 1: | |
| raise RuntimeError("404") | |
| return {"id": "mem-1"} | |
| with patch.object(c, "request", side_effect=fake_request): | |
| result = c.add_memory("u1", "s1", "test fact") | |
| assert result == {"id": "mem-1"} | |
| assert call_count == 2 | |
| def test_delete_memory_tries_fallback(self): | |
| c = self._make_client() | |
| call_count = 0 | |
| def fake_request(method, path, **kwargs): | |
| nonlocal call_count | |
| call_count += 1 | |
| if call_count == 1: | |
| raise RuntimeError("404") | |
| return {"deleted": True} | |
| with patch.object(c, "request", side_effect=fake_request): | |
| result = c.delete_memory("mem-123") | |
| assert result == {"deleted": True} | |
| assert call_count == 2 | |
| def test_ingest_session_payload(self): | |
| c = self._make_client() | |
| with patch.object(c, "request") as mock_req: | |
| mock_req.return_value = {"status": "ok"} | |
| msgs = [{"role": "user", "content": "hi"}] | |
| c.ingest_session("u1", "s1", msgs, timeout=10.0) | |
| mock_req.assert_called_once_with("POST", "/v1/memory/ingest/session", json_body={ | |
| "project": "test", | |
| "session_id": "s1", | |
| "user_id": "u1", | |
| "messages": msgs, | |
| "write_mode": "sync", | |
| }, timeout=10.0) | |
| def test_ask_user_payload(self): | |
| c = self._make_client() | |
| with patch.object(c, "request") as mock_req: | |
| mock_req.return_value = {"answer": "test answer"} | |
| c.ask_user("u1", "who am i?", reasoning_level="medium") | |
| mock_req.assert_called_once() | |
| call_kwargs = mock_req.call_args | |
| assert call_kwargs[1]["json_body"]["reasoning_level"] == "medium" | |
| def test_get_agent_model_path(self): | |
| c = self._make_client() | |
| with patch.object(c, "request") as mock_req: | |
| mock_req.return_value = {"memory_count": 3} | |
| c.get_agent_model("hermes") | |
| mock_req.assert_called_once_with( | |
| "GET", "/v1/memory/agent/hermes/model", | |
| params={"project": "test"}, timeout=4.0 | |
| ) | |
| # =========================================================================== | |
| # _WriteQueue tests | |
| # =========================================================================== | |
| class TestWriteQueue: | |
| """Test the SQLite-backed write queue with real SQLite.""" | |
| def _make_queue(self, tmp_path, client=None): | |
| if client is None: | |
| client = MagicMock() | |
| client.ingest_session = MagicMock(return_value={"status": "ok"}) | |
| db_path = tmp_path / "test_queue.db" | |
| return _WriteQueue(client, db_path), client, db_path | |
| def test_enqueue_creates_row(self, tmp_path): | |
| q, client, db_path = self._make_queue(tmp_path) | |
| q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) | |
| # Give the writer thread a moment to process | |
| time.sleep(1) | |
| q.shutdown() | |
| # If ingest succeeded, the row should be deleted | |
| client.ingest_session.assert_called_once() | |
| def test_enqueue_persists_to_sqlite(self, tmp_path): | |
| client = MagicMock() | |
| # Make ingest hang so the row stays in SQLite | |
| client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5)) | |
| db_path = tmp_path / "test_queue.db" | |
| q = _WriteQueue(client, db_path) | |
| q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}]) | |
| # Check SQLite directly — row should exist since flush is slow | |
| conn = sqlite3.connect(str(db_path)) | |
| rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall() | |
| conn.close() | |
| assert len(rows) >= 1 | |
| assert rows[0][0] == "user1" | |
| q.shutdown() | |
| def test_flush_deletes_row_on_success(self, tmp_path): | |
| q, client, db_path = self._make_queue(tmp_path) | |
| q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) | |
| time.sleep(1) | |
| q.shutdown() | |
| # Row should be gone | |
| conn = sqlite3.connect(str(db_path)) | |
| rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0] | |
| conn.close() | |
| assert rows == 0 | |
| def test_flush_records_error_on_failure(self, tmp_path): | |
| client = MagicMock() | |
| client.ingest_session = MagicMock(side_effect=RuntimeError("API down")) | |
| db_path = tmp_path / "test_queue.db" | |
| q = _WriteQueue(client, db_path) | |
| q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) | |
| time.sleep(3) # Allow retry + sleep(2) in _flush_row | |
| q.shutdown() | |
| # Row should still exist with error recorded | |
| conn = sqlite3.connect(str(db_path)) | |
| row = conn.execute("SELECT last_error FROM pending").fetchone() | |
| conn.close() | |
| assert row is not None | |
| assert "API down" in row[0] | |
| def test_thread_local_connection_reuse(self, tmp_path): | |
| q, _, _ = self._make_queue(tmp_path) | |
| # Same thread should get same connection | |
| conn1 = q._get_conn() | |
| conn2 = q._get_conn() | |
| assert conn1 is conn2 | |
| q.shutdown() | |
| def test_crash_recovery_replays_pending(self, tmp_path): | |
| """Simulate crash: create rows, then new queue should replay them.""" | |
| db_path = tmp_path / "recovery_test.db" | |
| # First: create a queue and insert rows, but don't let them flush | |
| client1 = MagicMock() | |
| client1.ingest_session = MagicMock(side_effect=RuntimeError("fail")) | |
| q1 = _WriteQueue(client1, db_path) | |
| q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}]) | |
| time.sleep(3) | |
| q1.shutdown() | |
| # Now create a new queue — it should replay the pending rows | |
| client2 = MagicMock() | |
| client2.ingest_session = MagicMock(return_value={"status": "ok"}) | |
| q2 = _WriteQueue(client2, db_path) | |
| time.sleep(2) | |
| q2.shutdown() | |
| # The replayed row should have been ingested via client2 | |
| client2.ingest_session.assert_called_once() | |
| call_args = client2.ingest_session.call_args | |
| assert call_args[0][0] == "user1" # user_id | |
| # =========================================================================== | |
| # _build_overlay tests | |
| # =========================================================================== | |
| class TestBuildOverlay: | |
| """Test the overlay formatter (pure function).""" | |
| def test_empty_inputs_returns_empty(self): | |
| assert _build_overlay({}, {}) == "" | |
| def test_empty_memories_returns_empty(self): | |
| assert _build_overlay({"memories": []}, {"results": []}) == "" | |
| def test_profile_items_included(self): | |
| profile = {"memories": [{"content": "User likes Python"}]} | |
| result = _build_overlay(profile, {}) | |
| assert "User likes Python" in result | |
| assert "[RetainDB Context]" in result | |
| def test_query_results_included(self): | |
| query_result = {"results": [{"content": "Previous discussion about Rust"}]} | |
| result = _build_overlay({}, query_result) | |
| assert "Previous discussion about Rust" in result | |
| def test_deduplication_removes_duplicates(self): | |
| profile = {"memories": [{"content": "User likes Python"}]} | |
| query_result = {"results": [{"content": "User likes Python"}]} | |
| result = _build_overlay(profile, query_result) | |
| assert result.count("User likes Python") == 1 | |
| def test_local_entries_filter(self): | |
| profile = {"memories": [{"content": "Already known fact"}]} | |
| result = _build_overlay(profile, {}, local_entries=["Already known fact"]) | |
| # The profile item matches a local entry, should be filtered | |
| assert result == "" | |
| def test_max_five_items_per_section(self): | |
| profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]} | |
| result = _build_overlay(profile, {}) | |
| # Should only include first 5 | |
| assert "Fact 0" in result | |
| assert "Fact 4" in result | |
| assert "Fact 5" not in result | |
| def test_none_content_handled(self): | |
| profile = {"memories": [{"content": None}, {"content": "Real fact"}]} | |
| result = _build_overlay(profile, {}) | |
| assert "Real fact" in result | |
| def test_truncation_at_320_chars(self): | |
| long_content = "x" * 500 | |
| profile = {"memories": [{"content": long_content}]} | |
| result = _build_overlay(profile, {}) | |
| # Each item is compacted to 320 chars max | |
| for line in result.split("\n"): | |
| if line.startswith("- "): | |
| assert len(line) <= 322 # "- " + 320 | |
| # =========================================================================== | |
| # RetainDBMemoryProvider tests | |
| # =========================================================================== | |
| class TestRetainDBMemoryProvider: | |
| """Test the main plugin class.""" | |
| def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"): | |
| monkeypatch.setenv("RETAINDB_API_KEY", api_key) | |
| monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) | |
| (tmp_path / ".hermes").mkdir(exist_ok=True) | |
| provider = RetainDBMemoryProvider() | |
| return provider | |
| def test_name(self): | |
| p = RetainDBMemoryProvider() | |
| assert p.name == "retaindb" | |
| def test_is_available_without_key(self): | |
| p = RetainDBMemoryProvider() | |
| assert p.is_available() is False | |
| def test_is_available_with_key(self, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test") | |
| p = RetainDBMemoryProvider() | |
| assert p.is_available() is True | |
| def test_config_schema(self): | |
| p = RetainDBMemoryProvider() | |
| schema = p.get_config_schema() | |
| assert len(schema) == 3 | |
| keys = [s["key"] for s in schema] | |
| assert "api_key" in keys | |
| assert "base_url" in keys | |
| assert "project" in keys | |
| def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| assert p._client is not None | |
| assert p._queue is not None | |
| assert p._session_id == "test-session" | |
| p.shutdown() | |
| def test_initialize_default_project(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| assert p._client.project == "default" | |
| p.shutdown() | |
| def test_initialize_explicit_project(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_PROJECT", "my-project") | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| assert p._client.project == "my-project" | |
| p.shutdown() | |
| def test_initialize_profile_project(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| profile_home = str(tmp_path / "profiles" / "coder") | |
| p.initialize("test-session", hermes_home=profile_home) | |
| assert p._client.project == "hermes-coder" | |
| p.shutdown() | |
| def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| soul_path = tmp_path / ".hermes" / "SOUL.md" | |
| soul_path.write_text("I am a helpful agent.") | |
| with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed: | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| # Give thread time to start | |
| time.sleep(0.5) | |
| mock_seed.assert_called_once_with("I am a helpful agent.") | |
| p.shutdown() | |
| def test_system_prompt_block(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| block = p.system_prompt_block() | |
| assert "RetainDB Memory" in block | |
| assert "Active" in block | |
| p.shutdown() | |
| def test_tool_schemas_count(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| schemas = p.get_tool_schemas() | |
| assert len(schemas) == 10 # 5 memory + 5 file tools | |
| names = [s["name"] for s in schemas] | |
| assert "retaindb_profile" in names | |
| assert "retaindb_search" in names | |
| assert "retaindb_context" in names | |
| assert "retaindb_remember" in names | |
| assert "retaindb_forget" in names | |
| assert "retaindb_upload_file" in names | |
| assert "retaindb_list_files" in names | |
| assert "retaindb_read_file" in names | |
| assert "retaindb_ingest_file" in names | |
| assert "retaindb_delete_file" in names | |
| def test_handle_tool_call_not_initialized(self): | |
| p = RetainDBMemoryProvider() | |
| result = json.loads(p.handle_tool_call("retaindb_profile", {})) | |
| assert "error" in result | |
| assert "not initialized" in result["error"] | |
| def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_nonexistent", {})) | |
| assert result == {"error": "Unknown tool: retaindb_nonexistent"} | |
| p.shutdown() | |
| def test_dispatch_profile(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "get_profile", return_value={"memories": []}): | |
| result = json.loads(p.handle_tool_call("retaindb_profile", {})) | |
| assert "memories" in result | |
| p.shutdown() | |
| def test_dispatch_search_requires_query(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_search", {})) | |
| assert result == {"error": "query is required"} | |
| p.shutdown() | |
| def test_dispatch_search(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}): | |
| result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"})) | |
| assert "results" in result | |
| p.shutdown() | |
| def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "search") as mock_search: | |
| mock_search.return_value = {"results": []} | |
| p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100}) | |
| # top_k should be capped at 20 | |
| assert mock_search.call_args[1]["top_k"] == 20 | |
| p.shutdown() | |
| def test_dispatch_remember(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}): | |
| result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"})) | |
| assert result["id"] == "mem-1" | |
| p.shutdown() | |
| def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_remember", {})) | |
| assert result == {"error": "content is required"} | |
| p.shutdown() | |
| def test_dispatch_forget(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "delete_memory", return_value={"deleted": True}): | |
| result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"})) | |
| assert result["deleted"] is True | |
| p.shutdown() | |
| def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_forget", {})) | |
| assert result == {"error": "memory_id is required"} | |
| p.shutdown() | |
| def test_dispatch_context(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \ | |
| patch.object(p._client, "get_profile", return_value={"memories": []}): | |
| result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"})) | |
| assert "context" in result | |
| assert "raw" in result | |
| p.shutdown() | |
| def test_dispatch_file_list(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "list_files", return_value={"files": []}): | |
| result = json.loads(p.handle_tool_call("retaindb_list_files", {})) | |
| assert "files" in result | |
| p.shutdown() | |
| def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_upload_file", {})) | |
| assert "error" in result | |
| def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"})) | |
| assert "File not found" in result["error"] | |
| p.shutdown() | |
| def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_read_file", {})) | |
| assert result == {"error": "file_id is required"} | |
| p.shutdown() | |
| def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_ingest_file", {})) | |
| assert result == {"error": "file_id is required"} | |
| p.shutdown() | |
| def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| result = json.loads(p.handle_tool_call("retaindb_delete_file", {})) | |
| assert result == {"error": "file_id is required"} | |
| p.shutdown() | |
| def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch): | |
| p = self._make_provider(tmp_path, monkeypatch) | |
| p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) | |
| with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")): | |
| result = json.loads(p.handle_tool_call("retaindb_profile", {})) | |
| assert "API exploded" in result["error"] | |
| p.shutdown() | |
| # =========================================================================== | |
| # Prefetch and thread management tests | |
| # =========================================================================== | |
| class TestPrefetch: | |
| """Test background prefetch and thread accumulation prevention.""" | |
| def _make_initialized_provider(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| return p | |
| def test_queue_prefetch_skips_without_client(self): | |
| p = RetainDBMemoryProvider() | |
| p.queue_prefetch("test") # Should not raise | |
| def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch): | |
| p = self._make_initialized_provider(tmp_path, monkeypatch) | |
| result = p.prefetch("test") | |
| assert result == "" | |
| p.shutdown() | |
| def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch): | |
| p = self._make_initialized_provider(tmp_path, monkeypatch) | |
| # Manually set the cached result | |
| with p._lock: | |
| p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests" | |
| result = p.prefetch("test") | |
| assert "User likes tests" in result | |
| # Should be consumed | |
| assert p.prefetch("test") == "" | |
| p.shutdown() | |
| def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch): | |
| p = self._make_initialized_provider(tmp_path, monkeypatch) | |
| with p._lock: | |
| p._dialectic_result = "User is a software engineer who prefers Python." | |
| result = p.prefetch("test") | |
| assert "[RetainDB User Synthesis]" in result | |
| assert "software engineer" in result | |
| p.shutdown() | |
| def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch): | |
| p = self._make_initialized_provider(tmp_path, monkeypatch) | |
| with p._lock: | |
| p._agent_model = { | |
| "memory_count": 5, | |
| "persona": "Helpful coding assistant", | |
| "persistent_instructions": ["Be concise", "Use Python"], | |
| "working_style": "Direct and efficient", | |
| } | |
| result = p.prefetch("test") | |
| assert "[RetainDB Agent Self-Model]" in result | |
| assert "Helpful coding assistant" in result | |
| assert "Be concise" in result | |
| assert "Direct and efficient" in result | |
| p.shutdown() | |
| def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch): | |
| p = self._make_initialized_provider(tmp_path, monkeypatch) | |
| with p._lock: | |
| p._agent_model = {"memory_count": 0} | |
| result = p.prefetch("test") | |
| assert "Agent Self-Model" not in result | |
| p.shutdown() | |
| def test_thread_accumulation_guard(self, tmp_path, monkeypatch): | |
| """Verify old prefetch threads are joined before new ones spawn.""" | |
| p = self._make_initialized_provider(tmp_path, monkeypatch) | |
| # Mock the prefetch methods to be slow | |
| with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \ | |
| patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \ | |
| patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)): | |
| p.queue_prefetch("query 1") | |
| first_threads = list(p._prefetch_threads) | |
| assert len(first_threads) == 3 | |
| # Call again — should join first batch before spawning new | |
| p.queue_prefetch("query 2") | |
| second_threads = list(p._prefetch_threads) | |
| assert len(second_threads) == 3 | |
| # Should be different thread objects | |
| for t in second_threads: | |
| assert t not in first_threads | |
| p.shutdown() | |
| def test_reasoning_level_short(self): | |
| assert RetainDBMemoryProvider._reasoning_level("hi") == "low" | |
| def test_reasoning_level_medium(self): | |
| assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium" | |
| def test_reasoning_level_long(self): | |
| assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high" | |
| # =========================================================================== | |
| # sync_turn tests | |
| # =========================================================================== | |
| class TestSyncTurn: | |
| """Test turn synchronization via the write queue.""" | |
| def test_sync_turn_enqueues(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| with patch.object(p._queue, "enqueue") as mock_enqueue: | |
| p.sync_turn("user msg", "assistant msg") | |
| mock_enqueue.assert_called_once() | |
| args = mock_enqueue.call_args[0] | |
| assert args[0] == "default" # user_id | |
| assert args[1] == "test-session" # session_id | |
| msgs = args[2] | |
| assert len(msgs) == 2 | |
| assert msgs[0]["role"] == "user" | |
| assert msgs[1]["role"] == "assistant" | |
| p.shutdown() | |
| def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| with patch.object(p._queue, "enqueue") as mock_enqueue: | |
| p.sync_turn("", "assistant msg") | |
| mock_enqueue.assert_not_called() | |
| p.shutdown() | |
| # =========================================================================== | |
| # on_memory_write hook tests | |
| # =========================================================================== | |
| class TestOnMemoryWrite: | |
| """Test the built-in memory mirror hook.""" | |
| def test_mirrors_add_action(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add: | |
| p.on_memory_write("add", "user", "User prefers dark mode") | |
| mock_add.assert_called_once() | |
| assert mock_add.call_args[1]["memory_type"] == "preference" | |
| p.shutdown() | |
| def test_skips_non_add_action(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| with patch.object(p._client, "add_memory") as mock_add: | |
| p.on_memory_write("remove", "user", "something") | |
| mock_add.assert_not_called() | |
| p.shutdown() | |
| def test_skips_empty_content(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| with patch.object(p._client, "add_memory") as mock_add: | |
| p.on_memory_write("add", "user", "") | |
| mock_add.assert_not_called() | |
| p.shutdown() | |
| def test_memory_target_maps_to_type(self, tmp_path, monkeypatch): | |
| monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") | |
| hermes_home = tmp_path / ".hermes" | |
| hermes_home.mkdir(exist_ok=True) | |
| monkeypatch.setenv("HERMES_HOME", str(hermes_home)) | |
| p = RetainDBMemoryProvider() | |
| p.initialize("test-session", hermes_home=str(hermes_home)) | |
| with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add: | |
| p.on_memory_write("add", "memory", "Some env fact") | |
| assert mock_add.call_args[1]["memory_type"] == "factual" | |
| p.shutdown() | |
| # =========================================================================== | |
| # register() test | |
| # =========================================================================== | |
| class TestRegister: | |
| def test_register_calls_register_memory_provider(self): | |
| from plugins.memory.retaindb import register | |
| ctx = MagicMock() | |
| register(ctx) | |
| ctx.register_memory_provider.assert_called_once() | |
| arg = ctx.register_memory_provider.call_args[0][0] | |
| assert isinstance(arg, RetainDBMemoryProvider) | |