ABAO77's picture
Upload 159 files
356be9d verified
from langchain_core.messages import (
BaseMessage,
ToolMessage,
AIMessage,
HumanMessage,
trim_messages,
)
from typing import Union
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableLambda
from src.langgraph.state import State
from src.utils.mongo import chat_messages_history
from src.utils.logger import logger
def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int:
if isinstance(messages, list):
return sum(len(message.content.split()) for message in messages)
return len(messages.content.split())
def create_tool_node_with_fallback(tools: list) -> dict:
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
def handle_tool_error(state: State) -> dict:
error = state.get("error")
tool_messages = state["messages"][-1]
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_messages.tool_calls
]
}
async def get_history(state: State, config):
logger.info("Get history node")
history = state["messages_history"] if state.get("messages_history") else None
try:
if history is None:
session_id = config.get("configurable", {}).get("session_id")
history = await chat_messages_history(session_id).aget_messages()
# logger.info(f"Chat history: {history}")
if not history:
return {"messages_history": []}
chat_message_history = trim_messages(
history,
strategy="last",
token_counter=fake_token_counter,
max_tokens=2000,
start_on="human",
end_on="ai",
include_system=False,
allow_partial=False,
)
# logger.info(f"Chat history: {chat_message_history}")
except Exception as e:
logger.error(f"Error getting chat history: {e}")
chat_message_history = []
return {"messages_history": chat_message_history}
async def save_history(state: State, config):
if not state["manual_save"]:
return {"messages": []}
message = state["messages"]
user_input = message[0].content
final_output = message[-1]
session_id = config.get("configurable", {}).get("session_id")
messages_add_to_history = [HumanMessage(user_input)]
if isinstance(final_output, AIMessage):
messages_add_to_history.append(AIMessage(final_output.content))
history = chat_messages_history(session_id)
await history.aadd_messages(messages_add_to_history)
return {"messages": []}
def human_review_node(state: State):
logger.info("Human review node")
tool_calls = state["messages"][-1].tool_calls
formatted_tool_calls = []
user_message: HumanMessage = state["messages"][0]
logger.info(f"User message: {user_message}")
for call in tool_calls:
args_str = "\n".join(
f"#### {k.replace('_', ' ')}: {v}" for k, v in call["args"].items()
)
call_name_with_spaces = call["name"].replace("_", " ")
formatted_call = f"""
**Tool calling**: {call_name_with_spaces}
**Arguments**:\n
{args_str}
"""
formatted_tool_calls.append(formatted_call)
format_message = (
"#### Do you want to run the following tool(s)?\n\n"
f"{chr(10).join(formatted_tool_calls)}\n\n"
"Enter **'y'** to run or **'n'** to cancel:"
)
return {"messages": [AIMessage(format_message)], "accept": False}
def format_accommodation_markdown(data):
formatted = ""
for entry in data:
formatted += f"### {entry['Accommodation Name']}\n"
formatted += f"- **Address:** {entry['Address']}\n"
formatted += f"- **Distance from center:** {entry['distance_km']}\n"
contact_info = entry.get("contact")
if contact_info:
formatted += "- **Contact:**\n"
if "phone" in contact_info:
formatted += f" - Phone: {contact_info['phone']}\n"
if "email" in contact_info:
formatted += f" - Email: {contact_info['email']}\n"
if "website" in entry:
formatted += f"- **Website:** [{entry['website']}]({entry['website']})\n"
accommodation_info = entry.get("accommodation")
if accommodation_info:
formatted += "- **Accommodation Info:**\n"
if "stars" in accommodation_info:
formatted += f" - Stars: {accommodation_info['stars']}\n"
if "rooms" in accommodation_info:
formatted += f" - Rooms: {accommodation_info['rooms']}\n"
formatted += "\n---\n\n"
return formatted