Spaces:
Sleeping
Sleeping
| from langchain_core.messages import ( | |
| BaseMessage, | |
| ToolMessage, | |
| AIMessage, | |
| HumanMessage, | |
| trim_messages, | |
| ) | |
| from typing import Union | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.runnables import RunnableLambda | |
| from src.langgraph.state import State | |
| from src.utils.mongo import chat_messages_history | |
| from src.utils.logger import logger | |
| def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int: | |
| if isinstance(messages, list): | |
| return sum(len(message.content.split()) for message in messages) | |
| return len(messages.content.split()) | |
| def create_tool_node_with_fallback(tools: list) -> dict: | |
| return ToolNode(tools).with_fallbacks( | |
| [RunnableLambda(handle_tool_error)], exception_key="error" | |
| ) | |
| def handle_tool_error(state: State) -> dict: | |
| error = state.get("error") | |
| tool_messages = state["messages"][-1] | |
| return { | |
| "messages": [ | |
| ToolMessage( | |
| content=f"Error: {repr(error)}\n please fix your mistakes.", | |
| tool_call_id=tc["id"], | |
| ) | |
| for tc in tool_messages.tool_calls | |
| ] | |
| } | |
| async def get_history(state: State, config): | |
| logger.info("Get history node") | |
| history = state["messages_history"] if state.get("messages_history") else None | |
| try: | |
| if history is None: | |
| session_id = config.get("configurable", {}).get("session_id") | |
| history = await chat_messages_history(session_id).aget_messages() | |
| # logger.info(f"Chat history: {history}") | |
| if not history: | |
| return {"messages_history": []} | |
| chat_message_history = trim_messages( | |
| history, | |
| strategy="last", | |
| token_counter=fake_token_counter, | |
| max_tokens=2000, | |
| start_on="human", | |
| end_on="ai", | |
| include_system=False, | |
| allow_partial=False, | |
| ) | |
| # logger.info(f"Chat history: {chat_message_history}") | |
| except Exception as e: | |
| logger.error(f"Error getting chat history: {e}") | |
| chat_message_history = [] | |
| return {"messages_history": chat_message_history} | |
| async def save_history(state: State, config): | |
| if not state["manual_save"]: | |
| return {"messages": []} | |
| message = state["messages"] | |
| user_input = message[0].content | |
| final_output = message[-1] | |
| session_id = config.get("configurable", {}).get("session_id") | |
| messages_add_to_history = [HumanMessage(user_input)] | |
| if isinstance(final_output, AIMessage): | |
| messages_add_to_history.append(AIMessage(final_output.content)) | |
| history = chat_messages_history(session_id) | |
| await history.aadd_messages(messages_add_to_history) | |
| return {"messages": []} | |
| def human_review_node(state: State): | |
| logger.info("Human review node") | |
| tool_calls = state["messages"][-1].tool_calls | |
| formatted_tool_calls = [] | |
| user_message: HumanMessage = state["messages"][0] | |
| logger.info(f"User message: {user_message}") | |
| for call in tool_calls: | |
| args_str = "\n".join( | |
| f"#### {k.replace('_', ' ')}: {v}" for k, v in call["args"].items() | |
| ) | |
| call_name_with_spaces = call["name"].replace("_", " ") | |
| formatted_call = f""" | |
| **Tool calling**: {call_name_with_spaces} | |
| **Arguments**:\n | |
| {args_str} | |
| """ | |
| formatted_tool_calls.append(formatted_call) | |
| format_message = ( | |
| "#### Do you want to run the following tool(s)?\n\n" | |
| f"{chr(10).join(formatted_tool_calls)}\n\n" | |
| "Enter **'y'** to run or **'n'** to cancel:" | |
| ) | |
| return {"messages": [AIMessage(format_message)], "accept": False} | |
| def format_accommodation_markdown(data): | |
| formatted = "" | |
| for entry in data: | |
| formatted += f"### {entry['Accommodation Name']}\n" | |
| formatted += f"- **Address:** {entry['Address']}\n" | |
| formatted += f"- **Distance from center:** {entry['distance_km']}\n" | |
| contact_info = entry.get("contact") | |
| if contact_info: | |
| formatted += "- **Contact:**\n" | |
| if "phone" in contact_info: | |
| formatted += f" - Phone: {contact_info['phone']}\n" | |
| if "email" in contact_info: | |
| formatted += f" - Email: {contact_info['email']}\n" | |
| if "website" in entry: | |
| formatted += f"- **Website:** [{entry['website']}]({entry['website']})\n" | |
| accommodation_info = entry.get("accommodation") | |
| if accommodation_info: | |
| formatted += "- **Accommodation Info:**\n" | |
| if "stars" in accommodation_info: | |
| formatted += f" - Stars: {accommodation_info['stars']}\n" | |
| if "rooms" in accommodation_info: | |
| formatted += f" - Rooms: {accommodation_info['rooms']}\n" | |
| formatted += "\n---\n\n" | |
| return formatted | |