jashdoshi77 commited on
Commit
d967f3f
Β·
1 Parent(s): 4a3ea5f

added conversational memory

Browse files
Files changed (3) hide show
  1. app.py +22 -1
  2. db/memory.py +86 -0
  3. frontend/script.js +16 -1
app.py CHANGED
@@ -27,6 +27,7 @@ app.add_middleware(
27
  class QuestionRequest(BaseModel):
28
  question: str
29
  provider: str = "groq" # "groq" | "openai"
 
30
 
31
 
32
  class GenerateSQLResponse(BaseModel):
@@ -54,9 +55,29 @@ def generate_sql_endpoint(req: QuestionRequest):
54
  @app.post("/chat", response_model=ChatResponse)
55
  def chat_endpoint(req: QuestionRequest):
56
  from ai.pipeline import SQLAnalystPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  pipeline = SQLAnalystPipeline(provider=req.provider)
59
- result = pipeline.run(req.question)
 
 
 
 
60
  return ChatResponse(**result)
61
 
62
 
 
27
  class QuestionRequest(BaseModel):
28
  question: str
29
  provider: str = "groq" # "groq" | "openai"
30
+ conversation_id: str | None = None
31
 
32
 
33
  class GenerateSQLResponse(BaseModel):
 
55
  @app.post("/chat", response_model=ChatResponse)
56
  def chat_endpoint(req: QuestionRequest):
57
  from ai.pipeline import SQLAnalystPipeline
58
+ from db.memory import get_recent_history, add_turn
59
+
60
+ conversation_id = req.conversation_id or "default"
61
+
62
+ history = get_recent_history(conversation_id, limit=5)
63
+
64
+ # Augment the question with recent conversation context
65
+ if history:
66
+ history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"]
67
+ for turn in history:
68
+ history_lines.append(f"User: {turn['question']}")
69
+ history_lines.append(f"Assistant: {turn['answer']}")
70
+ history_lines.append(f"Now the user asks: {req.question}")
71
+ question_with_context = "\n".join(history_lines)
72
+ else:
73
+ question_with_context = req.question
74
 
75
  pipeline = SQLAnalystPipeline(provider=req.provider)
76
+ result = pipeline.run(question_with_context)
77
+
78
+ # Persist this turn for future context
79
+ add_turn(conversation_id, req.question, result["answer"], result["sql"])
80
+
81
  return ChatResponse(**result)
82
 
83
 
db/memory.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conversation memory stored in PostgreSQL (Neon).
2
+
3
+ Keeps the last N turns per conversation so the AI can
4
+ use recent context for follow‑up questions.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, List, Dict
10
+
11
+ from sqlalchemy import text
12
+
13
+ from db.connection import get_engine
14
+
15
+
16
+ _TABLE_CREATED = False
17
+
18
+
19
+ def _ensure_table() -> None:
20
+ """Create the chat_history table if it doesn't exist."""
21
+ global _TABLE_CREATED
22
+ if _TABLE_CREATED:
23
+ return
24
+
25
+ ddl = text(
26
+ """
27
+ CREATE TABLE IF NOT EXISTS chat_history (
28
+ id BIGSERIAL PRIMARY KEY,
29
+ conversation_id TEXT NOT NULL,
30
+ question TEXT NOT NULL,
31
+ answer TEXT NOT NULL,
32
+ sql_query TEXT,
33
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
34
+ );
35
+ """
36
+ )
37
+ engine = get_engine()
38
+ with engine.begin() as conn:
39
+ conn.execute(ddl)
40
+
41
+ _TABLE_CREATED = True
42
+
43
+
44
+ def add_turn(conversation_id: str, question: str, answer: str, sql_query: str | None) -> None:
45
+ """Append a single Q/A turn to the history."""
46
+ _ensure_table()
47
+ engine = get_engine()
48
+ insert_stmt = text(
49
+ """
50
+ INSERT INTO chat_history (conversation_id, question, answer, sql_query)
51
+ VALUES (:conversation_id, :question, :answer, :sql_query)
52
+ """
53
+ )
54
+ with engine.begin() as conn:
55
+ conn.execute(
56
+ insert_stmt,
57
+ {
58
+ "conversation_id": conversation_id,
59
+ "question": question,
60
+ "answer": answer,
61
+ "sql_query": sql_query,
62
+ },
63
+ )
64
+
65
+
66
+ def get_recent_history(conversation_id: str, limit: int = 5) -> List[Dict[str, Any]]:
67
+ """Return the most recent `limit` turns for a conversation (oldest first)."""
68
+ _ensure_table()
69
+ engine = get_engine()
70
+ query = text(
71
+ """
72
+ SELECT question, answer, sql_query, created_at
73
+ FROM chat_history
74
+ WHERE conversation_id = :conversation_id
75
+ ORDER BY created_at DESC
76
+ LIMIT :limit
77
+ """
78
+ )
79
+ with engine.connect() as conn:
80
+ rows = conn.execute(
81
+ query, {"conversation_id": conversation_id, "limit": limit}
82
+ ).mappings().all()
83
+
84
+ # Reverse so caller sees oldest β†’ newest
85
+ return list(reversed([dict(r) for r in rows]))
86
+
frontend/script.js CHANGED
@@ -25,6 +25,17 @@
25
  let selectedProvider = "groq";
26
  let loadingStepTimer = null;
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  // ── Model Switcher ───────────────────────────────────────────────────
29
  modelSwitcher.addEventListener("click", (e) => {
30
  const btn = e.target.closest(".switcher-btn");
@@ -55,7 +66,11 @@
55
  const res = await fetch("/chat", {
56
  method: "POST",
57
  headers: { "Content-Type": "application/json" },
58
- body: JSON.stringify({ question, provider: selectedProvider }),
 
 
 
 
59
  });
60
 
61
  if (!res.ok) {
 
25
  let selectedProvider = "groq";
26
  let loadingStepTimer = null;
27
 
28
+ // Persistent conversation id per browser (for multi-turn memory)
29
+ let conversationId = window.localStorage.getItem("sqlbot_conversation_id");
30
+ if (!conversationId) {
31
+ if (window.crypto && window.crypto.randomUUID) {
32
+ conversationId = window.crypto.randomUUID();
33
+ } else {
34
+ conversationId = "conv-" + Date.now().toString(36);
35
+ }
36
+ window.localStorage.setItem("sqlbot_conversation_id", conversationId);
37
+ }
38
+
39
  // ── Model Switcher ───────────────────────────────────────────────────
40
  modelSwitcher.addEventListener("click", (e) => {
41
  const btn = e.target.closest(".switcher-btn");
 
66
  const res = await fetch("/chat", {
67
  method: "POST",
68
  headers: { "Content-Type": "application/json" },
69
+ body: JSON.stringify({
70
+ question,
71
+ provider: selectedProvider,
72
+ conversation_id: conversationId,
73
+ }),
74
  });
75
 
76
  if (!res.ok) {