|
|
""" |
|
|
CodeAgent: A LangGraph-based agent for executing Python code and using tools. |
|
|
Fully modular version with unified tool management. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import time |
|
|
from typing import Dict, List, Optional |
|
|
from jinja2 import Template |
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage |
|
|
from langchain_openai import ChatOpenAI |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
from core.types import AgentState, AgentConfig |
|
|
from core.constants import SYSTEM_PROMPT_TEMPLATE |
|
|
|
|
|
|
|
|
from managers import ( |
|
|
|
|
|
PackageManager, |
|
|
ConsoleDisplay, |
|
|
|
|
|
PlanManager, |
|
|
StateManager, |
|
|
WorkflowEngine, |
|
|
|
|
|
ToolManager, |
|
|
ToolSource, |
|
|
ToolSelector, |
|
|
|
|
|
Timing, |
|
|
PythonExecutor |
|
|
) |
|
|
|
|
|
|
|
|
load_dotenv("./.env") |
|
|
|
|
|
|
|
|
def get_system_prompt(functions: Dict[str, dict], packages: Dict[str, str] = None) -> str: |
|
|
"""Generate system prompt using template and functions.""" |
|
|
if packages is None: |
|
|
from core.constants import LIBRARY_CONTENT_DICT |
|
|
packages = LIBRARY_CONTENT_DICT |
|
|
return Template(SYSTEM_PROMPT_TEMPLATE).render(functions=functions, packages=packages) |
|
|
|
|
|
|
|
|
class CodeAgent: |
|
|
"""A code-based agent that can execute Python code and use tools to solve tasks.""" |
|
|
|
|
|
def __init__(self, model: BaseChatModel, |
|
|
config: Optional[AgentConfig] = None, |
|
|
use_tool_manager: bool = True, |
|
|
use_tool_selection: bool = True): |
|
|
""" |
|
|
Initialize the CodeAgent with unified tool management. |
|
|
|
|
|
Args: |
|
|
model: The language model to use for generation |
|
|
config: Configuration for the agent |
|
|
use_tool_manager: Whether to use the unified ToolManager (recommended) |
|
|
use_tool_selection: Whether to use LLM-based tool selection (like Biomni) |
|
|
""" |
|
|
self.model = model |
|
|
self.config = config or AgentConfig() |
|
|
self.use_tool_manager = use_tool_manager |
|
|
self.use_tool_selection = use_tool_selection |
|
|
|
|
|
|
|
|
self._selected_tools_cache = None |
|
|
|
|
|
|
|
|
self.package_manager = PackageManager() |
|
|
self.console = ConsoleDisplay() |
|
|
self.state_manager = StateManager() |
|
|
self.plan_manager = PlanManager() |
|
|
|
|
|
|
|
|
if not self.use_tool_manager: |
|
|
raise ValueError("ToolManager is required. Legacy mode (use_tool_manager=False) has been removed.") |
|
|
|
|
|
self.tool_manager = ToolManager(self.console) |
|
|
|
|
|
|
|
|
if self.use_tool_selection: |
|
|
self.tool_selector = ToolSelector(self.model) |
|
|
else: |
|
|
self.tool_selector = None |
|
|
|
|
|
|
|
|
self.workflow_engine = WorkflowEngine(model, self.config, self.console, self.state_manager) |
|
|
|
|
|
|
|
|
self.python_executor = PythonExecutor() |
|
|
|
|
|
|
|
|
self._setup_workflow() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _setup_workflow(self): |
|
|
"""Setup the LangGraph workflow using WorkflowEngine.""" |
|
|
self.workflow_engine.setup_workflow( |
|
|
self.generate, |
|
|
self.execute, |
|
|
self.should_continue |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(self, state: AgentState) -> AgentState: |
|
|
"""Generate response using LLM with tool-aware prompt.""" |
|
|
|
|
|
|
|
|
all_schemas = self.tool_manager.get_tool_schemas(openai_format=True) |
|
|
all_functions_dict = {schema['function']['name']: schema for schema in all_schemas} |
|
|
|
|
|
|
|
|
if self.use_tool_selection and self.tool_selector and state.get("messages") and self._selected_tools_cache is None: |
|
|
|
|
|
user_query = "" |
|
|
for msg in state["messages"]: |
|
|
if hasattr(msg, 'content') and msg.content: |
|
|
user_query = msg.content |
|
|
break |
|
|
|
|
|
if user_query: |
|
|
|
|
|
available_tools = {} |
|
|
for tool_name, schema in all_functions_dict.items(): |
|
|
available_tools[tool_name] = { |
|
|
'description': schema['function'].get('description', 'No description'), |
|
|
'source': 'tool_manager' |
|
|
} |
|
|
|
|
|
|
|
|
selected_tool_names = self.tool_selector.select_tools_for_task( |
|
|
user_query, available_tools, max_tools=15 |
|
|
) |
|
|
|
|
|
|
|
|
self._selected_tools_cache = {name: all_functions_dict[name] |
|
|
for name in selected_tool_names |
|
|
if name in all_functions_dict} |
|
|
|
|
|
self.console.console.print(f"π― Selected {len(self._selected_tools_cache)} tools from {len(all_functions_dict)} available tools (cached for session)") |
|
|
functions_dict = self._selected_tools_cache |
|
|
else: |
|
|
functions_dict = all_functions_dict |
|
|
elif self.use_tool_selection and self._selected_tools_cache is not None: |
|
|
|
|
|
functions_dict = self._selected_tools_cache |
|
|
else: |
|
|
|
|
|
functions_dict = all_functions_dict |
|
|
|
|
|
all_packages = self.package_manager.get_all_packages() |
|
|
system_prompt = get_system_prompt(functions_dict, all_packages) |
|
|
|
|
|
|
|
|
messages = [SystemMessage(content=system_prompt)] + state["messages"] |
|
|
|
|
|
response = self.model.invoke(messages) |
|
|
|
|
|
|
|
|
if "</execute>" in response.content: |
|
|
response.content = response.content.split("</execute>")[0] + "</execute>" |
|
|
|
|
|
|
|
|
msg = str(response.content) |
|
|
llm_reply = AIMessage(content=msg.strip()) |
|
|
|
|
|
|
|
|
new_step_count = state.get("step_count", 0) + 1 |
|
|
|
|
|
return self.state_manager.create_state_dict( |
|
|
messages=[llm_reply], |
|
|
step_count=new_step_count, |
|
|
error_count=state.get("error_count", 0), |
|
|
start_time=state.get("start_time", time.time()), |
|
|
current_plan=self._extract_current_plan(msg) |
|
|
) |
|
|
|
|
|
def _extract_current_plan(self, content: str) -> Optional[str]: |
|
|
"""Extract the current plan from the agent's response.""" |
|
|
return self.plan_manager.extract_plan_from_content(content) |
|
|
|
|
|
def execute(self, state: AgentState) -> AgentState: |
|
|
"""Execute code using persistent Python executor.""" |
|
|
try: |
|
|
last_message = state["messages"][-1].content |
|
|
execute_match = re.search(r"<execute>(.*?)</execute>", last_message, re.DOTALL) |
|
|
|
|
|
if execute_match: |
|
|
code = execute_match.group(1).strip() |
|
|
|
|
|
|
|
|
result = self.python_executor(code) |
|
|
|
|
|
|
|
|
obs = f"\n<observation>\nCode Output:\n{result}</observation>" |
|
|
return self.state_manager.create_state_dict( |
|
|
messages=[AIMessage(content=obs.strip())], |
|
|
step_count=state.get("step_count", 0), |
|
|
error_count=state.get("error_count", 0), |
|
|
start_time=state.get("start_time", time.time()), |
|
|
current_plan=state.get("current_plan") |
|
|
) |
|
|
else: |
|
|
return self.state_manager.create_state_dict( |
|
|
messages=[AIMessage(content="<error>No executable code found</error>")], |
|
|
step_count=state.get("step_count", 0), |
|
|
error_count=state.get("error_count", 0) + 1, |
|
|
start_time=state.get("start_time", time.time()), |
|
|
current_plan=state.get("current_plan") |
|
|
) |
|
|
except Exception as e: |
|
|
return self.state_manager.create_state_dict( |
|
|
messages=[AIMessage(content=f"<error>Execution error: {str(e)}</error>")], |
|
|
step_count=state.get("step_count", 0), |
|
|
error_count=state.get("error_count", 0) + 1, |
|
|
start_time=state.get("start_time", time.time()), |
|
|
current_plan=state.get("current_plan") |
|
|
) |
|
|
|
|
|
def should_continue(self, state: AgentState) -> str: |
|
|
"""Decide whether to continue executing or end the workflow.""" |
|
|
last_message = state["messages"][-1].content |
|
|
step_count = state.get("step_count", 0) |
|
|
error_count = state.get("error_count", 0) |
|
|
start_time = state.get("start_time", time.time()) |
|
|
|
|
|
|
|
|
if time.time() - start_time > self.config.timeout_seconds: |
|
|
return "end" |
|
|
|
|
|
|
|
|
if step_count >= self.config.max_steps: |
|
|
return "end" |
|
|
|
|
|
|
|
|
if error_count >= self.config.retry_attempts: |
|
|
return "end" |
|
|
|
|
|
|
|
|
if "<solution>" in last_message and "</solution>" in last_message: |
|
|
return "end" |
|
|
|
|
|
|
|
|
elif "<execute>" in last_message and "</execute>" in last_message: |
|
|
return "execute" |
|
|
|
|
|
else: |
|
|
return "end" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_packages(self, packages: Dict[str, str]) -> bool: |
|
|
"""Add new packages to the available packages.""" |
|
|
return self.package_manager.add_packages(packages) |
|
|
|
|
|
def get_all_packages(self) -> Dict[str, str]: |
|
|
"""Get all available packages (default + custom).""" |
|
|
return self.package_manager.get_all_packages() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_tool(self, function: callable, name: str = None, description: str = None) -> bool: |
|
|
"""Add a tool function to the manager.""" |
|
|
return self.tool_manager.add_tool(function, name, description, ToolSource.LOCAL) |
|
|
|
|
|
def remove_tool(self, name: str) -> bool: |
|
|
"""Remove a tool by name.""" |
|
|
return self.tool_manager.remove_tool(name) |
|
|
|
|
|
def list_tools(self, source: str = "all", include_details: bool = False) -> List[Dict]: |
|
|
"""List all available tools with optional filtering.""" |
|
|
source_enum = ToolSource.ALL |
|
|
if source.lower() in ["local", "decorated", "mcp"]: |
|
|
source_enum = ToolSource(source.lower()) |
|
|
|
|
|
return self.tool_manager.list_tools(source_enum, include_details) |
|
|
|
|
|
def search_tools(self, query: str) -> List[Dict]: |
|
|
"""Search tools by name and description.""" |
|
|
return self.tool_manager.search_tools(query) |
|
|
|
|
|
def get_tool_info(self, name: str) -> Optional[Dict]: |
|
|
"""Get detailed information about a specific tool.""" |
|
|
tool_info = self.tool_manager.get_tool(name) |
|
|
if tool_info: |
|
|
return { |
|
|
"name": tool_info.name, |
|
|
"description": tool_info.description, |
|
|
"source": tool_info.source.value, |
|
|
"server": tool_info.server, |
|
|
"module": tool_info.module, |
|
|
"has_function": tool_info.function is not None, |
|
|
"required_parameters": tool_info.required_parameters, |
|
|
"optional_parameters": tool_info.optional_parameters |
|
|
} |
|
|
return None |
|
|
|
|
|
def get_all_tool_functions(self) -> Dict[str, callable]: |
|
|
"""Get all tool functions as a dictionary.""" |
|
|
return self.tool_manager.get_all_functions() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_mcp(self, config_path: str = "./mcp_config.yaml") -> None: |
|
|
"""Add MCP tools from configuration file.""" |
|
|
self.tool_manager.add_mcp_server(config_path) |
|
|
|
|
|
def list_mcp_tools(self) -> List[Dict]: |
|
|
"""List all loaded MCP tools.""" |
|
|
return self.tool_manager.list_tools(self.tool_manager.ToolSource.MCP) |
|
|
|
|
|
def list_mcp_servers(self) -> Dict[str, List[str]]: |
|
|
"""List all MCP servers and their tools.""" |
|
|
return self.tool_manager.list_mcp_servers() |
|
|
|
|
|
def show_mcp_status(self) -> None: |
|
|
"""Display detailed MCP status information to the user.""" |
|
|
self.tool_manager.show_mcp_status() |
|
|
|
|
|
def get_mcp_summary(self) -> Dict[str, any]: |
|
|
"""Get a summary of MCP tools for programmatic access.""" |
|
|
return self.tool_manager.get_mcp_summary() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tool_statistics(self) -> Dict[str, any]: |
|
|
"""Get comprehensive tool statistics.""" |
|
|
return self.tool_manager.get_tool_statistics() |
|
|
|
|
|
def validate_tools(self) -> Dict[str, List[str]]: |
|
|
"""Validate all tools and return any issues.""" |
|
|
return self.tool_manager.validate_tools() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_tool_selection(self): |
|
|
"""Reset the cached tool selection to allow re-selection on next query.""" |
|
|
self._selected_tools_cache = None |
|
|
if self.use_tool_selection: |
|
|
self.console.console.print("π Tool selection cache cleared - will re-select tools on next query") |
|
|
|
|
|
def get_selected_tools(self): |
|
|
"""Get the currently selected tools (if any).""" |
|
|
return list(self._selected_tools_cache.keys()) if self._selected_tools_cache else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_trace(self) -> Dict: |
|
|
"""Get the complete trace of the last execution.""" |
|
|
if not self.workflow_engine: |
|
|
return {} |
|
|
|
|
|
return { |
|
|
"execution_time": time.strftime('%Y-%m-%d %H:%M:%S'), |
|
|
"config": { |
|
|
"max_steps": self.config.max_steps, |
|
|
"timeout_seconds": self.config.timeout_seconds, |
|
|
"verbose": self.config.verbose |
|
|
}, |
|
|
"messages": self.workflow_engine.message_history, |
|
|
"trace_logs": self.workflow_engine.trace_logs |
|
|
} |
|
|
|
|
|
def get_summary(self) -> Dict: |
|
|
"""Get a summary of the last execution.""" |
|
|
if not self.workflow_engine: |
|
|
return {} |
|
|
return self.workflow_engine.generate_summary() |
|
|
|
|
|
def save_trace(self, filepath: str = None) -> str: |
|
|
"""Save the trace of the last execution to a file.""" |
|
|
if not self.workflow_engine: |
|
|
raise RuntimeError("No workflow engine available") |
|
|
return self.workflow_engine.save_trace_to_file(filepath) |
|
|
|
|
|
def save_summary(self, filepath: str = None) -> str: |
|
|
"""Save the summary of the last execution to a file.""" |
|
|
if not self.workflow_engine: |
|
|
raise RuntimeError("No workflow engine available") |
|
|
return self.workflow_engine.save_summary_to_file(filepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(self, query: str, save_trace: bool = False, save_summary: bool = False, |
|
|
trace_dir: str = "traces") -> str: |
|
|
""" |
|
|
Run the agent with a given query using modular components. |
|
|
|
|
|
Args: |
|
|
query: The task/question to solve |
|
|
save_trace: Whether to save the complete trace to a file |
|
|
save_summary: Whether to save the execution summary to a file |
|
|
trace_dir: Directory to save trace and summary files |
|
|
|
|
|
Returns: |
|
|
The final response content |
|
|
""" |
|
|
|
|
|
overall_timing = Timing(start_time=time.time()) |
|
|
|
|
|
|
|
|
self.console.print_task_header(query) |
|
|
|
|
|
|
|
|
functions_dict = self.get_all_tool_functions() |
|
|
|
|
|
|
|
|
|
|
|
stats = self.tool_manager.get_tool_statistics() |
|
|
mcp_servers = self.tool_manager.list_mcp_servers() |
|
|
|
|
|
self.console.console.print(f"π οΈ Loaded {stats['total_tools']} total tools:") |
|
|
if stats['by_source']['local'] > 0: |
|
|
self.console.console.print(f" π Local tools: {stats['by_source']['local']}") |
|
|
if stats['by_source']['decorated'] > 0: |
|
|
self.console.console.print(f" π― Decorated tools: {stats['by_source']['decorated']}") |
|
|
if stats['by_source']['mcp'] > 0: |
|
|
self.console.console.print(f" π MCP tools: {stats['by_source']['mcp']} from {len(mcp_servers)} servers") |
|
|
for server_name, tools in mcp_servers.items(): |
|
|
self.console.console.print(f" β’ {server_name}: {len(tools)} tools") |
|
|
|
|
|
|
|
|
self.python_executor.send_functions(functions_dict) |
|
|
|
|
|
|
|
|
imported_packages, failed_packages = self.package_manager.import_packages(self.python_executor) |
|
|
self.console.print_packages_info(imported_packages, failed_packages) |
|
|
|
|
|
|
|
|
state_variables = {} |
|
|
self.python_executor.send_variables(state_variables) |
|
|
|
|
|
|
|
|
input_state = self.state_manager.create_state_dict( |
|
|
messages=[HumanMessage(content=query)], |
|
|
step_count=0, |
|
|
error_count=0, |
|
|
start_time=time.time(), |
|
|
current_plan=None |
|
|
) |
|
|
|
|
|
|
|
|
result, final_state = self.workflow_engine.run_workflow(input_state) |
|
|
|
|
|
|
|
|
overall_timing.end_time = time.time() |
|
|
|
|
|
|
|
|
final_step_count = final_state.get("step_count", 0) if final_state else 0 |
|
|
final_error_count = final_state.get("error_count", 0) if final_state else 0 |
|
|
|
|
|
self.console.print_execution_summary(final_step_count, final_error_count, overall_timing.duration) |
|
|
|
|
|
|
|
|
if save_trace or save_summary: |
|
|
|
|
|
from pathlib import Path |
|
|
trace_path = Path(trace_dir) |
|
|
trace_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if save_trace: |
|
|
trace_file = trace_path / f"agent_trace_{time.strftime('%Y%m%d_%H%M%S')}.json" |
|
|
saved_trace = self.workflow_engine.save_trace_to_file(str(trace_file)) |
|
|
self.console.console.print(f"πΎ Trace saved to: {saved_trace}") |
|
|
|
|
|
if save_summary: |
|
|
summary_file = trace_path / f"agent_summary_{time.strftime('%Y%m%d_%H%M%S')}.json" |
|
|
saved_summary = self.workflow_engine.save_summary_to_file(str(summary_file)) |
|
|
self.console.console.print(f"π Summary saved to: {saved_summary}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
model = ChatOpenAI( |
|
|
model="google/gemini-2.5-flash", |
|
|
base_url="https://openrouter.ai/api/v1", |
|
|
temperature=0.7, |
|
|
api_key=os.environ["OPENROUTER_API_KEY"], |
|
|
) |
|
|
|
|
|
model = ChatAnthropic(model='claude-sonnet-4-5-20250929') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = AgentConfig( |
|
|
max_steps=15, |
|
|
max_conversation_length=30, |
|
|
retry_attempts=3, |
|
|
timeout_seconds=1200, |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
|
|
|
agent = CodeAgent(model=model, config=config, use_tool_manager=True, use_tool_selection=True) |
|
|
|
|
|
|
|
|
print("\nπ§ Tool Management Demo:") |
|
|
|
|
|
|
|
|
stats = agent.get_tool_statistics() |
|
|
print(f"π Tool Statistics: {stats}") |
|
|
|
|
|
|
|
|
try: |
|
|
print("π§ Loading MCP tools...") |
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
config_path = os.path.join(script_dir, "mcp_config.yaml") |
|
|
agent.add_mcp(config_path) |
|
|
print("β
MCP tools loaded successfully!") |
|
|
|
|
|
|
|
|
agent.show_mcp_status() |
|
|
|
|
|
|
|
|
all_tools = agent.list_tools("all", include_details=True) |
|
|
print(f"\nπ All tools loaded:") |
|
|
for tool in all_tools: |
|
|
print(f" β’ {tool['name']} ({tool['source']}) - {tool['description'][:50]}...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not load MCP tools: {e}") |
|
|
|
|
|
|
|
|
issues = agent.validate_tools() |
|
|
if any(issues.values()): |
|
|
print(f"β οΈ Tool validation issues: {issues}") |
|
|
else: |
|
|
print("β
All tools validated successfully!") |
|
|
|
|
|
print(f"\nπ¦ Available packages: {list(agent.get_all_packages().keys())}") |
|
|
|
|
|
print("\nπ Running agent with trace and summary saving...") |
|
|
result = agent.run( |
|
|
query=""" |
|
|
Use AlphaGenome MCP to analyze heart gene expression data to identify the causal gene |
|
|
for the variant chr11:116837649:T>G, associated with Hypoalphalipoproteinemia. My API key is: AIzaSyD1USDNy9WqfIROICB3FWI1wJHmkO2z21U. |
|
|
""", |
|
|
save_trace=True, |
|
|
save_summary=True, |
|
|
trace_dir="traces" |
|
|
) |
|
|
|
|
|
|
|
|
print("\nπ Execution Summary:") |
|
|
summary = agent.get_summary() |
|
|
print(f" Total steps: {summary.get('total_steps', 0)}") |
|
|
print(f" Code executions: {len(summary.get('code_executions', []))}") |
|
|
print(f" Observations: {len(summary.get('observations', []))}") |
|
|
print(f" Errors: {len(summary.get('errors', []))}") |
|
|
|
|
|
|
|
|
|
|
|
|