client_chatbot / backend /storage.py
jashdoshi77's picture
Initial commit: AI Consultant Chatbot for HF Spaces
c27eaf1
"""SQLite-backed async session storage."""
from __future__ import annotations
import json
import os
from datetime import datetime, timezone
import aiosqlite
from models import Message, Phase, SessionState
DB_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "data")
DB_PATH = os.path.join(DB_DIR, "sessions.db")
# ---------------------------------------------------------------------------
# DB initialisation
# ---------------------------------------------------------------------------
async def init_db() -> None:
"""Create the sessions table if it doesn't exist."""
os.makedirs(DB_DIR, exist_ok=True)
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
state_json TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
"""
)
await db.commit()
# ---------------------------------------------------------------------------
# CRUD helpers
# ---------------------------------------------------------------------------
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
async def load_session(session_id: str) -> SessionState | None:
"""Load a session from the database. Returns None if not found."""
async with aiosqlite.connect(DB_PATH) as db:
cursor = await db.execute(
"SELECT state_json FROM sessions WHERE session_id = ?",
(session_id,),
)
row = await cursor.fetchone()
if row is None:
return None
data = json.loads(row[0])
return SessionState(**data)
async def save_session(session_id: str, state: SessionState) -> None:
"""Insert or update a session."""
now = _now()
state.updated_at = now
if not state.created_at:
state.created_at = now
state_json = state.model_dump_json()
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""
INSERT INTO sessions (session_id, state_json, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(session_id)
DO UPDATE SET state_json = excluded.state_json,
updated_at = excluded.updated_at
""",
(session_id, state_json, state.created_at, now),
)
await db.commit()
async def list_sessions() -> list[dict]:
"""Return lightweight info about all sessions."""
async with aiosqlite.connect(DB_PATH) as db:
cursor = await db.execute(
"SELECT session_id, state_json, created_at, updated_at FROM sessions ORDER BY updated_at DESC"
)
rows = await cursor.fetchall()
results = []
for sid, sj, ca, ua in rows:
data = json.loads(sj)
results.append(
{
"session_id": sid,
"phase": data.get("phase", "discovery"),
"confidence": data.get("confidence", 0.0),
"message_count": len(data.get("messages", [])),
"created_at": ca,
"updated_at": ua,
}
)
return results
async def delete_session(session_id: str) -> bool:
"""Delete a session. Returns True if it existed."""
async with aiosqlite.connect(DB_PATH) as db:
cursor = await db.execute(
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
)
await db.commit()
return cursor.rowcount > 0