File size: 4,280 Bytes
f580ad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Tool Selector for CodeAct Agent.
Pure LLM-based tool selection mechanism similar to Biomni's prompt_based_retrieval.
"""

import re
from typing import Dict, List, Optional
from langchain_core.messages import HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel


class ToolSelector:
    """
    LLM-based tool selection system inspired by Biomni's approach.
    Uses an LLM to intelligently select the most relevant tools for a given task.
    """

    def __init__(self, model: BaseChatModel):
        """
        Initialize the ToolSelector.

        Args:
            model: The language model to use for tool selection
        """
        self.model = model

    def select_tools_for_task(self, query: str, available_tools: Dict[str, Dict], max_tools: int = 15) -> List[str]:
        """
        Use LLM-based selection to choose the most relevant tools for a query.
        Inspired by Biomni's prompt_based_retrieval mechanism.

        Args:
            query: The user's query/task description
            available_tools: Dictionary of {tool_name: tool_info} available
            max_tools: Maximum number of tools to select

        Returns:
            List of selected tool names
        """
        if not available_tools:
            return []

        # Format tools for LLM prompt
        tools_list = self._format_tools_for_prompt(available_tools)

        # Create selection prompt (similar to Biomni's approach)
        selection_prompt = f"""You are an expert biomedical research assistant. Your task is to select the most relevant tools to help answer a user's query.

USER QUERY: {query}

Below are the available tools. Select items that are directly or indirectly relevant to answering the query.
Be generous in your selection - include tools that might be useful for the task, even if they're not explicitly mentioned in the query.
It's better to include slightly more tools than to miss potentially useful ones.

AVAILABLE TOOLS:
{tools_list}

Select up to {max_tools} tools that would be most helpful for this task.

Respond with ONLY a comma-separated list of the exact tool names, like this:
tool_name_1, tool_name_2, tool_name_3

Selected tools:"""

        try:
            # Get LLM response
            response = self.model.invoke([HumanMessage(content=selection_prompt)])
            response_content = response.content.strip()

            # Parse the response to extract tool names
            selected_tools = self._parse_tool_selection_response(response_content, available_tools)

            # Ensure we don't exceed max_tools
            return selected_tools[:max_tools]

        except Exception as e:
            print(f"Error in LLM-based tool selection: {e}")
            # Return all tools if LLM fails (no keyword fallback)
            return list(available_tools.keys())[:max_tools]

    def _format_tools_for_prompt(self, tools: Dict[str, Dict]) -> str:
        """Format tools for the LLM prompt."""
        formatted = []
        for i, (tool_name, tool_info) in enumerate(tools.items(), 1):
            description = tool_info.get('description', 'No description available')
            source = tool_info.get('source', 'unknown')
            formatted.append(f"{i}. {tool_name} ({source}): {description}")
        return "\n".join(formatted)

    def _parse_tool_selection_response(self, response: str, available_tools: Dict[str, Dict]) -> List[str]:
        """Parse the LLM response to extract valid tool names."""
        selected_tools = []

        # Split by commas and clean up
        tool_candidates = [name.strip() for name in response.split(',')]

        for candidate in tool_candidates:
            # Remove any extra characters, numbers, or formatting
            clean_candidate = re.sub(r'^\d+\.\s*', '', candidate)  # Remove "1. " prefixes
            clean_candidate = clean_candidate.strip()

            # Check if this matches any available tool (case-insensitive)
            for tool_name in available_tools.keys():
                if clean_candidate.lower() == tool_name.lower():
                    if tool_name not in selected_tools:  # Avoid duplicates
                        selected_tools.append(tool_name)
                    break

        return selected_tools