File size: 4,972 Bytes
3973360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356be9d
3973360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from langchain_core.messages import (
    BaseMessage,
    ToolMessage,
    AIMessage,
    HumanMessage,
    trim_messages,
)
from typing import Union
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableLambda
from src.langgraph.state import State
from src.utils.mongo import chat_messages_history
from src.utils.logger import logger


def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int:
    if isinstance(messages, list):
        return sum(len(message.content.split()) for message in messages)
    return len(messages.content.split())


def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state: State) -> dict:
    error = state.get("error")
    tool_messages = state["messages"][-1]
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_messages.tool_calls
        ]
    }


async def get_history(state: State, config):
    logger.info("Get history node")
    history = state["messages_history"] if state.get("messages_history") else None
    try:
        if history is None:
            session_id = config.get("configurable", {}).get("session_id")
            history = await chat_messages_history(session_id).aget_messages()
            # logger.info(f"Chat history: {history}")
            if not history:
                return {"messages_history": []}
        chat_message_history = trim_messages(
            history,
            strategy="last",
            token_counter=fake_token_counter,
            max_tokens=2000,
            start_on="human",
            end_on="ai",
            include_system=False,
            allow_partial=False,
        )
        # logger.info(f"Chat history: {chat_message_history}")
    except Exception as e:
        logger.error(f"Error getting chat history: {e}")
        chat_message_history = []

    return {"messages_history": chat_message_history}


async def save_history(state: State, config):
    if not state["manual_save"]:
        return {"messages": []}
    message = state["messages"]
    user_input = message[0].content
    final_output = message[-1]
    session_id = config.get("configurable", {}).get("session_id")
    messages_add_to_history = [HumanMessage(user_input)]
    if isinstance(final_output, AIMessage):
        messages_add_to_history.append(AIMessage(final_output.content))
        history = chat_messages_history(session_id)
        await history.aadd_messages(messages_add_to_history)
    return {"messages": []}


def human_review_node(state: State):
    logger.info("Human review node")
    tool_calls = state["messages"][-1].tool_calls
    formatted_tool_calls = []
    user_message: HumanMessage = state["messages"][0]
    logger.info(f"User message: {user_message}")

    for call in tool_calls:
        args_str = "\n".join(
            f"#### {k.replace('_', ' ')}: {v}" for k, v in call["args"].items()
        )
        call_name_with_spaces = call["name"].replace("_", " ")
        formatted_call = f"""

**Tool calling**: {call_name_with_spaces}



**Arguments**:\n

{args_str}

"""
        formatted_tool_calls.append(formatted_call)

    format_message = (
        "#### Do you want to run the following tool(s)?\n\n"
        f"{chr(10).join(formatted_tool_calls)}\n\n"
        "Enter **'y'** to run or **'n'** to cancel:"
    )

    return {"messages": [AIMessage(format_message)], "accept": False}


def format_accommodation_markdown(data):
    formatted = ""
    for entry in data:
        formatted += f"### {entry['Accommodation Name']}\n"
        formatted += f"- **Address:** {entry['Address']}\n"
        formatted += f"- **Distance from center:** {entry['distance_km']}\n"

        contact_info = entry.get("contact")
        if contact_info:
            formatted += "- **Contact:**\n"
            if "phone" in contact_info:
                formatted += f"  - Phone: {contact_info['phone']}\n"
            if "email" in contact_info:
                formatted += f"  - Email: {contact_info['email']}\n"

        if "website" in entry:
            formatted += f"- **Website:** [{entry['website']}]({entry['website']})\n"

        accommodation_info = entry.get("accommodation")
        if accommodation_info:
            formatted += "- **Accommodation Info:**\n"
            if "stars" in accommodation_info:
                formatted += f"  - Stars: {accommodation_info['stars']}\n"
            if "rooms" in accommodation_info:
                formatted += f"  - Rooms: {accommodation_info['rooms']}\n"

        formatted += "\n---\n\n"

    return formatted