File size: 2,322 Bytes
d967f3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Conversation memory stored in PostgreSQL (Neon).

Keeps the last N turns per conversation so the AI can
use recent context for follow‑up questions.
"""

from __future__ import annotations

from typing import Any, List, Dict

from sqlalchemy import text

from db.connection import get_engine


_TABLE_CREATED = False


def _ensure_table() -> None:
    """Create the chat_history table if it doesn't exist."""
    global _TABLE_CREATED
    if _TABLE_CREATED:
        return

    ddl = text(
        """
        CREATE TABLE IF NOT EXISTS chat_history (
            id BIGSERIAL PRIMARY KEY,
            conversation_id TEXT NOT NULL,
            question TEXT NOT NULL,
            answer TEXT NOT NULL,
            sql_query TEXT,
            created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
        );
        """
    )
    engine = get_engine()
    with engine.begin() as conn:
        conn.execute(ddl)

    _TABLE_CREATED = True


def add_turn(conversation_id: str, question: str, answer: str, sql_query: str | None) -> None:
    """Append a single Q/A turn to the history."""
    _ensure_table()
    engine = get_engine()
    insert_stmt = text(
        """
        INSERT INTO chat_history (conversation_id, question, answer, sql_query)
        VALUES (:conversation_id, :question, :answer, :sql_query)
        """
    )
    with engine.begin() as conn:
        conn.execute(
            insert_stmt,
            {
                "conversation_id": conversation_id,
                "question": question,
                "answer": answer,
                "sql_query": sql_query,
            },
        )


def get_recent_history(conversation_id: str, limit: int = 5) -> List[Dict[str, Any]]:
    """Return the most recent `limit` turns for a conversation (oldest first)."""
    _ensure_table()
    engine = get_engine()
    query = text(
        """
        SELECT question, answer, sql_query, created_at
        FROM chat_history
        WHERE conversation_id = :conversation_id
        ORDER BY created_at DESC
        LIMIT :limit
        """
    )
    with engine.connect() as conn:
        rows = conn.execute(
            query, {"conversation_id": conversation_id, "limit": limit}
        ).mappings().all()

    # Reverse so caller sees oldest → newest
    return list(reversed([dict(r) for r in rows]))