File size: 10,130 Bytes
14edff4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 | from __future__ import annotations
import asyncio
import json
import sqlite3
import threading
from pathlib import Path
from ..items import TResponseInputItem
from .session import SessionABC
class SQLiteSession(SessionABC):
"""SQLite-based implementation of session storage.
This implementation stores conversation history in a SQLite database.
By default, uses an in-memory database that is lost when the process ends.
For persistent storage, provide a file path.
"""
def __init__(
self,
session_id: str,
db_path: str | Path = ":memory:",
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
):
"""Initialize the SQLite session.
Args:
session_id: Unique identifier for the conversation session
db_path: Path to the SQLite database file. Defaults to ':memory:' (in-memory database)
sessions_table: Name of the table to store session metadata. Defaults to
'agent_sessions'
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
"""
self.session_id = session_id
self.db_path = db_path
self.sessions_table = sessions_table
self.messages_table = messages_table
self._local = threading.local()
self._lock = threading.Lock()
# For in-memory databases, we need a shared connection to avoid thread isolation
# For file databases, we use thread-local connections for better concurrency
self._is_memory_db = str(db_path) == ":memory:"
if self._is_memory_db:
self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False)
self._shared_connection.execute("PRAGMA journal_mode=WAL")
self._init_db_for_connection(self._shared_connection)
else:
# For file databases, initialize the schema once since it persists
init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
init_conn.execute("PRAGMA journal_mode=WAL")
self._init_db_for_connection(init_conn)
init_conn.close()
def _get_connection(self) -> sqlite3.Connection:
"""Get a database connection."""
if self._is_memory_db:
# Use shared connection for in-memory database to avoid thread isolation
return self._shared_connection
else:
# Use thread-local connections for file databases
if not hasattr(self._local, "connection"):
self._local.connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._local.connection.execute("PRAGMA journal_mode=WAL")
assert isinstance(self._local.connection, sqlite3.Connection), (
f"Expected sqlite3.Connection, got {type(self._local.connection)}"
)
return self._local.connection
def _init_db_for_connection(self, conn: sqlite3.Connection) -> None:
"""Initialize the database schema for a specific connection."""
conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
session_id TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.messages_table} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
message_data TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES {self.sessions_table} (session_id)
ON DELETE CASCADE
)
"""
)
conn.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_{self.messages_table}_session_id
ON {self.messages_table} (session_id, created_at)
"""
)
conn.commit()
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.
Args:
limit: Maximum number of items to retrieve. If None, retrieves all items.
When specified, returns the latest N items in chronological order.
Returns:
List of input items representing the conversation history
"""
def _get_items_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
if limit is None:
# Fetch all items in chronological order
cursor = conn.execute(
f"""
SELECT message_data FROM {self.messages_table}
WHERE session_id = ?
ORDER BY created_at ASC
""",
(self.session_id,),
)
else:
# Fetch the latest N items in chronological order
cursor = conn.execute(
f"""
SELECT message_data FROM {self.messages_table}
WHERE session_id = ?
ORDER BY created_at DESC
LIMIT ?
""",
(self.session_id, limit),
)
rows = cursor.fetchall()
# Reverse to get chronological order when using DESC
if limit is not None:
rows = list(reversed(rows))
items = []
for (message_data,) in rows:
try:
item = json.loads(message_data)
items.append(item)
except json.JSONDecodeError:
# Skip invalid JSON entries
continue
return items
return await asyncio.to_thread(_get_items_sync)
async def add_items(self, items: list[TResponseInputItem]) -> None:
"""Add new items to the conversation history.
Args:
items: List of input items to add to the history
"""
if not items:
return
def _add_items_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
# Ensure session exists
conn.execute(
f"""
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
""",
(self.session_id,),
)
# Add items
message_data = [(self.session_id, json.dumps(item)) for item in items]
conn.executemany(
f"""
INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?)
""",
message_data,
)
# Update session timestamp
conn.execute(
f"""
UPDATE {self.sessions_table}
SET updated_at = CURRENT_TIMESTAMP
WHERE session_id = ?
""",
(self.session_id,),
)
conn.commit()
await asyncio.to_thread(_add_items_sync)
async def pop_item(self) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.
Returns:
The most recent item if it exists, None if the session is empty
"""
def _pop_item_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
# Use DELETE with RETURNING to atomically delete and return the most recent item
cursor = conn.execute(
f"""
DELETE FROM {self.messages_table}
WHERE id = (
SELECT id FROM {self.messages_table}
WHERE session_id = ?
ORDER BY created_at DESC
LIMIT 1
)
RETURNING message_data
""",
(self.session_id,),
)
result = cursor.fetchone()
conn.commit()
if result:
message_data = result[0]
try:
item = json.loads(message_data)
return item
except json.JSONDecodeError:
# Return None for corrupted JSON entries (already deleted)
return None
return None
return await asyncio.to_thread(_pop_item_sync)
async def clear_session(self) -> None:
"""Clear all items for this session."""
def _clear_session_sync():
conn = self._get_connection()
with self._lock if self._is_memory_db else threading.Lock():
conn.execute(
f"DELETE FROM {self.messages_table} WHERE session_id = ?",
(self.session_id,),
)
conn.execute(
f"DELETE FROM {self.sessions_table} WHERE session_id = ?",
(self.session_id,),
)
conn.commit()
await asyncio.to_thread(_clear_session_sync)
def close(self) -> None:
"""Close the database connection."""
if self._is_memory_db:
if hasattr(self, "_shared_connection"):
self._shared_connection.close()
else:
if hasattr(self._local, "connection"):
self._local.connection.close()
|