Multi-Rag / src /MultiRag /graph /builder.py
VashuTheGreat2's picture
Upload folder using huggingface_hub
5551822 verified
import logging
from langgraph.graph import START, END, StateGraph
from src.MultiRag.models.rag_model import State
from src.MultiRag.nodes.chat_node import chat_node
from src.MultiRag.graph.worker.builder import graph as worker_sub_graph
from src.MultiRag.nodes.orchestrator_node import orchestrator_node
from src.MultiRag.nodes.reducer_node import reducer_node
from langgraph.prebuilt import ToolNode
from src.MultiRag.memory import memory
from langgraph.types import Send
from src.MultiRag.tools.web_search import WebSearch
from langchain.agents.middleware import ToolCallLimitMiddleware
tool_limiter = ToolCallLimitMiddleware(
run_limit=3,
exit_behavior="continue",
)
def enforce_tool_limit(state: State):
updates = tool_limiter.after_model(state, runtime=None)
return updates or {}
def after_tool_limit(state: State):
if state.get("jump_to") == "end":
return "chat_node"
last_message = state.get("messages", [])[-1]
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
return "chat_node"
logging.info("Initializing StateGraph with State model...")
graph_builder = StateGraph(State)
def fanout(state: State):
logging.info("Evaluating fanout condition from orchestrator_node")
plan = state.get("plan")
if not plan:
logging.warning("No plan found in state, defaulting to chat_node")
return "chat_node"
if not plan.use_worker:
logging.info("Orchestrator decided to bypass workers and go to chat")
return "chat_node"
tasks = plan.tasks or []
if not tasks:
logging.info("No tasks to execute, going to chat_node")
return "chat_node"
logging.info(f"Fanning out {len(tasks)} tasks to workers")
return [
Send(
"worker",
{
"plan_to_retrieve": task.instruction,
"file_type": task.file_type,
"file_path": task.file_path,
"thread_id": state.get("thread_id", "1"),
"worker_result": [],
},
)
for task in tasks
]
def should_continue(state: State):
last_message=state.get("messages", [])[-1] if state.get("messages") else None
if last_message.tool_calls:
return "tool_limit"
return END
logging.info("Adding nodes to graph builder: orchestrator_node, chat_node, worker, reducer_node")
graph_builder.add_node("orchestrator_node", orchestrator_node)
graph_builder.add_node("chat_node", chat_node)
graph_builder.add_node("worker", worker_sub_graph)
graph_builder.add_node("reducer_node", reducer_node)
graph_builder.add_node("tools", ToolNode([WebSearch().search]))
graph_builder.add_node("tool_limit", enforce_tool_limit)
logging.info("Configuring graph edges and flow...")
graph_builder.add_edge(START, "orchestrator_node")
logging.info("Setting up conditional edges from orchestrator_node using fanout")
graph_builder.add_conditional_edges(
"orchestrator_node",
fanout,
{
"worker": "worker",
"chat_node": "chat_node"
}
)
logging.info("Connecting worker to reducer_node and then to chat_node")
graph_builder.add_edge("worker", "reducer_node")
graph_builder.add_edge("reducer_node", "chat_node")
graph_builder.add_conditional_edges(
"chat_node",
should_continue,
["tool_limit", END]
)
# graph_builder.add_conditional_edges("chat_node", should_continue, ["tools", END])
graph_builder.add_conditional_edges(
"tool_limit",
after_tool_limit,
["tools", "chat_node"]
)
graph_builder.add_edge("tools", "chat_node")
logging.info("Compiling graph...")
graph = graph_builder.compile(checkpointer=memory)
try:
png_data = graph.get_graph(xray=1).draw_mermaid_png()
with open("graph.png", "wb") as f:
f.write(png_data)
logging.info("Graph visualization saved to graph.png")
except Exception as e:
logging.warning(f"Could not generate graph visualization: {e}")
logging.info("Graph compiled successfully.")
async def deleteThread(thread_id: str):
try:
cp = memory
state = await cp.aget_tuple(config={'configurable': {'thread_id': thread_id}})
if state is None:
logging.info(f"Thread {thread_id} not found, nothing to delete.")
return False
await cp.adelete_thread(thread_id=thread_id)
logging.info(f"Thread {thread_id} deleted successfully.")
return True
except Exception as e:
logging.error(f"Error deleting thread {thread_id}: {e}")
return False
async def retrieve_all_threads():
try:
cp=memory
all_threads = set()
for checkpoint in cp.list(None):
all_threads.add(checkpoint.config["configurable"]["thread_id"])
return list(all_threads)
except Exception as e:
logging.error(f"Error retrieving threads: {e}")
return []
async def load_conversation(thread_id):
try:
state = graph.get_state(config={'configurable': {'thread_id': thread_id}})
return state.values.get('messages', [])
except Exception as e:
logging.error(f"Error loading conversation: {e}")
return []