MnemoCore / tests /mocks /mock_redis.py
Granis87's picture
Initial upload of MnemoCore
dbb04e4 verified
"""
Mock Async Redis Storage
========================
In-memory mock implementation of AsyncRedisStorage for offline testing.
Uses fakeredis for Redis-compatible behavior when available,
falls back to pure Python in-memory implementation.
Implements all public methods from mnemocore.core.async_storage.AsyncRedisStorage
"""
from typing import Dict, List, Optional, Any
import json
from dataclasses import dataclass, field
from datetime import datetime
import asyncio
from loguru import logger
# Try to import fakeredis for realistic Redis behavior
try:
import fakeredis.aioredis as fakeredis
HAS_FAKEREDIS = True
except ImportError:
HAS_FAKEREDIS = False
logger.info("[MockRedis] fakeredis not available, using in-memory dict storage")
@dataclass
class StreamEntry:
"""Mock Redis Stream entry."""
id: str
data: Dict[str, str]
class InMemoryRedisClient:
"""
Pure Python in-memory Redis client mock.
Implements the subset of Redis commands used by AsyncRedisStorage:
- get/set/setex/delete
- mget
- zadd/zrange/zrem
- xadd
- ping
- pipeline
"""
def __init__(self, decode_responses: bool = True):
self._data: Dict[str, Any] = {}
self._ttls: Dict[str, int] = {}
self._sorted_sets: Dict[str, Dict[str, float]] = {}
self._streams: Dict[str, List[StreamEntry]] = {}
self._decode_responses = decode_responses
self._id_counter = 0
async def ping(self) -> bool:
return True
async def get(self, key: str) -> Optional[str]:
# Check TTL
if key in self._ttls:
import time
if time.time() > self._ttls[key]:
del self._data[key]
del self._ttls[key]
return None
return self._data.get(key)
async def set(self, key: str, value: str) -> bool:
self._data[key] = value
return True
async def setex(self, key: str, ttl: int, value: str) -> bool:
import time
self._data[key] = value
self._ttls[key] = int(time.time()) + ttl
return True
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
self._ttls.pop(key, None)
return 1
return 0
async def mget(self, keys: List[str]) -> List[Optional[str]]:
import time
results = []
for key in keys:
# Check TTL
if key in self._ttls:
if time.time() > self._ttls[key]:
del self._data[key]
del self._ttls[key]
results.append(None)
continue
results.append(self._data.get(key))
return results
async def zadd(self, key: str, mapping: Dict[str, float]) -> int:
if key not in self._sorted_sets:
self._sorted_sets[key] = {}
added = 0
for member, score in mapping.items():
if member not in self._sorted_sets[key]:
added += 1
self._sorted_sets[key][member] = score
return added
async def zrange(self, key: str, start: int, stop: int) -> List[str]:
if key not in self._sorted_sets:
return []
# Sort by score
sorted_items = sorted(self._sorted_sets[key].items(), key=lambda x: x[1])
# Handle Python-style slicing (stop is inclusive in Redis ZRANGE)
if stop >= 0:
stop += 1
return [item[0] for item in sorted_items[start:stop]]
async def zrem(self, key: str, member: str) -> int:
if key not in self._sorted_sets:
return 0
if member in self._sorted_sets[key]:
del self._sorted_sets[key][member]
return 1
return 0
async def xadd(self, stream: str, fields: Dict[str, str]) -> str:
if stream not in self._streams:
self._streams[stream] = []
# Generate unique ID
import time
timestamp = int(time.time() * 1000)
self._id_counter += 1
entry_id = f"{timestamp}-{self._id_counter}"
self._streams[stream].append(StreamEntry(id=entry_id, data=fields))
return entry_id
def pipeline(self):
return MockPipeline(self)
class MockPipeline:
"""Mock Redis pipeline for batch operations."""
def __init__(self, client: InMemoryRedisClient):
self._client = client
self._commands: List[tuple] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
def incr(self, key: str):
self._commands.append(('incr', key))
def expire(self, key: str, seconds: int):
self._commands.append(('expire', key, seconds))
async def execute(self) -> List[Any]:
results = []
for cmd in self._commands:
if cmd[0] == 'incr':
key = cmd[1]
current = self._client._data.get(key, '0')
new_val = int(current) + 1
self._client._data[key] = str(new_val)
results.append(new_val)
elif cmd[0] == 'expire':
import time
key, seconds = cmd[1], cmd[2]
self._client._ttls[key] = int(time.time()) + seconds
results.append(1)
self._commands.clear()
return results
async def aclose(self):
pass
class MockAsyncRedisStorage:
"""
Mock implementation of AsyncRedisStorage.
Provides full implementation of all public methods:
- store_memory()
- retrieve_memory()
- batch_retrieve()
- delete_memory()
- get_eviction_candidates()
- update_ltp()
- publish_event()
- check_health()
- close()
Uses in-memory dictionaries for storage, enabling full test isolation.
"""
def __init__(
self,
url: str = "redis://localhost:6379/0",
stream_key: str = "haim:subconscious",
max_connections: int = 10,
socket_timeout: int = 5,
password: Optional[str] = None,
client: Optional[Any] = None,
):
"""Initialize mock storage with configuration matching real AsyncRedisStorage."""
self.stream_key = stream_key
self.url = url
# Use provided client or create in-memory client
if client:
self.redis_client = client
else:
self.redis_client = InMemoryRedisClient(decode_responses=True)
async def close(self):
"""Close the client connection."""
if hasattr(self.redis_client, 'aclose'):
await self.redis_client.aclose()
elif hasattr(self.redis_client, 'close'):
await self.redis_client.close()
# --- CRUD Operations ---
async def store_memory(self, node_id: str, data: Dict[str, Any], ttl: Optional[int] = None):
"""
Store memory metadata in Redis (Key-Value) + Update LTP Index.
"""
key = f"haim:memory:{node_id}"
payload = json.dumps(data, default=str)
if ttl:
await self.redis_client.setex(key, ttl, payload)
else:
await self.redis_client.set(key, payload)
# Update LTP Index (Sorted Set)
ltp = float(data.get("ltp_strength", 0.0))
await self.redis_client.zadd("haim:ltp_index", {node_id: ltp})
async def retrieve_memory(self, node_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve memory metadata by ID.
Returns:
Dict with memory data if found, None if not found.
"""
key = f"haim:memory:{node_id}"
data = await self.redis_client.get(key)
if data:
return json.loads(data)
return None
async def batch_retrieve(self, node_ids: List[str]) -> List[Optional[Dict[str, Any]]]:
"""
Batch retrieve multiple memories using MGET.
Returns:
List of dicts (or None for not found entries).
"""
if not node_ids:
return []
keys = [f"haim:memory:{mid}" for mid in node_ids]
results = await self.redis_client.mget(keys)
parsed = []
for r in results:
if r:
try:
parsed.append(json.loads(r))
except json.JSONDecodeError:
parsed.append(None)
else:
parsed.append(None)
return parsed
async def delete_memory(self, node_id: str):
"""
Delete memory from storage and index.
"""
key = f"haim:memory:{node_id}"
await self.redis_client.delete(key)
await self.redis_client.zrem("haim:ltp_index", node_id)
# --- Index/LTP Operations ---
async def get_eviction_candidates(self, count: int = 10) -> List[str]:
"""
Get IDs of memories with the lowest LTP scores.
Returns:
List of node IDs (empty list if none found).
"""
# ZRANGE 0 (count-1) returns lowest scores
members = await self.redis_client.zrange("haim:ltp_index", 0, count - 1)
return members
async def update_ltp(self, node_id: str, new_ltp: float):
"""
Update just the LTP score in the index.
"""
await self.redis_client.zadd("haim:ltp_index", {node_id: new_ltp})
# --- Streaming (Subconscious Bus) ---
async def publish_event(self, event_type: str, payload: Dict[str, Any]):
"""
Publish an event to the Subconscious Bus (Redis Stream).
"""
# XADD expects flat dict of strings
msg = {"type": event_type}
for k, v in payload.items():
if isinstance(v, (dict, list)):
msg[k] = json.dumps(v)
else:
msg[k] = str(v)
await self.redis_client.xadd(self.stream_key, msg)
async def check_health(self) -> bool:
"""Ping Redis to check connectivity."""
try:
return await self.redis_client.ping()
except Exception:
return False
# --- Utility methods for testing ---
def _get_stored_keys(self) -> List[str]:
"""Get all stored memory keys (for testing assertions)."""
if isinstance(self.redis_client, InMemoryRedisClient):
return [k for k in self.redis_client._data.keys() if k.startswith("haim:memory:")]
return []
def _get_ltp_index(self) -> Dict[str, float]:
"""Get LTP index contents (for testing assertions)."""
if isinstance(self.redis_client, InMemoryRedisClient):
return dict(self.redis_client._sorted_sets.get("haim:ltp_index", {}))
return {}
def _get_stream_events(self) -> List[Dict[str, Any]]:
"""Get all stream events (for testing assertions)."""
if isinstance(self.redis_client, InMemoryRedisClient):
events = self.redis_client._streams.get(self.stream_key, [])
return [{"id": e.id, "data": e.data} for e in events]
return []
def _clear_all(self):
"""Clear all data (for test cleanup)."""
if isinstance(self.redis_client, InMemoryRedisClient):
self.redis_client._data.clear()
self.redis_client._ttls.clear()
self.redis_client._sorted_sets.clear()
self.redis_client._streams.clear()
# Factory function to create appropriate mock based on available dependencies
def create_mock_redis_storage(
url: str = "redis://localhost:6379/0",
stream_key: str = "haim:subconscious",
**kwargs
) -> MockAsyncRedisStorage:
"""
Create a mock Redis storage instance.
Uses fakeredis if available, otherwise falls back to in-memory dict.
"""
if HAS_FAKEREDIS:
try:
fake_client = fakeredis.FakeRedis(decode_responses=True)
return MockAsyncRedisStorage(
url=url,
stream_key=stream_key,
client=fake_client,
**kwargs
)
except Exception as e:
logger.warning(f"[MockRedis] Failed to create fakeredis client: {e}, using in-memory")
return MockAsyncRedisStorage(
url=url,
stream_key=stream_key,
**kwargs
)