File size: 2,692 Bytes
d545f81
 
 
 
3cdce90
d545f81
3cdce90
d545f81
3cdce90
 
d545f81
3cdce90
d545f81
 
 
 
 
 
 
3cdce90
 
 
d545f81
 
3cdce90
d545f81
 
 
 
 
 
 
 
 
 
 
 
 
3cdce90
d545f81
 
 
3cdce90
d545f81
 
3cdce90
d545f81
 
 
 
 
 
3cdce90
d545f81
3cdce90
d545f81
 
 
 
 
 
3cdce90
d545f81
 
 
3cdce90
d545f81
 
 
 
 
 
 
 
 
 
3cdce90
d545f81
 
 
 
 
 
 
 
3cdce90
d545f81
 
3cdce90
d545f81
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Memory module for XENO Bot
Handles LangGraph memory operations using SQLite
"""

import sqlite3
import uuid
from datetime import datetime
from typing import Any, Dict, List

from langgraph.checkpoint.sqlite import SqliteSaver

from src.config import SQLITE_DB_PATH

# === LangGraph Memory Setup ===
conn = sqlite3.connect(SQLITE_DB_PATH, check_same_thread=False)
memory = SqliteSaver(conn=conn)


def update_memory(
    config: Dict[str, Any], user_message: str, assistant_message: str, timer=None
):
    """
    Update memory with new messages

    Args:
        config: Configuration dictionary with thread_id
        user_message: User's message
        assistant_message: Assistant's response
        timer: Optional timer object for tracking
    """
    if timer:
        with timer.time_step("memory_update"):
            _update_memory_impl(config, user_message, assistant_message)
    else:
        _update_memory_impl(config, user_message, assistant_message)


def _update_memory_impl(config, user_message: str, assistant_message: str):
    """Internal implementation of memory update"""
    full_checkpoint = memory.get(config) or {}
    messages = full_checkpoint.get("channel_values", {}).get("messages", [])

    messages.append({"role": "user", "content": user_message})
    messages.append({"role": "assistant", "content": assistant_message})

    checkpoint_to_save = {
        "v": 1,
        "id": str(uuid.uuid4()),
        "ts": datetime.now().isoformat(),
        "channel_values": {"messages": messages},
        "channel_versions": {},
        "versions_seen": {},
    }

    memory.put(config, checkpoint_to_save, {}, {})


def retrieve_memory(config: Dict[str, Any], timer=None) -> List[Dict[str, str]]:
    """
    Retrieve memory messages for a session

    Args:
        config: Configuration dictionary with thread_id
        timer: Optional timer object for tracking

    Returns:
        List of message dictionaries
    """
    if timer:
        with timer.time_step("memory_retrieval"):
            return _retrieve_memory_impl(config)
    else:
        return _retrieve_memory_impl(config)


def _retrieve_memory_impl(config) -> List[Dict[str, str]]:
    """Internal implementation of memory retrieval"""
    full_checkpoint = memory.get(config) or {}
    return full_checkpoint.get("channel_values", {}).get("messages", [])


def create_session_config(session_id: str = "default") -> Dict[str, Any]:
    """
    Create a configuration dictionary for a session

    Args:
        session_id: Unique session identifier

    Returns:
        Configuration dictionary
    """
    return {"configurable": {"thread_id": str(session_id), "checkpoint_ns": ""}}