Lattice / core /agents /langgraph_agent.py
cryogenic22's picture
Create core/agents/langgraph_agent.py
1f95c9c verified
from typing import Dict, Any, List, Optional
from .base import BaseAgent, TaskInput, AgentTool
from langgraph.graph import Graph, StateGraph
from langgraph.prebuilt import ToolExecutor
import operator
from datetime import datetime
import logging
class LangGraphAgent(BaseAgent):
"""LangGraph-based agent implementation"""
async def initialize(self) -> None:
"""Initialize LangGraph agent"""
await super().initialize()
try:
# Initialize tool executor
self.tool_executor = ToolExecutor(
tools=[
self._convert_tool_to_langgraph(tool)
for tool in self.agent_config.tools
]
)
# Create agent graph
self.graph = self._create_agent_graph()
except Exception as e:
self.logger.error(f"Failed to initialize LangGraph agent: {str(e)}")
raise
def _convert_tool_to_langgraph(self, tool: AgentTool) -> Dict[str, Any]:
"""Convert Lattice tool to LangGraph format"""
return {
"name": tool.name,
"description": tool.description,
"func": tool.function,
"args_schema": {} # Add appropriate schema
}
def _create_agent_graph(self) -> StateGraph:
"""Create agent workflow graph"""
workflow = StateGraph()
# Add nodes
workflow.add_node("agent", self._agent_node)
workflow.add_node("tool_executor", self.tool_executor)
# Add edges
workflow.add_edge("agent", "tool_executor")
workflow.add_edge("tool_executor", "agent")
# Set entry point
workflow.set_entry_point("agent")
return workflow.compile()
async def _agent_node(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""Agent node implementation"""
# Process state and decide next action
message = state.get("message", "")
history = state.get("history", [])
# Generate response using LLM
response = await self._generate_response(message, history)
# Determine if tool execution is needed
if self._needs_tool(response):
tool_call = self._extract_tool_call(response)
return {
"type": "tool_call",
"tool": tool_call["name"],
"args": tool_call["args"]
}
return {
"type": "final",
"response": response
}
async def _execute_implementation(self, task: TaskInput) -> Dict[str, Any]:
"""Execute task using LangGraph"""
try:
# Initialize state
initial_state = {
"message": task.description,
"history": [],
"context": task.context or {},
"tools": task.tools or []
}
# Execute graph
result = await self.graph.arun(initial_state)
return {
'output': result.get("response", ""),
'steps': result.get("history", []),
'metadata': {
'framework': 'langgraph',
'tools_used': result.get("tools_used", [])
}
}
except Exception as e:
self.logger.error(f"LangGraph task execution failed: {str(e)}")
raise
class LangGraphWorkflow:
"""LangGraph workflow implementation"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.logger = logging.getLogger("lattice.agent.langgraph.workflow")
def _create_workflow_graph(
self,
agents: List[LangGraphAgent],
tasks: List[TaskInput]
) -> StateGraph:
"""Create workflow graph connecting multiple agents"""
workflow = StateGraph()
# Add agent nodes
for i, agent in enumerate(agents):
workflow.add_node(f"agent_{i}", agent._agent_node)
workflow.add_node(f"tools_{i}", agent.tool_executor)
# Add task routing node
workflow.add_node("router", self._create_router(tasks))
# Connect nodes based on task dependencies
for i in range(len(agents)):
workflow.add_edge(f"agent_{i}", f"tools_{i}")
workflow.add_edge(f"tools_{i}", "router")
workflow.add_edge("router", f"agent_{i}")
# Set entry point
workflow.set_entry_point("router")
return workflow.compile()
def _create_router(self, tasks: List[TaskInput]):
"""Create routing logic based on tasks"""
def router(state: Dict[str, Any]) -> str:
current_task = state.get("current_task", 0)
if current_task >= len(tasks):
return "end"
# Route to appropriate agent based on task
return f"agent_{current_task}"
return router
async def execute_workflow(
self,
agents: List[LangGraphAgent],
tasks: List[TaskInput]
) -> Dict[str, Any]:
"""Execute a workflow with multiple agents"""
try:
# Create workflow graph
graph = self._create_workflow_graph(agents, tasks)
# Initialize state
initial_state = {
"current_task": 0,
"tasks": [task.dict() for task in tasks],
"results": [],
"history": []
}
# Execute workflow
result = await graph.arun(initial_state)
return {
'results': result.get("results", []),
'history': result.get("history", []),
'metadata': {
'framework': 'langgraph',
'agent_count': len(agents),
'task_count': len(tasks)
},
'timestamp': datetime.now().isoformat()
}
except Exception as e:
self.logger.error(f"LangGraph workflow execution failed: {str(e)}")
raise
class LangGraphWorkflow