Spaces:
Running
Running
File size: 2,322 Bytes
d967f3f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """Conversation memory stored in PostgreSQL (Neon).
Keeps the last N turns per conversation so the AI can
use recent context for follow‑up questions.
"""
from __future__ import annotations
from typing import Any, List, Dict
from sqlalchemy import text
from db.connection import get_engine
_TABLE_CREATED = False
def _ensure_table() -> None:
"""Create the chat_history table if it doesn't exist."""
global _TABLE_CREATED
if _TABLE_CREATED:
return
ddl = text(
"""
CREATE TABLE IF NOT EXISTS chat_history (
id BIGSERIAL PRIMARY KEY,
conversation_id TEXT NOT NULL,
question TEXT NOT NULL,
answer TEXT NOT NULL,
sql_query TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
)
engine = get_engine()
with engine.begin() as conn:
conn.execute(ddl)
_TABLE_CREATED = True
def add_turn(conversation_id: str, question: str, answer: str, sql_query: str | None) -> None:
"""Append a single Q/A turn to the history."""
_ensure_table()
engine = get_engine()
insert_stmt = text(
"""
INSERT INTO chat_history (conversation_id, question, answer, sql_query)
VALUES (:conversation_id, :question, :answer, :sql_query)
"""
)
with engine.begin() as conn:
conn.execute(
insert_stmt,
{
"conversation_id": conversation_id,
"question": question,
"answer": answer,
"sql_query": sql_query,
},
)
def get_recent_history(conversation_id: str, limit: int = 5) -> List[Dict[str, Any]]:
"""Return the most recent `limit` turns for a conversation (oldest first)."""
_ensure_table()
engine = get_engine()
query = text(
"""
SELECT question, answer, sql_query, created_at
FROM chat_history
WHERE conversation_id = :conversation_id
ORDER BY created_at DESC
LIMIT :limit
"""
)
with engine.connect() as conn:
rows = conn.execute(
query, {"conversation_id": conversation_id, "limit": limit}
).mappings().all()
# Reverse so caller sees oldest → newest
return list(reversed([dict(r) for r in rows]))
|