File size: 3,558 Bytes
b4856f1 2473009 b4856f1 752f5cc b4856f1 752f5cc b4856f1 752f5cc b4856f1 752f5cc b4856f1 752f5cc b4856f1 752f5cc b4856f1 752f5cc b4856f1 752f5cc b4856f1 b4c4175 752f5cc b4856f1 752f5cc b4856f1 b4c4175 b4856f1 752f5cc b4856f1 752f5cc b4856f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
"""
dataRetrievalAgentGraph.py - Data Retrieval Agent Graph Builder
"""
from langgraph.graph import StateGraph, START, END
from src.llms.groqllm import GroqLLM
from src.states.dataRetrievalAgentState import DataRetrievalAgentState
from src.nodes.dataRetrievalAgentNode import DataRetrievalAgentNode
class DataRetrievalAgentGraph(DataRetrievalAgentNode):
def __init__(self, llm):
super().__init__(llm)
self.llm = llm
def prepare_worker_tasks(self, state: DataRetrievalAgentState) -> dict:
tasks = state.generated_tasks
initial_states = [{"generated_tasks": [task]} for task in tasks]
return {"tasks_for_workers": initial_states}
def create_worker_graph(self):
worker_graph_builder = StateGraph(DataRetrievalAgentState)
worker_graph_builder.add_node("worker_agent", self.worker_agent_node)
worker_graph_builder.add_node("tool_node", self.tool_node)
worker_graph_builder.set_entry_point("worker_agent")
worker_graph_builder.add_edge("worker_agent", "tool_node")
worker_graph_builder.add_edge("tool_node", END)
return worker_graph_builder.compile()
def aggregate_results(self, state: DataRetrievalAgentState) -> dict:
worker_outputs = getattr(state, "worker", [])
new_results = []
if isinstance(worker_outputs, list):
for output in worker_outputs:
if "worker_results" in output and output["worker_results"]:
new_results.extend(output["worker_results"])
return {"worker_results": new_results, "latest_worker_results": new_results}
def format_output(self, state: DataRetrievalAgentState) -> dict:
classified_events = state.classified_buffer
insights = []
for event in classified_events:
insights.append(
{
"source_event_id": event.event_id,
"domain": event.target_agent,
"severity": "medium",
"summary": event.content_summary,
"risk_score": event.confidence_score,
}
)
print(f"[DATA RETRIEVAL] Formatted {len(insights)} insights for parent graph")
return {"domain_insights": insights}
def build_data_retrieval_agent_graph(self):
worker_graph = self.create_worker_graph()
workflow = StateGraph(DataRetrievalAgentState)
workflow.add_node("master_delegator", self.master_agent_node)
workflow.add_node("prepare_worker_tasks", self.prepare_worker_tasks)
workflow.add_node(
"worker",
lambda state: {
"worker": worker_graph.map().invoke(state.tasks_for_workers)
},
)
workflow.add_node("aggregate_results", self.aggregate_results)
workflow.add_node("classifier_agent", self.classifier_agent_node)
workflow.add_node("format_output", self.format_output)
workflow.set_entry_point("master_delegator")
workflow.add_edge("master_delegator", "prepare_worker_tasks")
workflow.add_edge("prepare_worker_tasks", "worker")
workflow.add_edge("worker", "aggregate_results")
workflow.add_edge("aggregate_results", "classifier_agent")
workflow.add_edge("classifier_agent", "format_output")
workflow.add_edge("format_output", END)
return workflow.compile()
llm = GroqLLM().get_llm()
graph_builder = DataRetrievalAgentGraph(llm)
graph = graph_builder.build_data_retrieval_agent_graph()
|