Spaces:
Sleeping
Sleeping
| from typing import TypedDict, Optional, List | |
| from langchain_core.messages import AnyMessage, ToolMessage | |
| from langgraph.graph.message import add_messages | |
| from typing import Sequence, Annotated | |
| from langchain_core.messages import RemoveMessage | |
| from langchain_core.documents import Document | |
| from .tools import retrieve_document, python_repl, duckduckgo_search | |
| from src.utils.logger import logger | |
| from src.config.llm import get_llm | |
| from .prompt import template_prompt | |
| tools = [retrieve_document, python_repl, duckduckgo_search] | |
| class State(TypedDict): | |
| messages: Annotated[Sequence[AnyMessage], add_messages] | |
| selected_ids: Optional[List[str]] | |
| selected_documents: Optional[List[Document]] | |
| tools: Optional[List[str]] | |
| prompt: str | |
| model_name: Optional[str] | |
| def trim_history(state: State): | |
| history = state.get("messages", []) | |
| tool_names = state.get("tools", []) | |
| if len(history) > 10: | |
| num_to_remove = len(history) - 10 | |
| remove_messages = [ | |
| RemoveMessage(id=history[i].id) for i in range(num_to_remove) | |
| ] | |
| return { | |
| "messages": remove_messages, | |
| "selected_ids": [], | |
| "selected_documents": [], | |
| } | |
| return {} | |
| def execute_tool(state: State): | |
| tool_calls = state["messages"][-1].tool_calls | |
| tool_names = state.get("tools", []) | |
| tool_name_to_func = {tool.name: tool for tool in tools} | |
| tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func] | |
| selected_ids = [] | |
| selected_documents = [] | |
| tool_messages = [] | |
| for tool_call in tool_calls: | |
| tool_name = tool_call["name"] | |
| tool_args = tool_call["args"] | |
| tool_id = tool_call["id"] | |
| tool_func = tool_name_to_func.get(tool_name) | |
| if tool_func: | |
| if tool_name == "retrieve_document": | |
| documents = tool_func.invoke(tool_args.get("query")) | |
| documents = dict(documents) | |
| context_str = documents.get("context_str", "") | |
| selected_documents = documents.get("selected_documents", []) | |
| selected_ids = documents.get("selected_ids", []) | |
| if documents: | |
| tool_messages.append( | |
| ToolMessage( | |
| tool_call_id=tool_id, | |
| content=context_str, | |
| ) | |
| ) | |
| continue | |
| tool_response = tool_func.invoke(tool_args) | |
| print(f"tool_response: {tool_response}") | |
| tool_messages.append(ToolMessage( | |
| tool_call_id=tool_id, | |
| content=tool_response, | |
| )) | |
| return { | |
| "selected_ids": selected_ids, | |
| "selected_documents": selected_documents, | |
| "messages": tool_messages, | |
| } | |
| def generate_answer_rag(state: State): | |
| messages = state["messages"] | |
| tool_names = state.get("tools", []) | |
| prompt = state["prompt"] | |
| model_name = state.get("model_name", "gemini-2.0-flash") | |
| tool_name_to_func = {tool.name: tool for tool in tools} | |
| tool_functions = [tool_name_to_func[name] for name in tool_names if name in tool_name_to_func] | |
| print(f"tools: {tool_functions}") | |
| llm_call = template_prompt | get_llm(model_name).bind_tools(tool_functions) | |
| if tool_functions: | |
| for tool in tool_functions: | |
| if tool.name == "retrieve_document": | |
| prompt += "Sử dụng tool `retrieve_document` để truy xuất tài liệu để bổ sung thông tin cho câu trả lời" | |
| if tool.name == "python_repl": | |
| prompt += "Sử dụng tool `python_repl` để thực hiện các tác vụ liên quan đến tính toán phức tạp" | |
| if tool.name == "duckduckgo_search": | |
| prompt += "Sử dụng tool `duckduckgo_search` để tìm kiếm thông tin trên internet" | |
| response = llm_call.invoke( | |
| { | |
| "messages": messages, | |
| "prompt": prompt, | |
| } | |
| ) | |
| return {"messages": response} | |