| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import json |
| | import sqlite3 |
| | import threading |
| | from pathlib import Path |
| |
|
| | from ..items import TResponseInputItem |
| | from .session import SessionABC |
| |
|
| |
|
| | class SQLiteSession(SessionABC): |
| | """SQLite-based implementation of session storage. |
| | |
| | This implementation stores conversation history in a SQLite database. |
| | By default, uses an in-memory database that is lost when the process ends. |
| | For persistent storage, provide a file path. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | session_id: str, |
| | db_path: str | Path = ":memory:", |
| | sessions_table: str = "agent_sessions", |
| | messages_table: str = "agent_messages", |
| | ): |
| | """Initialize the SQLite session. |
| | |
| | Args: |
| | session_id: Unique identifier for the conversation session |
| | db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database) |
| | sessions_table: Name of the table to store session metadata. Defaults to |
| | 'agent_sessions' |
| | messages_table: Name of the table to store message data. Defaults to 'agent_messages' |
| | """ |
| | self.session_id = session_id |
| | self.db_path = db_path |
| | self.sessions_table = sessions_table |
| | self.messages_table = messages_table |
| | self._local = threading.local() |
| | self._lock = threading.Lock() |
| |
|
| | |
| | |
| | self._is_memory_db = str(db_path) == ":memory:" |
| | if self._is_memory_db: |
| | self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) |
| | self._shared_connection.execute("PRAGMA journal_mode=WAL") |
| | self._init_db_for_connection(self._shared_connection) |
| | else: |
| | |
| | init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) |
| | init_conn.execute("PRAGMA journal_mode=WAL") |
| | self._init_db_for_connection(init_conn) |
| | init_conn.close() |
| |
|
| | def _get_connection(self) -> sqlite3.Connection: |
| | """Get a database connection.""" |
| | if self._is_memory_db: |
| | |
| | return self._shared_connection |
| | else: |
| | |
| | if not hasattr(self._local, "connection"): |
| | self._local.connection = sqlite3.connect( |
| | str(self.db_path), |
| | check_same_thread=False, |
| | ) |
| | self._local.connection.execute("PRAGMA journal_mode=WAL") |
| | assert isinstance(self._local.connection, sqlite3.Connection), ( |
| | f"Expected sqlite3.Connection, got {type(self._local.connection)}" |
| | ) |
| | return self._local.connection |
| |
|
| | def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: |
| | """Initialize the database schema for a specific connection.""" |
| | conn.execute( |
| | f""" |
| | CREATE TABLE IF NOT EXISTS {self.sessions_table} ( |
| | session_id TEXT PRIMARY KEY, |
| | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| | updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
| | ) |
| | """ |
| | ) |
| |
|
| | conn.execute( |
| | f""" |
| | CREATE TABLE IF NOT EXISTS {self.messages_table} ( |
| | id INTEGER PRIMARY KEY AUTOINCREMENT, |
| | session_id TEXT NOT NULL, |
| | message_data TEXT NOT NULL, |
| | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, |
| | FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id) |
| | ON DELETE CASCADE |
| | ) |
| | """ |
| | ) |
| |
|
| | conn.execute( |
| | f""" |
| | CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id |
| | ON {self.messages_table} (session_id, created_at) |
| | """ |
| | ) |
| |
|
| | conn.commit() |
| |
|
| | async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: |
| | """Retrieve the conversation history for this session. |
| | |
| | Args: |
| | limit: Maximum number of items to retrieve. If None, retrieves all items. |
| | When specified, returns the latest N items in chronological order. |
| | |
| | Returns: |
| | List of input items representing the conversation history |
| | """ |
| |
|
| | def _get_items_sync(): |
| | conn = self._get_connection() |
| | with self._lock if self._is_memory_db else threading.Lock(): |
| | if limit is None: |
| | |
| | cursor = conn.execute( |
| | f""" |
| | SELECT message_data FROM {self.messages_table} |
| | WHERE session_id = ? |
| | ORDER BY created_at ASC |
| | """, |
| | (self.session_id,), |
| | ) |
| | else: |
| | |
| | cursor = conn.execute( |
| | f""" |
| | SELECT message_data FROM {self.messages_table} |
| | WHERE session_id = ? |
| | ORDER BY created_at DESC |
| | LIMIT ? |
| | """, |
| | (self.session_id, limit), |
| | ) |
| |
|
| | rows = cursor.fetchall() |
| |
|
| | |
| | if limit is not None: |
| | rows = list(reversed(rows)) |
| |
|
| | items = [] |
| | for (message_data,) in rows: |
| | try: |
| | item = json.loads(message_data) |
| | items.append(item) |
| | except json.JSONDecodeError: |
| | |
| | continue |
| |
|
| | return items |
| |
|
| | return await asyncio.to_thread(_get_items_sync) |
| |
|
| | async def add_items(self, items: list[TResponseInputItem]) -> None: |
| | """Add new items to the conversation history. |
| | |
| | Args: |
| | items: List of input items to add to the history |
| | """ |
| | if not items: |
| | return |
| |
|
| | def _add_items_sync(): |
| | conn = self._get_connection() |
| |
|
| | with self._lock if self._is_memory_db else threading.Lock(): |
| | |
| | conn.execute( |
| | f""" |
| | INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) |
| | """, |
| | (self.session_id,), |
| | ) |
| |
|
| | |
| | message_data = [(self.session_id, json.dumps(item)) for item in items] |
| | conn.executemany( |
| | f""" |
| | INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) |
| | """, |
| | message_data, |
| | ) |
| |
|
| | |
| | conn.execute( |
| | f""" |
| | UPDATE {self.sessions_table} |
| | SET updated_at = CURRENT_TIMESTAMP |
| | WHERE session_id = ? |
| | """, |
| | (self.session_id,), |
| | ) |
| |
|
| | conn.commit() |
| |
|
| | await asyncio.to_thread(_add_items_sync) |
| |
|
| | async def pop_item(self) -> TResponseInputItem | None: |
| | """Remove and return the most recent item from the session. |
| | |
| | Returns: |
| | The most recent item if it exists, None if the session is empty |
| | """ |
| |
|
| | def _pop_item_sync(): |
| | conn = self._get_connection() |
| | with self._lock if self._is_memory_db else threading.Lock(): |
| | |
| | cursor = conn.execute( |
| | f""" |
| | DELETE FROM {self.messages_table} |
| | WHERE id = ( |
| | SELECT id FROM {self.messages_table} |
| | WHERE session_id = ? |
| | ORDER BY created_at DESC |
| | LIMIT 1 |
| | ) |
| | RETURNING message_data |
| | """, |
| | (self.session_id,), |
| | ) |
| |
|
| | result = cursor.fetchone() |
| | conn.commit() |
| |
|
| | if result: |
| | message_data = result[0] |
| | try: |
| | item = json.loads(message_data) |
| | return item |
| | except json.JSONDecodeError: |
| | |
| | return None |
| |
|
| | return None |
| |
|
| | return await asyncio.to_thread(_pop_item_sync) |
| |
|
| | async def clear_session(self) -> None: |
| | """Clear all items for this session.""" |
| |
|
| | def _clear_session_sync(): |
| | conn = self._get_connection() |
| | with self._lock if self._is_memory_db else threading.Lock(): |
| | conn.execute( |
| | f"DELETE FROM {self.messages_table} WHERE session_id = ?", |
| | (self.session_id,), |
| | ) |
| | conn.execute( |
| | f"DELETE FROM {self.sessions_table} WHERE session_id = ?", |
| | (self.session_id,), |
| | ) |
| | conn.commit() |
| |
|
| | await asyncio.to_thread(_clear_session_sync) |
| |
|
| | def close(self) -> None: |
| | """Close the database connection.""" |
| | if self._is_memory_db: |
| | if hasattr(self, "_shared_connection"): |
| | self._shared_connection.close() |
| | else: |
| | if hasattr(self._local, "connection"): |
| | self._local.connection.close() |
| |
|