Spaces:
Sleeping
Sleeping
File size: 4,085 Bytes
16d5a75 031378e 16d5a75 031378e 16d5a75 031378e 16d5a75 031378e 16d5a75 744b763 16d5a75 031378e 16d5a75 031378e 16d5a75 031378e 16d5a75 031378e 16d5a75 031378e 16d5a75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | from typing import TypedDict, Optional, List
from langchain_core.messages import AnyMessage, ToolMessage
from langgraph.graph.message import add_messages
from typing import Sequence, Annotated
from langchain_core.messages import RemoveMessage
from langchain_core.documents import Document
from .tools import retrieve_document, python_repl, duckduckgo_search
from src.utils.logger import logger
from src.config.llm import get_llm
from .prompt import template_prompt
tools = [retrieve_document, python_repl, duckduckgo_search]
class State(TypedDict):
messages: Annotated[Sequence[AnyMessage], add_messages]
selected_ids: Optional[List[str]]
selected_documents: Optional[List[Document]]
tools: Optional[List[str]]
prompt: str
model_name: Optional[str]
def trim_history(state: State):
history = state.get("messages", [])
tool_names = state.get("tools", [])
if len(history) > 10:
num_to_remove = len(history) - 10
remove_messages = [
RemoveMessage(id=history[i].id) for i in range(num_to_remove)
]
return {
"messages": remove_messages,
"selected_ids": [],
"selected_documents": [],
}
return {}
def execute_tool(state: State):
tool_calls = state["messages"][-1].tool_calls
tool_names = state.get("tools", [])
tool_name_to_func = {tool.name: tool for tool in tools}
tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
selected_ids = []
selected_documents = []
tool_messages = []
for tool_call in tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
tool_id = tool_call["id"]
tool_func = tool_name_to_func.get(tool_name)
if tool_func:
if tool_name == "retrieve_document":
documents = tool_func.invoke(tool_args.get("query"))
documents = dict(documents)
context_str = documents.get("context_str", "")
selected_documents = documents.get("selected_documents", [])
selected_ids = documents.get("selected_ids", [])
if documents:
tool_messages.append(
ToolMessage(
tool_call_id=tool_id,
content=context_str,
)
)
continue
tool_response = tool_func.invoke(tool_args)
print(f"tool_response: {tool_response}")
tool_messages.append(ToolMessage(
tool_call_id=tool_id,
content=tool_response,
))
return {
"selected_ids": selected_ids,
"selected_documents": selected_documents,
"messages": tool_messages,
}
def generate_answer_rag(state: State):
messages = state["messages"]
tool_names = state.get("tools", [])
prompt = state["prompt"]
model_name = state.get("model_name", "gemini-2.0-flash")
tool_name_to_func = {tool.name: tool for tool in tools}
tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func]
print(f"tools: {tool_functions}")
llm_call = template_prompt | get_llm(model_name).bind_tools(tool_functions)
if tool_functions:
for tool in tool_functions:
if tool.name == "retrieve_document":
prompt += "Sử dụng tool `retrieve_document` để truy xuất tài liệu để bổ sung thông tin cho câu trả lời"
if tool.name == "python_repl":
prompt += "Sử dụng tool `python_repl` để thực hiện các tác vụ liên quan đến tính toán phức tạp"
if tool.name == "duckduckgo_search":
prompt += "Sử dụng tool `duckduckgo_search` để tìm kiếm thông tin trên internet"
response = llm_call.invoke(
{
"messages": messages,
"prompt": prompt,
}
)
return {"messages": response}
|