File size: 2,110 Bytes
a47e415
 
 
 
 
 
 
 
 
 
fd0d494
 
 
a47e415
 
fd0d494
 
 
 
 
 
 
a47e415
 
fd0d494
 
 
 
 
 
 
a47e415
 
fd0d494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from langchain.agents import AgentExecutor
from workflow.graph_state import GraphState
from pathlib import Path
import yaml

yaml_path = Path(__file__).parent / "prompts.yaml"
with yaml_path.open() as f:
    prompt_template = yaml.safe_load(f)

AGENTS = ["news", "earnings", "market"]

def news_node(state: GraphState, agent: AgentExecutor) -> GraphState:
    ticker = state["ticker"]
    query = prompt_template['news_user_prompt'].format(ticker=ticker)
    res = agent.invoke({"input": query})
    state["news_summary"] = res["output"]
    state["completed"] = list(set(state["completed"] + ["news"]))
    return state


def earnings_node(state: GraphState, agent: AgentExecutor) -> GraphState:
    ticker = state["ticker"]
    query = prompt_template['earnings_user_prompt'].format(ticker=ticker)
    res = agent.invoke({"input": query})
    state["earnings_summary"] = res["output"]
    state["completed"] = list(set(state["completed"] + ["earnings"]))
    return state


def market_node(state: GraphState, agent: AgentExecutor) -> GraphState:
    ticker = state["ticker"]
    query = prompt_template['market_user_prompt'].format(ticker=ticker)
    res = agent.invoke({"input": query})
    state["market_summary"] = res["output"]
    state["completed"] = list(set(state["completed"] + ["market"]))
    return state


def synth_node(state: GraphState, synthesizer_chain) -> GraphState:
    out = synthesizer_chain.invoke(
        {
            "ticker": state["ticker"],
            "news_summary": state.get("news_summary", ""),
            "earnings_summary": state.get("earnings_summary", ""),
            "market_summary": state.get("market_summary", ""),
        }
    )
    state["final_recommendation"] = out.content if hasattr(out, "content") else str(out)
    return state
    
def supervisor_node(state: GraphState) -> GraphState:
    # Do any bookkeeping here if needed; otherwise just pass state through
    return state

def supervisor_router(state: GraphState) -> str:
    remaining = [a for a in AGENTS if a not in state.get("completed", [])]
    return remaining[0] if remaining else "synth"