|
|
from langchain.agents import AgentExecutor |
|
|
from workflow.graph_state import GraphState |
|
|
from pathlib import Path |
|
|
import yaml |
|
|
|
|
|
yaml_path = Path(__file__).parent / "prompts.yaml" |
|
|
with yaml_path.open() as f: |
|
|
prompt_template = yaml.safe_load(f) |
|
|
|
|
|
AGENTS = ["news", "earnings", "market"] |
|
|
|
|
|
def news_node(state: GraphState, agent: AgentExecutor) -> GraphState: |
|
|
ticker = state["ticker"] |
|
|
query = prompt_template['news_user_prompt'].format(ticker=ticker) |
|
|
res = agent.invoke({"input": query}) |
|
|
state["news_summary"] = res["output"] |
|
|
state["completed"] = list(set(state["completed"] + ["news"])) |
|
|
return state |
|
|
|
|
|
|
|
|
def earnings_node(state: GraphState, agent: AgentExecutor) -> GraphState: |
|
|
ticker = state["ticker"] |
|
|
query = prompt_template['earnings_user_prompt'].format(ticker=ticker) |
|
|
res = agent.invoke({"input": query}) |
|
|
state["earnings_summary"] = res["output"] |
|
|
state["completed"] = list(set(state["completed"] + ["earnings"])) |
|
|
return state |
|
|
|
|
|
|
|
|
def market_node(state: GraphState, agent: AgentExecutor) -> GraphState: |
|
|
ticker = state["ticker"] |
|
|
query = prompt_template['market_user_prompt'].format(ticker=ticker) |
|
|
res = agent.invoke({"input": query}) |
|
|
state["market_summary"] = res["output"] |
|
|
state["completed"] = list(set(state["completed"] + ["market"])) |
|
|
return state |
|
|
|
|
|
|
|
|
def synth_node(state: GraphState, synthesizer_chain) -> GraphState: |
|
|
out = synthesizer_chain.invoke( |
|
|
{ |
|
|
"ticker": state["ticker"], |
|
|
"news_summary": state.get("news_summary", ""), |
|
|
"earnings_summary": state.get("earnings_summary", ""), |
|
|
"market_summary": state.get("market_summary", ""), |
|
|
} |
|
|
) |
|
|
state["final_recommendation"] = out.content if hasattr(out, "content") else str(out) |
|
|
return state |
|
|
|
|
|
def supervisor_node(state: GraphState) -> GraphState: |
|
|
|
|
|
return state |
|
|
|
|
|
def supervisor_router(state: GraphState) -> str: |
|
|
remaining = [a for a in AGENTS if a not in state.get("completed", [])] |
|
|
return remaining[0] if remaining else "synth" |