|
|
from typing import Dict, Any, List, Optional |
|
|
from .base import BaseAgent, TaskInput, AgentTool |
|
|
from langgraph.graph import Graph, StateGraph |
|
|
from langgraph.prebuilt import ToolExecutor |
|
|
import operator |
|
|
from datetime import datetime |
|
|
import logging |
|
|
|
|
|
class LangGraphAgent(BaseAgent): |
|
|
"""LangGraph-based agent implementation""" |
|
|
|
|
|
async def initialize(self) -> None: |
|
|
"""Initialize LangGraph agent""" |
|
|
await super().initialize() |
|
|
|
|
|
try: |
|
|
|
|
|
self.tool_executor = ToolExecutor( |
|
|
tools=[ |
|
|
self._convert_tool_to_langgraph(tool) |
|
|
for tool in self.agent_config.tools |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.graph = self._create_agent_graph() |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Failed to initialize LangGraph agent: {str(e)}") |
|
|
raise |
|
|
|
|
|
def _convert_tool_to_langgraph(self, tool: AgentTool) -> Dict[str, Any]: |
|
|
"""Convert Lattice tool to LangGraph format""" |
|
|
return { |
|
|
"name": tool.name, |
|
|
"description": tool.description, |
|
|
"func": tool.function, |
|
|
"args_schema": {} |
|
|
} |
|
|
|
|
|
def _create_agent_graph(self) -> StateGraph: |
|
|
"""Create agent workflow graph""" |
|
|
workflow = StateGraph() |
|
|
|
|
|
|
|
|
workflow.add_node("agent", self._agent_node) |
|
|
workflow.add_node("tool_executor", self.tool_executor) |
|
|
|
|
|
|
|
|
workflow.add_edge("agent", "tool_executor") |
|
|
workflow.add_edge("tool_executor", "agent") |
|
|
|
|
|
|
|
|
workflow.set_entry_point("agent") |
|
|
|
|
|
return workflow.compile() |
|
|
|
|
|
async def _agent_node(self, state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Agent node implementation""" |
|
|
|
|
|
message = state.get("message", "") |
|
|
history = state.get("history", []) |
|
|
|
|
|
|
|
|
response = await self._generate_response(message, history) |
|
|
|
|
|
|
|
|
if self._needs_tool(response): |
|
|
tool_call = self._extract_tool_call(response) |
|
|
return { |
|
|
"type": "tool_call", |
|
|
"tool": tool_call["name"], |
|
|
"args": tool_call["args"] |
|
|
} |
|
|
|
|
|
return { |
|
|
"type": "final", |
|
|
"response": response |
|
|
} |
|
|
|
|
|
async def _execute_implementation(self, task: TaskInput) -> Dict[str, Any]: |
|
|
"""Execute task using LangGraph""" |
|
|
try: |
|
|
|
|
|
initial_state = { |
|
|
"message": task.description, |
|
|
"history": [], |
|
|
"context": task.context or {}, |
|
|
"tools": task.tools or [] |
|
|
} |
|
|
|
|
|
|
|
|
result = await self.graph.arun(initial_state) |
|
|
|
|
|
return { |
|
|
'output': result.get("response", ""), |
|
|
'steps': result.get("history", []), |
|
|
'metadata': { |
|
|
'framework': 'langgraph', |
|
|
'tools_used': result.get("tools_used", []) |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"LangGraph task execution failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
class LangGraphWorkflow: |
|
|
"""LangGraph workflow implementation""" |
|
|
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
|
self.config = config |
|
|
self.logger = logging.getLogger("lattice.agent.langgraph.workflow") |
|
|
|
|
|
def _create_workflow_graph( |
|
|
self, |
|
|
agents: List[LangGraphAgent], |
|
|
tasks: List[TaskInput] |
|
|
) -> StateGraph: |
|
|
"""Create workflow graph connecting multiple agents""" |
|
|
workflow = StateGraph() |
|
|
|
|
|
|
|
|
for i, agent in enumerate(agents): |
|
|
workflow.add_node(f"agent_{i}", agent._agent_node) |
|
|
workflow.add_node(f"tools_{i}", agent.tool_executor) |
|
|
|
|
|
|
|
|
workflow.add_node("router", self._create_router(tasks)) |
|
|
|
|
|
|
|
|
for i in range(len(agents)): |
|
|
workflow.add_edge(f"agent_{i}", f"tools_{i}") |
|
|
workflow.add_edge(f"tools_{i}", "router") |
|
|
workflow.add_edge("router", f"agent_{i}") |
|
|
|
|
|
|
|
|
workflow.set_entry_point("router") |
|
|
|
|
|
return workflow.compile() |
|
|
|
|
|
def _create_router(self, tasks: List[TaskInput]): |
|
|
"""Create routing logic based on tasks""" |
|
|
def router(state: Dict[str, Any]) -> str: |
|
|
current_task = state.get("current_task", 0) |
|
|
if current_task >= len(tasks): |
|
|
return "end" |
|
|
|
|
|
|
|
|
return f"agent_{current_task}" |
|
|
|
|
|
return router |
|
|
|
|
|
async def execute_workflow( |
|
|
self, |
|
|
agents: List[LangGraphAgent], |
|
|
tasks: List[TaskInput] |
|
|
) -> Dict[str, Any]: |
|
|
"""Execute a workflow with multiple agents""" |
|
|
try: |
|
|
|
|
|
graph = self._create_workflow_graph(agents, tasks) |
|
|
|
|
|
|
|
|
initial_state = { |
|
|
"current_task": 0, |
|
|
"tasks": [task.dict() for task in tasks], |
|
|
"results": [], |
|
|
"history": [] |
|
|
} |
|
|
|
|
|
|
|
|
result = await graph.arun(initial_state) |
|
|
|
|
|
return { |
|
|
'results': result.get("results", []), |
|
|
'history': result.get("history", []), |
|
|
'metadata': { |
|
|
'framework': 'langgraph', |
|
|
'agent_count': len(agents), |
|
|
'task_count': len(tasks) |
|
|
}, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"LangGraph workflow execution failed: {str(e)}") |
|
|
raise |
|
|
|
|
|
class LangGraphWorkflow |