agent-persistence / persistence.py
pratikshahp's picture
Create persistence.py
5f4b642 verified
import sqlite3
import json
from typing import Dict, Optional, List
from datetime import datetime
import threading
class AgentPersistence:
"""Handles saving and loading agent state using SQLite."""
def __init__(self, db_path: str = "agent_memory.db"):
self.db_path = db_path
self._local = threading.local()
self._init_db()
def _get_conn(self) -> sqlite3.Connection:
"""Get thread-local database connection."""
if not hasattr(self._local, "conn"):
self._local.conn = sqlite3.connect(self.db_path)
# Enable foreign keys
self._local.conn.execute("PRAGMA foreign_keys = ON")
return self._local.conn
def _init_db(self):
"""Initialize the database schema."""
with self._get_conn() as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS agents (
name TEXT PRIMARY KEY,
persona TEXT,
instruction TEXT,
strategy TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS agent_states (
id INTEGER PRIMARY KEY AUTOINCREMENT,
agent_name TEXT,
task TEXT,
history TEXT, -- Store JSON as TEXT
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (agent_name) REFERENCES agents(name) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_agent_states_agent_name
ON agent_states(agent_name);
CREATE TRIGGER IF NOT EXISTS update_agent_timestamp
AFTER UPDATE ON agents
BEGIN
UPDATE agents
SET last_updated = CURRENT_TIMESTAMP
WHERE name = NEW.name;
END;
""")
def save_agent_state(self, agent) -> bool:
"""Save agent's current state to database."""
try:
with self._get_conn() as conn:
# First, update or insert the agent's base information
conn.execute("""
INSERT INTO agents (name, persona, instruction, strategy)
VALUES (?, ?, ?, ?)
ON CONFLICT(name) DO UPDATE SET
persona = excluded.persona,
instruction = excluded.instruction,
strategy = excluded.strategy
""", (
agent.name,
agent.persona,
agent.instruction,
agent.strategy.__class__.__name__ if agent.strategy else None
))
# Convert history to JSON string
history_json = json.dumps(agent.history)
# Then, save the current state
conn.execute("""
INSERT INTO agent_states (agent_name, task, history)
VALUES (?, ?, ?)
""", (
agent.name,
agent.task,
history_json # Save as JSON string
))
return True
except Exception as e:
print(f"Error saving agent state: {str(e)}")
return False
def load_agent_state(self, agent, agent_name: Optional[str] = None) -> bool:
"""Load agent's most recent state from database."""
try:
name_to_load = agent_name or agent.name
with self._get_conn() as conn:
# Load agent base information
agent_data = conn.execute("""
SELECT persona, instruction, strategy
FROM agents
WHERE name = ?
""", (name_to_load,)).fetchone()
if not agent_data:
return False
# Load most recent state
state_data = conn.execute("""
SELECT task, history
FROM agent_states
WHERE agent_name = ?
ORDER BY timestamp DESC
LIMIT 1
""", (name_to_load,)).fetchone()
if not state_data:
return False
# Update agent with loaded data
agent.persona = agent_data[0]
agent.instruction = agent_data[1]
if agent_data[2]:
agent.strategy = agent_data[2]
agent.task = state_data[0]
# Parse JSON string back to list
agent._history = json.loads(state_data[1]) if state_data[1] else []
return True
except Exception as e:
print(f"Error loading agent state: {str(e)}")
return False
def get_agent_history(self, agent_name: str, limit: int = 10) -> List[Dict]:
"""Retrieve the last N states for an agent."""
try:
with self._get_conn() as conn:
states = conn.execute("""
SELECT task, history, timestamp
FROM agent_states
WHERE agent_name = ?
ORDER BY timestamp DESC
LIMIT ?
""", (agent_name, limit)).fetchall()
return [{
'task': state[0],
'history': json.loads(state[1]) if state[1] else [],
'timestamp': state[2]
} for state in states]
except Exception as e:
print(f"Error retrieving agent history: {str(e)}")
return []
def list_saved_agents(self) -> Dict[str, datetime]:
"""Return a dictionary of saved agent names and their last update times."""
saved_agents = {}
try:
with self._get_conn() as conn:
results = conn.execute("""
SELECT name, last_updated
FROM agents
ORDER BY last_updated DESC
""").fetchall()
for name, timestamp in results:
saved_agents[name] = datetime.fromisoformat(timestamp)
except Exception as e:
print(f"Error listing saved agents: {str(e)}")
return saved_agents
def delete_agent_state(self, agent_name: str) -> bool:
"""Delete all data for an agent."""
try:
with self._get_conn() as conn:
# Due to foreign key constraints, this will also delete
# all associated states
conn.execute("DELETE FROM agents WHERE name = ?", (agent_name,))
return True
except Exception as e:
print(f"Error deleting agent state: {str(e)}")
return False
def cleanup_old_states(self, agent_name: str, keep_last: int = 10) -> bool:
"""Clean up old states keeping only the N most recent ones."""
try:
with self._get_conn() as conn:
conn.execute("""
DELETE FROM agent_states
WHERE agent_name = ?
AND id NOT IN (
SELECT id FROM agent_states
WHERE agent_name = ?
ORDER BY timestamp DESC
LIMIT ?
)
""", (agent_name, agent_name, keep_last))
return True
except Exception as e:
print(f"Error cleaning up old states: {str(e)}")
return False