ABAO77's picture
Upload 72 files
031378e verified
from typing import TypedDict, Optional, List
from langchain_core.messages import AnyMessage, ToolMessage
from langgraph.graph.message import add_messages
from typing import Sequence, Annotated
from langchain_core.messages import RemoveMessage
from langchain_core.documents import Document
from .tools import retrieve_document, python_repl, duckduckgo_search
from src.utils.logger import logger
from src.config.llm import get_llm
from .prompt import template_prompt
tools = [retrieve_document, python_repl, duckduckgo_search]
class State(TypedDict):
messages: Annotated[Sequence[AnyMessage], add_messages]
selected_ids: Optional[List[str]]
selected_documents: Optional[List[Document]]
tools: Optional[List[str]]
prompt: str
model_name: Optional[str]
def trim_history(state: State):
history = state.get("messages", [])
tool_names = state.get("tools", [])
if len(history) > 10:
num_to_remove = len(history) - 10
remove_messages = [
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
]
return {
"messages": remove_messages,
"selected_ids": [],
"selected_documents": [],
}
return {}
def execute_tool(state: State):
tool_calls = state["messages"][-1].tool_calls
tool_names = state.get("tools", [])
tool_name_to_func = {tool.name: tool for tool in tools}
tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
selected_ids = []
selected_documents = []
tool_messages = []
for tool_call in tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = tool_name_to_func.get(tool_name)
if tool_func:
if tool_name == "retrieve_document":
documents = tool_func.invoke(tool_args.get("query"))
documents = dict(documents)
context_str = documents.get("context_str", "")
selected_documents = documents.get("selected_documents", [])
selected_ids = documents.get("selected_ids", [])
if documents:
tool_messages.append(
ToolMessage(
tool_call_id=tool_id,
content=context_str,
)
)
continue
tool_response = tool_func.invoke(tool_args)
print(f"tool_response: {tool_response}")
tool_messages.append(ToolMessage(
tool_call_id=tool_id,
content=tool_response,
))
return {
"selected_ids": selected_ids,
"selected_documents": selected_documents,
"messages": tool_messages,
}
def generate_answer_rag(state: State):
messages = state["messages"]
tool_names = state.get("tools", [])
prompt = state["prompt"]
model_name = state.get("model_name", "gemini-2.0-flash")
tool_name_to_func = {tool.name: tool for tool in tools}
tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
print(f"tools: {tool_functions}")
llm_call = template_prompt | get_llm(model_name).bind_tools(tool_functions)
if tool_functions:
for tool in tool_functions:
if tool.name == "retrieve_document":
prompt += "Sử dụng tool `retrieve_document` để truy xuất tài liệu để bổ sung thông tin cho câu trả lời"
if tool.name == "python_repl":
prompt += "Sử dụng tool `python_repl` để thực hiện các tác vụ liên quan đến tính toán phức tạp"
if tool.name == "duckduckgo_search":
prompt += "Sử dụng tool `duckduckgo_search` để tìm kiếm thông tin trên internet"
response = llm_call.invoke(
{
"messages": messages,
"prompt": prompt,
}
)
return {"messages": response}