File size: 3,558 Bytes
b4856f1
2473009
b4856f1
752f5cc
b4856f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752f5cc
b4856f1
 
752f5cc
b4856f1
 
 
752f5cc
b4856f1
 
 
752f5cc
b4856f1
752f5cc
b4856f1
 
 
 
752f5cc
 
b4856f1
 
 
 
752f5cc
b4856f1
b4c4175
 
 
 
 
 
 
 
 
752f5cc
b4856f1
 
 
 
 
 
752f5cc
b4856f1
 
 
 
b4c4175
 
 
b4856f1
 
 
 
752f5cc
b4856f1
 
 
 
 
 
 
752f5cc
b4856f1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
dataRetrievalAgentGraph.py - Data Retrieval Agent Graph Builder
"""

from langgraph.graph import StateGraph, START, END
from src.llms.groqllm import GroqLLM
from src.states.dataRetrievalAgentState import DataRetrievalAgentState
from src.nodes.dataRetrievalAgentNode import DataRetrievalAgentNode


class DataRetrievalAgentGraph(DataRetrievalAgentNode):
    def __init__(self, llm):
        super().__init__(llm)
        self.llm = llm

    def prepare_worker_tasks(self, state: DataRetrievalAgentState) -> dict:
        tasks = state.generated_tasks
        initial_states = [{"generated_tasks": [task]} for task in tasks]
        return {"tasks_for_workers": initial_states}

    def create_worker_graph(self):
        worker_graph_builder = StateGraph(DataRetrievalAgentState)

        worker_graph_builder.add_node("worker_agent", self.worker_agent_node)
        worker_graph_builder.add_node("tool_node", self.tool_node)

        worker_graph_builder.set_entry_point("worker_agent")
        worker_graph_builder.add_edge("worker_agent", "tool_node")
        worker_graph_builder.add_edge("tool_node", END)

        return worker_graph_builder.compile()

    def aggregate_results(self, state: DataRetrievalAgentState) -> dict:
        worker_outputs = getattr(state, "worker", [])
        new_results = []

        if isinstance(worker_outputs, list):
            for output in worker_outputs:
                if "worker_results" in output and output["worker_results"]:
                    new_results.extend(output["worker_results"])

        return {"worker_results": new_results, "latest_worker_results": new_results}

    def format_output(self, state: DataRetrievalAgentState) -> dict:
        classified_events = state.classified_buffer
        insights = []

        for event in classified_events:
            insights.append(
                {
                    "source_event_id": event.event_id,
                    "domain": event.target_agent,
                    "severity": "medium",
                    "summary": event.content_summary,
                    "risk_score": event.confidence_score,
                }
            )

        print(f"[DATA RETRIEVAL] Formatted {len(insights)} insights for parent graph")
        return {"domain_insights": insights}

    def build_data_retrieval_agent_graph(self):
        worker_graph = self.create_worker_graph()
        workflow = StateGraph(DataRetrievalAgentState)

        workflow.add_node("master_delegator", self.master_agent_node)
        workflow.add_node("prepare_worker_tasks", self.prepare_worker_tasks)
        workflow.add_node(
            "worker",
            lambda state: {
                "worker": worker_graph.map().invoke(state.tasks_for_workers)
            },
        )
        workflow.add_node("aggregate_results", self.aggregate_results)
        workflow.add_node("classifier_agent", self.classifier_agent_node)
        workflow.add_node("format_output", self.format_output)

        workflow.set_entry_point("master_delegator")
        workflow.add_edge("master_delegator", "prepare_worker_tasks")
        workflow.add_edge("prepare_worker_tasks", "worker")
        workflow.add_edge("worker", "aggregate_results")
        workflow.add_edge("aggregate_results", "classifier_agent")
        workflow.add_edge("classifier_agent", "format_output")
        workflow.add_edge("format_output", END)

        return workflow.compile()


llm = GroqLLM().get_llm()
graph_builder = DataRetrievalAgentGraph(llm)
graph = graph_builder.build_data_retrieval_agent_graph()