File size: 2,334 Bytes
40d5e03
 
 
 
fc62e60
 
 
 
 
 
 
40d5e03
 
 
 
 
 
 
 
 
 
 
 
 
8b9e569
40d5e03
 
 
 
 
 
 
 
8b9e569
40d5e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Dict, List, Literal

from classes.base_models import ChatMessage

"""

This class should be removed after the demo and all call sites

migrated to the LangGraph checkpointer. We should use a persistent

checkpointer (e.g. PostgresSaver or RedisSaver) once the demo is completed.

For more details: https://docs.langchain.com/oss/python/langchain/short-term-memory

"""


class SessionConversationStore:
    def __init__(self) -> None:
        # session_id -> conversation_id -> [ChatMessage]
        self.session_conversation_map: Dict[str, Dict[str, List[ChatMessage]]] = dict()

    def add_human_message(

        self,

        session_id: str,

        conversation_id: str,

        human_message: str,

    ):
        self.__add_message(session_id, conversation_id, human_message, role="user")
        return self.session_conversation_map[session_id][conversation_id]

    def add_assistant_reply(

        self,

        session_id: str,

        conversation_id: str,

        reply: str,

    ):
        self.__add_message(session_id, conversation_id, reply, role="assistant")
        return self.session_conversation_map[session_id][conversation_id]

    def delete_session_conversations(self, session_id: str):
        if session_id in self.session_conversation_map:
            del self.session_conversation_map[session_id]

    def __add_message(

        self,

        session_id: str,

        conversation_id: str,

        message: str,

        role: Literal["user", "assistant", "system"],

    ):
        # New session
        if session_id not in self.session_conversation_map:
            self.session_conversation_map[session_id] = {
                conversation_id: [
                    ChatMessage(role=role, content=message),
                ]
            }
            return

        # New conversation, but old session
        conversation_map = self.session_conversation_map[session_id]
        if conversation_id not in conversation_map:
            conversation_map[conversation_id] = [
                ChatMessage(role=role, content=message),
            ]
            return

        # Old conversation and old session
        conversation_map[conversation_id].append(
            ChatMessage(role=role, content=message),
        )