File size: 2,532 Bytes
a47e415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
import yaml
from langgraph.graph import StateGraph, END

from agents.earnings_agent.earnings_agent import create_earnings_agent
from agents.market_agent.market_agent import create_market_agent
from agents.news_agent.news_agent import create_news_agent
from model.init_model import init_main_model
from workflow.graph_state import GraphState
from workflow.nodes.nodes import news_node, earnings_node, market_node, synth_node, supervisor_node, AGENTS, supervisor_router
from pathlib import Path

yaml_path = Path(__file__).parent / "prompts.yaml"
with yaml_path.open() as f:
    prompt_template = yaml.safe_load(f)

def make_synthesizer(model):
    """Final writer to merge all agent outputs into actionable recommendations."""
    template = ChatPromptTemplate.from_messages(
        [
            ("system", prompt_template["system"]),
            ("human", prompt_template["human"])
        ]
    )
    return template | model  # LC chain: Prompt -> LLM

def build_agents_workflow(llm_model_name):
    # --- Base LLM for agents & synthesizer, we can initiate different models for agents here ---
    model = init_main_model(llm_model_name)

    # --- Create specialized agents ---
    news_agent = create_news_agent(model)
    earnings_agent = create_earnings_agent(model)
    market_agent = create_market_agent(model)

    # --- Create synthesizer chain ---
    synthesizer = make_synthesizer(model)

    # --- LangGraph: wire nodes ---
    g = StateGraph(GraphState)

    # Bind node callables with their dependencies via closures
    g.add_node("news", lambda s: news_node(s, news_agent))
    g.add_node("earnings", lambda s: earnings_node(s, earnings_agent))
    g.add_node("market", lambda s: market_node(s, market_agent))
    g.add_node("synth", lambda s: synth_node(s, synthesizer))

    # Supervisor node
    g.add_node("supervisor", supervisor_node)
    # Edges: start -> supervisor -> (news|earnings|market|synth) -> supervisor ... -> synth -> END
    g.set_entry_point("supervisor")

    for a in AGENTS:
        g.add_edge(a, "supervisor")
    g.add_edge("synth", END)

    # Route decisions come from the router function (returns a string)
    g.add_conditional_edges(
        "supervisor",
        supervisor_router,   # returns: "news" | "earnings" | "market" | "synth"
        {
            "news": "news",
            "earnings": "earnings",
            "market": "market",
            "synth": "synth",
        },
    )

    return g.compile()