Spaces:
Running
Running
Commit Β·
d967f3f
1
Parent(s): 4a3ea5f
added conversational memory
Browse files- app.py +22 -1
- db/memory.py +86 -0
- frontend/script.js +16 -1
app.py
CHANGED
|
@@ -27,6 +27,7 @@ app.add_middleware(
|
|
| 27 |
class QuestionRequest(BaseModel):
|
| 28 |
question: str
|
| 29 |
provider: str = "groq" # "groq" | "openai"
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
class GenerateSQLResponse(BaseModel):
|
|
@@ -54,9 +55,29 @@ def generate_sql_endpoint(req: QuestionRequest):
|
|
| 54 |
@app.post("/chat", response_model=ChatResponse)
|
| 55 |
def chat_endpoint(req: QuestionRequest):
|
| 56 |
from ai.pipeline import SQLAnalystPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
pipeline = SQLAnalystPipeline(provider=req.provider)
|
| 59 |
-
result = pipeline.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return ChatResponse(**result)
|
| 61 |
|
| 62 |
|
|
|
|
| 27 |
class QuestionRequest(BaseModel):
|
| 28 |
question: str
|
| 29 |
provider: str = "groq" # "groq" | "openai"
|
| 30 |
+
conversation_id: str | None = None
|
| 31 |
|
| 32 |
|
| 33 |
class GenerateSQLResponse(BaseModel):
|
|
|
|
| 55 |
@app.post("/chat", response_model=ChatResponse)
|
| 56 |
def chat_endpoint(req: QuestionRequest):
|
| 57 |
from ai.pipeline import SQLAnalystPipeline
|
| 58 |
+
from db.memory import get_recent_history, add_turn
|
| 59 |
+
|
| 60 |
+
conversation_id = req.conversation_id or "default"
|
| 61 |
+
|
| 62 |
+
history = get_recent_history(conversation_id, limit=5)
|
| 63 |
+
|
| 64 |
+
# Augment the question with recent conversation context
|
| 65 |
+
if history:
|
| 66 |
+
history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"]
|
| 67 |
+
for turn in history:
|
| 68 |
+
history_lines.append(f"User: {turn['question']}")
|
| 69 |
+
history_lines.append(f"Assistant: {turn['answer']}")
|
| 70 |
+
history_lines.append(f"Now the user asks: {req.question}")
|
| 71 |
+
question_with_context = "\n".join(history_lines)
|
| 72 |
+
else:
|
| 73 |
+
question_with_context = req.question
|
| 74 |
|
| 75 |
pipeline = SQLAnalystPipeline(provider=req.provider)
|
| 76 |
+
result = pipeline.run(question_with_context)
|
| 77 |
+
|
| 78 |
+
# Persist this turn for future context
|
| 79 |
+
add_turn(conversation_id, req.question, result["answer"], result["sql"])
|
| 80 |
+
|
| 81 |
return ChatResponse(**result)
|
| 82 |
|
| 83 |
|
db/memory.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Conversation memory stored in PostgreSQL (Neon).
|
| 2 |
+
|
| 3 |
+
Keeps the last N turns per conversation so the AI can
|
| 4 |
+
use recent context for followβup questions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, List, Dict
|
| 10 |
+
|
| 11 |
+
from sqlalchemy import text
|
| 12 |
+
|
| 13 |
+
from db.connection import get_engine
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_TABLE_CREATED = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _ensure_table() -> None:
|
| 20 |
+
"""Create the chat_history table if it doesn't exist."""
|
| 21 |
+
global _TABLE_CREATED
|
| 22 |
+
if _TABLE_CREATED:
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
ddl = text(
|
| 26 |
+
"""
|
| 27 |
+
CREATE TABLE IF NOT EXISTS chat_history (
|
| 28 |
+
id BIGSERIAL PRIMARY KEY,
|
| 29 |
+
conversation_id TEXT NOT NULL,
|
| 30 |
+
question TEXT NOT NULL,
|
| 31 |
+
answer TEXT NOT NULL,
|
| 32 |
+
sql_query TEXT,
|
| 33 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
| 34 |
+
);
|
| 35 |
+
"""
|
| 36 |
+
)
|
| 37 |
+
engine = get_engine()
|
| 38 |
+
with engine.begin() as conn:
|
| 39 |
+
conn.execute(ddl)
|
| 40 |
+
|
| 41 |
+
_TABLE_CREATED = True
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def add_turn(conversation_id: str, question: str, answer: str, sql_query: str | None) -> None:
|
| 45 |
+
"""Append a single Q/A turn to the history."""
|
| 46 |
+
_ensure_table()
|
| 47 |
+
engine = get_engine()
|
| 48 |
+
insert_stmt = text(
|
| 49 |
+
"""
|
| 50 |
+
INSERT INTO chat_history (conversation_id, question, answer, sql_query)
|
| 51 |
+
VALUES (:conversation_id, :question, :answer, :sql_query)
|
| 52 |
+
"""
|
| 53 |
+
)
|
| 54 |
+
with engine.begin() as conn:
|
| 55 |
+
conn.execute(
|
| 56 |
+
insert_stmt,
|
| 57 |
+
{
|
| 58 |
+
"conversation_id": conversation_id,
|
| 59 |
+
"question": question,
|
| 60 |
+
"answer": answer,
|
| 61 |
+
"sql_query": sql_query,
|
| 62 |
+
},
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_recent_history(conversation_id: str, limit: int = 5) -> List[Dict[str, Any]]:
|
| 67 |
+
"""Return the most recent `limit` turns for a conversation (oldest first)."""
|
| 68 |
+
_ensure_table()
|
| 69 |
+
engine = get_engine()
|
| 70 |
+
query = text(
|
| 71 |
+
"""
|
| 72 |
+
SELECT question, answer, sql_query, created_at
|
| 73 |
+
FROM chat_history
|
| 74 |
+
WHERE conversation_id = :conversation_id
|
| 75 |
+
ORDER BY created_at DESC
|
| 76 |
+
LIMIT :limit
|
| 77 |
+
"""
|
| 78 |
+
)
|
| 79 |
+
with engine.connect() as conn:
|
| 80 |
+
rows = conn.execute(
|
| 81 |
+
query, {"conversation_id": conversation_id, "limit": limit}
|
| 82 |
+
).mappings().all()
|
| 83 |
+
|
| 84 |
+
# Reverse so caller sees oldest β newest
|
| 85 |
+
return list(reversed([dict(r) for r in rows]))
|
| 86 |
+
|
frontend/script.js
CHANGED
|
@@ -25,6 +25,17 @@
|
|
| 25 |
let selectedProvider = "groq";
|
| 26 |
let loadingStepTimer = null;
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
// ββ Model Switcher βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 29 |
modelSwitcher.addEventListener("click", (e) => {
|
| 30 |
const btn = e.target.closest(".switcher-btn");
|
|
@@ -55,7 +66,11 @@
|
|
| 55 |
const res = await fetch("/chat", {
|
| 56 |
method: "POST",
|
| 57 |
headers: { "Content-Type": "application/json" },
|
| 58 |
-
body: JSON.stringify({
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
});
|
| 60 |
|
| 61 |
if (!res.ok) {
|
|
|
|
| 25 |
let selectedProvider = "groq";
|
| 26 |
let loadingStepTimer = null;
|
| 27 |
|
| 28 |
+
// Persistent conversation id per browser (for multi-turn memory)
|
| 29 |
+
let conversationId = window.localStorage.getItem("sqlbot_conversation_id");
|
| 30 |
+
if (!conversationId) {
|
| 31 |
+
if (window.crypto && window.crypto.randomUUID) {
|
| 32 |
+
conversationId = window.crypto.randomUUID();
|
| 33 |
+
} else {
|
| 34 |
+
conversationId = "conv-" + Date.now().toString(36);
|
| 35 |
+
}
|
| 36 |
+
window.localStorage.setItem("sqlbot_conversation_id", conversationId);
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
// ββ Model Switcher βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
modelSwitcher.addEventListener("click", (e) => {
|
| 41 |
const btn = e.target.closest(".switcher-btn");
|
|
|
|
| 66 |
const res = await fetch("/chat", {
|
| 67 |
method: "POST",
|
| 68 |
headers: { "Content-Type": "application/json" },
|
| 69 |
+
body: JSON.stringify({
|
| 70 |
+
question,
|
| 71 |
+
provider: selectedProvider,
|
| 72 |
+
conversation_id: conversationId,
|
| 73 |
+
}),
|
| 74 |
});
|
| 75 |
|
| 76 |
if (!res.ok) {
|