alpha / managers /tools /tool_selector.py
yhzhang3's picture
first commit
f580ad3
"""
Tool Selector for CodeAct Agent.
Pure LLM-based tool selection mechanism similar to Biomni's prompt_based_retrieval.
"""
import re
from typing import Dict, List, Optional
from langchain_core.messages import HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel
class ToolSelector:
"""
LLM-based tool selection system inspired by Biomni's approach.
Uses an LLM to intelligently select the most relevant tools for a given task.
"""
def __init__(self, model: BaseChatModel):
"""
Initialize the ToolSelector.
Args:
model: The language model to use for tool selection
"""
self.model = model
def select_tools_for_task(self, query: str, available_tools: Dict[str, Dict], max_tools: int = 15) -> List[str]:
"""
Use LLM-based selection to choose the most relevant tools for a query.
Inspired by Biomni's prompt_based_retrieval mechanism.
Args:
query: The user's query/task description
available_tools: Dictionary of {tool_name: tool_info} available
max_tools: Maximum number of tools to select
Returns:
List of selected tool names
"""
if not available_tools:
return []
# Format tools for LLM prompt
tools_list = self._format_tools_for_prompt(available_tools)
# Create selection prompt (similar to Biomni's approach)
selection_prompt = f"""You are an expert biomedical research assistant. Your task is to select the most relevant tools to help answer a user's query.
USER QUERY: {query}
Below are the available tools. Select items that are directly or indirectly relevant to answering the query.
Be generous in your selection - include tools that might be useful for the task, even if they're not explicitly mentioned in the query.
It's better to include slightly more tools than to miss potentially useful ones.
AVAILABLE TOOLS:
{tools_list}
Select up to {max_tools} tools that would be most helpful for this task.
Respond with ONLY a comma-separated list of the exact tool names, like this:
tool_name_1, tool_name_2, tool_name_3
Selected tools:"""
try:
# Get LLM response
response = self.model.invoke([HumanMessage(content=selection_prompt)])
response_content = response.content.strip()
# Parse the response to extract tool names
selected_tools = self._parse_tool_selection_response(response_content, available_tools)
# Ensure we don't exceed max_tools
return selected_tools[:max_tools]
except Exception as e:
print(f"Error in LLM-based tool selection: {e}")
# Return all tools if LLM fails (no keyword fallback)
return list(available_tools.keys())[:max_tools]
def _format_tools_for_prompt(self, tools: Dict[str, Dict]) -> str:
"""Format tools for the LLM prompt."""
formatted = []
for i, (tool_name, tool_info) in enumerate(tools.items(), 1):
description = tool_info.get('description', 'No description available')
source = tool_info.get('source', 'unknown')
formatted.append(f"{i}. {tool_name} ({source}): {description}")
return "\n".join(formatted)
def _parse_tool_selection_response(self, response: str, available_tools: Dict[str, Dict]) -> List[str]:
"""Parse the LLM response to extract valid tool names."""
selected_tools = []
# Split by commas and clean up
tool_candidates = [name.strip() for name in response.split(',')]
for candidate in tool_candidates:
# Remove any extra characters, numbers, or formatting
clean_candidate = re.sub(r'^\d+\.\s*', '', candidate) # Remove "1. " prefixes
clean_candidate = clean_candidate.strip()
# Check if this matches any available tool (case-insensitive)
for tool_name in available_tools.keys():
if clean_candidate.lower() == tool_name.lower():
if tool_name not in selected_tools: # Avoid duplicates
selected_tools.append(tool_name)
break
return selected_tools