Spaces:
Sleeping
Sleeping
File size: 4,972 Bytes
3973360 356be9d 3973360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from langchain_core.messages import (
BaseMessage,
ToolMessage,
AIMessage,
HumanMessage,
trim_messages,
)
from typing import Union
from langgraph.prebuilt import ToolNode
from langchain_core.runnables import RunnableLambda
from src.langgraph.state import State
from src.utils.mongo import chat_messages_history
from src.utils.logger import logger
def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int:
if isinstance(messages, list):
return sum(len(message.content.split()) for message in messages)
return len(messages.content.split())
def create_tool_node_with_fallback(tools: list) -> dict:
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
def handle_tool_error(state: State) -> dict:
error = state.get("error")
tool_messages = state["messages"][-1]
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_messages.tool_calls
]
}
async def get_history(state: State, config):
logger.info("Get history node")
history = state["messages_history"] if state.get("messages_history") else None
try:
if history is None:
session_id = config.get("configurable", {}).get("session_id")
history = await chat_messages_history(session_id).aget_messages()
# logger.info(f"Chat history: {history}")
if not history:
return {"messages_history": []}
chat_message_history = trim_messages(
history,
strategy="last",
token_counter=fake_token_counter,
max_tokens=2000,
start_on="human",
end_on="ai",
include_system=False,
allow_partial=False,
)
# logger.info(f"Chat history: {chat_message_history}")
except Exception as e:
logger.error(f"Error getting chat history: {e}")
chat_message_history = []
return {"messages_history": chat_message_history}
async def save_history(state: State, config):
if not state["manual_save"]:
return {"messages": []}
message = state["messages"]
user_input = message[0].content
final_output = message[-1]
session_id = config.get("configurable", {}).get("session_id")
messages_add_to_history = [HumanMessage(user_input)]
if isinstance(final_output, AIMessage):
messages_add_to_history.append(AIMessage(final_output.content))
history = chat_messages_history(session_id)
await history.aadd_messages(messages_add_to_history)
return {"messages": []}
def human_review_node(state: State):
logger.info("Human review node")
tool_calls = state["messages"][-1].tool_calls
formatted_tool_calls = []
user_message: HumanMessage = state["messages"][0]
logger.info(f"User message: {user_message}")
for call in tool_calls:
args_str = "\n".join(
f"#### {k.replace('_', ' ')}: {v}" for k, v in call["args"].items()
)
call_name_with_spaces = call["name"].replace("_", " ")
formatted_call = f"""
**Tool calling**: {call_name_with_spaces}
**Arguments**:\n
{args_str}
"""
formatted_tool_calls.append(formatted_call)
format_message = (
"#### Do you want to run the following tool(s)?\n\n"
f"{chr(10).join(formatted_tool_calls)}\n\n"
"Enter **'y'** to run or **'n'** to cancel:"
)
return {"messages": [AIMessage(format_message)], "accept": False}
def format_accommodation_markdown(data):
formatted = ""
for entry in data:
formatted += f"### {entry['Accommodation Name']}\n"
formatted += f"- **Address:** {entry['Address']}\n"
formatted += f"- **Distance from center:** {entry['distance_km']}\n"
contact_info = entry.get("contact")
if contact_info:
formatted += "- **Contact:**\n"
if "phone" in contact_info:
formatted += f" - Phone: {contact_info['phone']}\n"
if "email" in contact_info:
formatted += f" - Email: {contact_info['email']}\n"
if "website" in entry:
formatted += f"- **Website:** [{entry['website']}]({entry['website']})\n"
accommodation_info = entry.get("accommodation")
if accommodation_info:
formatted += "- **Accommodation Info:**\n"
if "stars" in accommodation_info:
formatted += f" - Stars: {accommodation_info['stars']}\n"
if "rooms" in accommodation_info:
formatted += f" - Rooms: {accommodation_info['rooms']}\n"
formatted += "\n---\n\n"
return formatted
|