Spaces:
Running
Running
| import sqlite3 | |
| import json | |
| from typing import Dict, Optional, List | |
| from datetime import datetime | |
| import threading | |
| class AgentPersistence: | |
| """Handles saving and loading agent state using SQLite.""" | |
| def __init__(self, db_path: str = "agent_memory.db"): | |
| self.db_path = db_path | |
| self._local = threading.local() | |
| self._init_db() | |
| def _get_conn(self) -> sqlite3.Connection: | |
| """Get thread-local database connection.""" | |
| if not hasattr(self._local, "conn"): | |
| self._local.conn = sqlite3.connect(self.db_path) | |
| # Enable foreign keys | |
| self._local.conn.execute("PRAGMA foreign_keys = ON") | |
| return self._local.conn | |
| def _init_db(self): | |
| """Initialize the database schema.""" | |
| with self._get_conn() as conn: | |
| conn.executescript(""" | |
| CREATE TABLE IF NOT EXISTS agents ( | |
| name TEXT PRIMARY KEY, | |
| persona TEXT, | |
| instruction TEXT, | |
| strategy TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| CREATE TABLE IF NOT EXISTS agent_states ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| agent_name TEXT, | |
| task TEXT, | |
| history TEXT, -- Store JSON as TEXT | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (agent_name) REFERENCES agents(name) ON DELETE CASCADE | |
| ); | |
| CREATE INDEX IF NOT EXISTS idx_agent_states_agent_name | |
| ON agent_states(agent_name); | |
| CREATE TRIGGER IF NOT EXISTS update_agent_timestamp | |
| AFTER UPDATE ON agents | |
| BEGIN | |
| UPDATE agents | |
| SET last_updated = CURRENT_TIMESTAMP | |
| WHERE name = NEW.name; | |
| END; | |
| """) | |
| def save_agent_state(self, agent) -> bool: | |
| """Save agent's current state to database.""" | |
| try: | |
| with self._get_conn() as conn: | |
| # First, update or insert the agent's base information | |
| conn.execute(""" | |
| INSERT INTO agents (name, persona, instruction, strategy) | |
| VALUES (?, ?, ?, ?) | |
| ON CONFLICT(name) DO UPDATE SET | |
| persona = excluded.persona, | |
| instruction = excluded.instruction, | |
| strategy = excluded.strategy | |
| """, ( | |
| agent.name, | |
| agent.persona, | |
| agent.instruction, | |
| agent.strategy.__class__.__name__ if agent.strategy else None | |
| )) | |
| # Convert history to JSON string | |
| history_json = json.dumps(agent.history) | |
| # Then, save the current state | |
| conn.execute(""" | |
| INSERT INTO agent_states (agent_name, task, history) | |
| VALUES (?, ?, ?) | |
| """, ( | |
| agent.name, | |
| agent.task, | |
| history_json # Save as JSON string | |
| )) | |
| return True | |
| except Exception as e: | |
| print(f"Error saving agent state: {str(e)}") | |
| return False | |
| def load_agent_state(self, agent, agent_name: Optional[str] = None) -> bool: | |
| """Load agent's most recent state from database.""" | |
| try: | |
| name_to_load = agent_name or agent.name | |
| with self._get_conn() as conn: | |
| # Load agent base information | |
| agent_data = conn.execute(""" | |
| SELECT persona, instruction, strategy | |
| FROM agents | |
| WHERE name = ? | |
| """, (name_to_load,)).fetchone() | |
| if not agent_data: | |
| return False | |
| # Load most recent state | |
| state_data = conn.execute(""" | |
| SELECT task, history | |
| FROM agent_states | |
| WHERE agent_name = ? | |
| ORDER BY timestamp DESC | |
| LIMIT 1 | |
| """, (name_to_load,)).fetchone() | |
| if not state_data: | |
| return False | |
| # Update agent with loaded data | |
| agent.persona = agent_data[0] | |
| agent.instruction = agent_data[1] | |
| if agent_data[2]: | |
| agent.strategy = agent_data[2] | |
| agent.task = state_data[0] | |
| # Parse JSON string back to list | |
| agent._history = json.loads(state_data[1]) if state_data[1] else [] | |
| return True | |
| except Exception as e: | |
| print(f"Error loading agent state: {str(e)}") | |
| return False | |
| def get_agent_history(self, agent_name: str, limit: int = 10) -> List[Dict]: | |
| """Retrieve the last N states for an agent.""" | |
| try: | |
| with self._get_conn() as conn: | |
| states = conn.execute(""" | |
| SELECT task, history, timestamp | |
| FROM agent_states | |
| WHERE agent_name = ? | |
| ORDER BY timestamp DESC | |
| LIMIT ? | |
| """, (agent_name, limit)).fetchall() | |
| return [{ | |
| 'task': state[0], | |
| 'history': json.loads(state[1]) if state[1] else [], | |
| 'timestamp': state[2] | |
| } for state in states] | |
| except Exception as e: | |
| print(f"Error retrieving agent history: {str(e)}") | |
| return [] | |
| def list_saved_agents(self) -> Dict[str, datetime]: | |
| """Return a dictionary of saved agent names and their last update times.""" | |
| saved_agents = {} | |
| try: | |
| with self._get_conn() as conn: | |
| results = conn.execute(""" | |
| SELECT name, last_updated | |
| FROM agents | |
| ORDER BY last_updated DESC | |
| """).fetchall() | |
| for name, timestamp in results: | |
| saved_agents[name] = datetime.fromisoformat(timestamp) | |
| except Exception as e: | |
| print(f"Error listing saved agents: {str(e)}") | |
| return saved_agents | |
| def delete_agent_state(self, agent_name: str) -> bool: | |
| """Delete all data for an agent.""" | |
| try: | |
| with self._get_conn() as conn: | |
| # Due to foreign key constraints, this will also delete | |
| # all associated states | |
| conn.execute("DELETE FROM agents WHERE name = ?", (agent_name,)) | |
| return True | |
| except Exception as e: | |
| print(f"Error deleting agent state: {str(e)}") | |
| return False | |
| def cleanup_old_states(self, agent_name: str, keep_last: int = 10) -> bool: | |
| """Clean up old states keeping only the N most recent ones.""" | |
| try: | |
| with self._get_conn() as conn: | |
| conn.execute(""" | |
| DELETE FROM agent_states | |
| WHERE agent_name = ? | |
| AND id NOT IN ( | |
| SELECT id FROM agent_states | |
| WHERE agent_name = ? | |
| ORDER BY timestamp DESC | |
| LIMIT ? | |
| ) | |
| """, (agent_name, agent_name, keep_last)) | |
| return True | |
| except Exception as e: | |
| print(f"Error cleaning up old states: {str(e)}") | |
| return False |