import logging from langgraph.graph import START, END, StateGraph from src.MultiRag.models.rag_model import State from src.MultiRag.nodes.chat_node import chat_node from src.MultiRag.graph.worker.builder import graph as worker_sub_graph from src.MultiRag.nodes.orchestrator_node import orchestrator_node from src.MultiRag.nodes.reducer_node import reducer_node from langgraph.prebuilt import ToolNode from src.MultiRag.memory import memory from langgraph.types import Send from src.MultiRag.tools.web_search import WebSearch from langchain.agents.middleware import ToolCallLimitMiddleware tool_limiter = ToolCallLimitMiddleware( run_limit=3, exit_behavior="continue", ) def enforce_tool_limit(state: State): updates = tool_limiter.after_model(state, runtime=None) return updates or {} def after_tool_limit(state: State): if state.get("jump_to") == "end": return "chat_node" last_message = state.get("messages", [])[-1] if hasattr(last_message, "tool_calls") and last_message.tool_calls: return "tools" return "chat_node" logging.info("Initializing StateGraph with State model...") graph_builder = StateGraph(State) def fanout(state: State): logging.info("Evaluating fanout condition from orchestrator_node") plan = state.get("plan") if not plan: logging.warning("No plan found in state, defaulting to chat_node") return "chat_node" if not plan.use_worker: logging.info("Orchestrator decided to bypass workers and go to chat") return "chat_node" tasks = plan.tasks or [] if not tasks: logging.info("No tasks to execute, going to chat_node") return "chat_node" logging.info(f"Fanning out {len(tasks)} tasks to workers") return [ Send( "worker", { "plan_to_retrieve": task.instruction, "file_type": task.file_type, "file_path": task.file_path, "thread_id": state.get("thread_id", "1"), "worker_result": [], }, ) for task in tasks ] def should_continue(state: State): last_message=state.get("messages", [])[-1] if state.get("messages") else None if last_message.tool_calls: return "tool_limit" return END logging.info("Adding nodes to graph builder: orchestrator_node, chat_node, worker, reducer_node") graph_builder.add_node("orchestrator_node", orchestrator_node) graph_builder.add_node("chat_node", chat_node) graph_builder.add_node("worker", worker_sub_graph) graph_builder.add_node("reducer_node", reducer_node) graph_builder.add_node("tools", ToolNode([WebSearch().search])) graph_builder.add_node("tool_limit", enforce_tool_limit) logging.info("Configuring graph edges and flow...") graph_builder.add_edge(START, "orchestrator_node") logging.info("Setting up conditional edges from orchestrator_node using fanout") graph_builder.add_conditional_edges( "orchestrator_node", fanout, { "worker": "worker", "chat_node": "chat_node" } ) logging.info("Connecting worker to reducer_node and then to chat_node") graph_builder.add_edge("worker", "reducer_node") graph_builder.add_edge("reducer_node", "chat_node") graph_builder.add_conditional_edges( "chat_node", should_continue, ["tool_limit", END] ) # graph_builder.add_conditional_edges("chat_node", should_continue, ["tools", END]) graph_builder.add_conditional_edges( "tool_limit", after_tool_limit, ["tools", "chat_node"] ) graph_builder.add_edge("tools", "chat_node") logging.info("Compiling graph...") graph = graph_builder.compile(checkpointer=memory) try: png_data = graph.get_graph(xray=1).draw_mermaid_png() with open("graph.png", "wb") as f: f.write(png_data) logging.info("Graph visualization saved to graph.png") except Exception as e: logging.warning(f"Could not generate graph visualization: {e}") logging.info("Graph compiled successfully.") async def deleteThread(thread_id: str): try: cp = memory state = await cp.aget_tuple(config={'configurable': {'thread_id': thread_id}}) if state is None: logging.info(f"Thread {thread_id} not found, nothing to delete.") return False await cp.adelete_thread(thread_id=thread_id) logging.info(f"Thread {thread_id} deleted successfully.") return True except Exception as e: logging.error(f"Error deleting thread {thread_id}: {e}") return False async def retrieve_all_threads(): try: cp=memory all_threads = set() for checkpoint in cp.list(None): all_threads.add(checkpoint.config["configurable"]["thread_id"]) return list(all_threads) except Exception as e: logging.error(f"Error retrieving threads: {e}") return [] async def load_conversation(thread_id): try: state = graph.get_state(config={'configurable': {'thread_id': thread_id}}) return state.values.get('messages', []) except Exception as e: logging.error(f"Error loading conversation: {e}") return []