Aleksey Matsarski
Refactoring code, provide better abstraction and file structure
a47e415
from langchain.agents import AgentExecutor
from workflow.graph_state import GraphState
from pathlib import Path
import yaml
yaml_path = Path(__file__).parent / "prompts.yaml"
with yaml_path.open() as f:
prompt_template = yaml.safe_load(f)
AGENTS = ["news", "earnings", "market"]
def news_node(state: GraphState, agent: AgentExecutor) -> GraphState:
ticker = state["ticker"]
query = prompt_template['news_user_prompt'].format(ticker=ticker)
res = agent.invoke({"input": query})
state["news_summary"] = res["output"]
state["completed"] = list(set(state["completed"] + ["news"]))
return state
def earnings_node(state: GraphState, agent: AgentExecutor) -> GraphState:
ticker = state["ticker"]
query = prompt_template['earnings_user_prompt'].format(ticker=ticker)
res = agent.invoke({"input": query})
state["earnings_summary"] = res["output"]
state["completed"] = list(set(state["completed"] + ["earnings"]))
return state
def market_node(state: GraphState, agent: AgentExecutor) -> GraphState:
ticker = state["ticker"]
query = prompt_template['market_user_prompt'].format(ticker=ticker)
res = agent.invoke({"input": query})
state["market_summary"] = res["output"]
state["completed"] = list(set(state["completed"] + ["market"]))
return state
def synth_node(state: GraphState, synthesizer_chain) -> GraphState:
out = synthesizer_chain.invoke(
{
"ticker": state["ticker"],
"news_summary": state.get("news_summary", ""),
"earnings_summary": state.get("earnings_summary", ""),
"market_summary": state.get("market_summary", ""),
}
)
state["final_recommendation"] = out.content if hasattr(out, "content") else str(out)
return state
def supervisor_node(state: GraphState) -> GraphState:
# Do any bookkeeping here if needed; otherwise just pass state through
return state
def supervisor_router(state: GraphState) -> str:
remaining = [a for a in AGENTS if a not in state.get("completed", [])]
return remaining[0] if remaining else "synth"