| import os |
| import time |
| import hashlib |
| from schemas.agent import MemoryEntry |
|
|
| DB_PATH = os.getenv("AGENT_MEMORY_PATH", "/tmp/agent_memory") |
|
|
|
|
| class AgentMemory: |
| def __init__(self): |
| self._store = {} |
| self._init_path() |
|
|
| def _init_path(self): |
| os.makedirs(DB_PATH, exist_ok=True) |
|
|
| async def store(self, entry: MemoryEntry) -> str: |
| entry_id = hashlib.md5(f"{entry.session_id}:{entry.action}:{time.time()}".encode()).hexdigest() |
| key = f"{entry.session_id}:{entry_id}" |
| self._store[key] = { |
| "session_id": entry.session_id, |
| "task_id": entry.task_id, |
| "action": entry.action, |
| "prompt": entry.prompt, |
| "response": entry.response, |
| "tool_used": entry.tool_used, |
| "result": entry.result, |
| "success": entry.success, |
| "tokens_used": entry.tokens_used, |
| "latency_ms": entry.latency_ms, |
| "created_at": time.time() |
| } |
| return entry_id |
|
|
| async def recall(self, session_id: str, query: str = "", limit: int = 20) -> list[MemoryEntry]: |
| results = [] |
| for k, v in sorted(self._store.items(), key=lambda x: x[1].get("created_at", 0), reverse=True): |
| if v["session_id"] == session_id: |
| results.append(MemoryEntry( |
| session_id=v["session_id"], |
| task_id=v.get("task_id"), |
| action=v.get("action", ""), |
| prompt=v.get("prompt"), |
| response=v.get("response"), |
| tool_used=v.get("tool_used"), |
| result=v.get("result"), |
| success=v.get("success", True), |
| tokens_used=v.get("tokens_used", 0), |
| latency_ms=v.get("latency_ms", 0) |
| )) |
| if len(results) >= limit: |
| break |
| if query and results: |
| scored = [(self._simple_score(entry, query), entry) for entry in results] |
| scored.sort(key=lambda x: x[0], reverse=True) |
| results = [e for _, e in scored[:limit]] |
| return results |
|
|
| def _simple_score(self, entry: MemoryEntry, query: str) -> float: |
| q = query.lower() |
| score = 0.0 |
| if entry.action and q in entry.action.lower(): |
| score += 0.5 |
| if entry.result and q in entry.result.lower(): |
| score += 0.4 |
| if entry.prompt and q in entry.prompt.lower(): |
| score += 0.3 |
| if entry.response and q in entry.response.lower(): |
| score += 0.2 |
| if entry.tool_used and q in entry.tool_used.lower(): |
| score += 0.1 |
| return score |
|
|
| async def get_session_summary(self, session_id: str) -> str: |
| entries = await self.recall(session_id, "", limit=50) |
| if not entries: |
| return "No prior session data." |
| success_count = sum(1 for e in entries if e.success) |
| total = len(entries) |
| tools_used = list(set(e.tool_used for e in entries if e.tool_used)) |
| last_action = entries[0].action if entries else "none" |
| return (f"Session: {success_count}/{total} successful actions. " |
| f"Tools used: {', '.join(tools_used) if tools_used else 'none'}. " |
| f"Last action: {last_action}") |
|
|
| async def clear_session(self, session_id: str): |
| keys_to_del = [k for k in self._store if k.startswith(f"{session_id}:")] |
| for k in keys_to_del: |
| self._store.pop(k, None) |
|
|