from typing import TypedDict, Annotated from langgraph.graph.message import add_messages from langchain_core.messages import AnyMessage, HumanMessage, AIMessage from langgraph.graph import START, StateGraph from components.stage_mapping import get_stage_list, get_next_stage from llm_utils import call_llm_api, is_stage_complete from langchain_core.messages import AIMessage from langgraph.prebuilt import ToolNode, tools_condition # Define the agent state class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] current_stage: str completed_stages: list[str] stage_list = get_stage_list() def make_stage_node(stage_name): def stage_node(state: AgentState): # Only proceed if the last message is from the user last_msg = state["messages"][-1] # Only call LLM if the last message is from the user (not AI) if hasattr(last_msg, "type") and last_msg.type == "human": # Prepare messages for LLM context messages = [] for msg in state["messages"]: if hasattr(msg, "type") and msg.type == "system": messages.append({"role": "system", "content": msg.content}) elif hasattr(msg, "type") and msg.type == "human": messages.append({"role": "user", "content": msg.content}) elif hasattr(msg, "type") and msg.type == "ai": messages.append({"role": "assistant", "content": msg.content}) # --- Add robust stage management system prompt --- stage_context_prompt = ( f"[Stage Management]\n" f"Current stage: {state['current_stage']}\n" f"Completed stages: {', '.join(state['completed_stages']) if state['completed_stages'] else 'None'}\n" "You must always check if the current stage is complete. You must look at evidence in to determine if you have enough logical information and reasoning to conclude the stage is complete. " "If it is, clearly state that the stage is complete and suggest moving to the next stage. " "If not, ask clarifying questions or provide guidance for the current stage. " "Never forget to consider the current stage and completed stages in your reasoning." ) messages = [{"role": "system", "content": stage_context_prompt}] + messages assistant_reply = call_llm_api(messages) new_messages = state["messages"] + [AIMessage(content=assistant_reply)] completed_stages = state["completed_stages"].copy() current_stage = state["current_stage"] # Only move to next stage if is_stage_complete returns True if is_stage_complete(assistant_reply): completed_stages.append(current_stage) next_stage = get_next_stage(current_stage) if next_stage: current_stage = next_stage else: current_stage = None return { "messages": new_messages, "current_stage": current_stage, "completed_stages": completed_stages, } else: # If last message is not from user, do nothing (wait for user input) return state return stage_node # Build the graph builder = StateGraph(AgentState) # Add a node for each stage for stage in stage_list: builder.add_node(stage, make_stage_node(stage)) # Add edges for sequential progression and conditional tool usage builder.add_edge(START, stage_list[0]) for stage in stage_list: next_stage = get_next_stage(stage) # Always add a conditional edge to tools and to the next/default stage if next_stage: builder.add_edge(stage, next_stage) ## Modal and Nebius do not support conditional tool edges yet # Compile the graph stage_graph = builder.compile() with open("graph_output.png", "wb") as f: f.write(stage_graph.get_graph().draw_mermaid_png())