File size: 1,964 Bytes
35c0d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Conversation memory.

Short-term, per-session chat memory using LangChain's RunnableWithMessageHistory
backed by SQLChatMessageHistory, persisted to SQLite at ./data/sessions.db.

Each Gradio session maps to a session_id. RunnableWithMessageHistory transparently
loads prior turns before each call and saves the new turn afterward; the actual
trimming to the configured window happens inside the assistant's _respond.
"""

from __future__ import annotations

import os

from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.history import RunnableWithMessageHistory

from src.assistants.base import BaseAssistant
from src.config import settings


def _connection_string() -> str:
    """SQLAlchemy-style sqlite URL, ensuring the parent directory exists."""
    path = settings.sqlite_path
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    return f"sqlite:///{path}"


def get_session_history(session_id: str) -> BaseChatMessageHistory:
    """Return the persistent chat history for one session."""
    return SQLChatMessageHistory(
        session_id=session_id,
        connection=_connection_string(),
    )


def build_conversational(assistant: BaseAssistant) -> RunnableWithMessageHistory:
    """Wrap an assistant's core generation step with persistent memory.

    Invoke as:
        conversational.invoke(
            {"input": user_msg},
            config={"configurable": {"session_id": sid}},
        )
    RunnableWithMessageHistory injects the loaded history under "history" and
    saves both the user input and the returned AIMessage afterward.
    """
    core = RunnableLambda(assistant._respond)
    return RunnableWithMessageHistory(
        core,
        get_session_history,
        input_messages_key="input",
        history_messages_key="history",
    )