from langchain.agents import AgentExecutor from workflow.graph_state import GraphState from pathlib import Path import yaml yaml_path = Path(__file__).parent / "prompts.yaml" with yaml_path.open() as f: prompt_template = yaml.safe_load(f) AGENTS = ["news", "earnings", "market"] def news_node(state: GraphState, agent: AgentExecutor) -> GraphState: ticker = state["ticker"] query = prompt_template['news_user_prompt'].format(ticker=ticker) res = agent.invoke({"input": query}) state["news_summary"] = res["output"] state["completed"] = list(set(state["completed"] + ["news"])) return state def earnings_node(state: GraphState, agent: AgentExecutor) -> GraphState: ticker = state["ticker"] query = prompt_template['earnings_user_prompt'].format(ticker=ticker) res = agent.invoke({"input": query}) state["earnings_summary"] = res["output"] state["completed"] = list(set(state["completed"] + ["earnings"])) return state def market_node(state: GraphState, agent: AgentExecutor) -> GraphState: ticker = state["ticker"] query = prompt_template['market_user_prompt'].format(ticker=ticker) res = agent.invoke({"input": query}) state["market_summary"] = res["output"] state["completed"] = list(set(state["completed"] + ["market"])) return state def synth_node(state: GraphState, synthesizer_chain) -> GraphState: out = synthesizer_chain.invoke( { "ticker": state["ticker"], "news_summary": state.get("news_summary", ""), "earnings_summary": state.get("earnings_summary", ""), "market_summary": state.get("market_summary", ""), } ) state["final_recommendation"] = out.content if hasattr(out, "content") else str(out) return state def supervisor_node(state: GraphState) -> GraphState: # Do any bookkeeping here if needed; otherwise just pass state through return state def supervisor_router(state: GraphState) -> str: remaining = [a for a in AGENTS if a not in state.get("completed", [])] return remaining[0] if remaining else "synth"