| import glob |
| import inspect |
| import json |
| import os |
| import signal |
| from collections.abc import Sequence |
| from functools import wraps |
| from multiprocessing import Process, Queue |
| from typing import Annotated, TypedDict |
|
|
| from langchain_core.messages import BaseMessage, SystemMessage, ToolMessage |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain_core.runnables import RunnableConfig |
| from langgraph.graph import END, StateGraph |
| from langgraph.graph.message import add_messages |
|
|
| from biomni.config import default_config |
| from biomni.env_desc import data_lake_dict, library_content_dict |
| from biomni.llm import get_llm |
| from biomni.model.retriever import ToolRetriever |
| from biomni.tool.tool_registry import ToolRegistry |
| from biomni.utils import ( |
| api_schema_to_langchain_tool, |
| function_to_api_schema, |
| pretty_print, |
| read_module2api, |
| ) |
|
|
|
|
| |
| class AgentState(TypedDict): |
| """The state of the agent.""" |
|
|
| |
| messages: Annotated[Sequence[BaseMessage], add_messages] |
|
|
|
|
| class react: |
| def __init__( |
| self, |
| path: str | None = None, |
| llm: str | None = None, |
| use_tool_retriever: bool | None = None, |
| timeout_seconds: int | None = None, |
| ): |
| |
| if path is None: |
| path = default_config.path |
| if llm is None: |
| llm = default_config.llm |
| if use_tool_retriever is None: |
| use_tool_retriever = default_config.use_tool_retriever |
| if timeout_seconds is None: |
| timeout_seconds = default_config.timeout_seconds |
|
|
| self.path = path |
| if not os.path.exists(path): |
| os.makedirs(path) |
| print(f"Created directory: {path}") |
| |
| else: |
| print(f"Data directory already exists: {path}, loading...") |
|
|
| module2api = read_module2api() |
|
|
| self.llm = get_llm(llm, config=default_config) |
| tools = [] |
| for module, api_list in module2api.items(): |
| print("Registering tools from module:", module) |
| tools += [api_schema_to_langchain_tool(api, mode="custom_tool", module_name=module) for api in api_list] |
| self.tools = tools |
| self.module2api = module2api |
| self.use_tool_retriever = use_tool_retriever |
|
|
| |
| self.data_lake_dict = data_lake_dict |
| self.library_content_dict = library_content_dict |
|
|
| if self.use_tool_retriever: |
| self.tool_registry = ToolRegistry(module2api) |
| self.retriever = ToolRetriever() |
|
|
| self.timeout_seconds = timeout_seconds |
|
|
| |
| self.tools = self._add_timeout_to_tools(self.tools) |
| self.prompt = "" |
| self.system_prompt = "" |
|
|
| def _add_timeout_to_tools(self, tools): |
| """Apply timeout wrapper to all tool functions using multiprocessing.""" |
|
|
| def create_timed_func(original_func, timeout): |
| """Factory function that creates a unique timed function for each tool.""" |
| tool_name = getattr(original_func, "__name__", "unknown") |
| |
|
|
| def process_func(func, args, kwargs, result_queue): |
| """Function to run in a separate process.""" |
| try: |
| result = func(*args, **kwargs) |
| result_queue.put(("success", result)) |
| except Exception as e: |
| result_queue.put(("error", str(e))) |
|
|
| @wraps(original_func) |
| def timed_func(*args, **kwargs): |
| |
| result_queue = Queue() |
|
|
| |
| proc = Process( |
| target=process_func, |
| args=(original_func, args, kwargs, result_queue), |
| ) |
| proc.start() |
|
|
| |
| proc.join(timeout) |
|
|
| |
| if proc.is_alive(): |
| print(f"TIMEOUT: Tool {tool_name} execution timed out after {timeout} seconds") |
| |
| proc.terminate() |
| proc.join(1) |
|
|
| |
| if proc.is_alive(): |
| os.kill(proc.pid, signal.SIGKILL) |
|
|
| return f"ERROR: Tool execution timed out after {timeout} seconds. Please try with simpler inputs or break your task into smaller steps." |
|
|
| |
| if not result_queue.empty(): |
| status, result = result_queue.get() |
| if status == "success": |
| return result |
| else: |
| return f"Error in tool execution: {result}" |
|
|
| return "Error: Tool execution completed but no result was returned" |
|
|
| return timed_func |
|
|
| wrapped_tools = [] |
| for tool in tools: |
| wrapped_tool = tool |
| wrapped_tool.func = create_timed_func(tool.func, self.timeout_seconds) |
| wrapped_tools.append(wrapped_tool) |
|
|
| return wrapped_tools |
|
|
| def add_tool(self, api): |
| function_code = inspect.getsource(api) |
| schema = function_to_api_schema(function_code, self.llm) |
| new_tool = api_schema_to_langchain_tool(schema, mode="custom_tool", module_name=api.__module__) |
|
|
| |
| wrapped_tools = self._add_timeout_to_tools([new_tool]) |
|
|
| |
| if wrapped_tools: |
| self.tools.append(wrapped_tools[0]) |
|
|
| def configure( |
| self, |
| plan=False, |
| reflect=False, |
| data_lake=False, |
| react_code_search=False, |
| library_access=False, |
| ): |
| data_lake_path = self.path + "/data_lake" |
| data_lake_content = glob.glob(data_lake_path + "/*") |
| data_lake_items = [x.split("/")[-1] for x in data_lake_content] |
|
|
| if react_code_search: |
| tools = [i for i in self.tools if i.name in ["run_python_repl", "search_google"]] |
|
|
| prompt_modifier = """ |
| You are a helpful biomedical assistant assigned with the task of problem-solving. |
| |
| You have access to two tools: |
| 1) run_python_repl: to write and run python code |
| 2) search_google: to search google for information |
| |
| You can use them to solve the problem. |
| """ |
| else: |
| tools = self.tools |
| if (not plan) and (not reflect): |
| prompt_modifier = """You are a helpful biologist and expert geneticist. |
| """ |
| elif plan and (not reflect): |
| prompt_modifier = """You are a helpful biologist and expert geneticist. |
| Given the question from the user, |
| - First, come up with a high level plan based on your understanding of the problem and available tools and record it in the Research Plan and Status. You can revise the plan later. |
| - Research Plan and Status should well organized and succinctly keep track of 1) high level plan (can be revised), 2) what steps have been done and what steps are in progress, 3) short results and conclusions of each step after it has been performed. Do not perform action in research plan. |
| - Research Plan and Status must only include progress that has been made by previous steps. It should not include results not directly confirmed by the previous observation. |
| - Follow the plan and try to achieve the goal as straightforwardly as possible. Use tools as necessary. |
| """ |
| elif (not plan) and reflect: |
| prompt_modifier = """You are a helpful biologist and expert geneticist. |
| In each round after the tool is used, conduct "reflection" step: reflect on the current state of the problem and the results of the last round. What does the observation mean? If there is an error, what caused the error and how to debug? |
| """ |
| else: |
| prompt_modifier = """You are a helpful biologist and expert geneticist. |
| Given the question from the user, |
| - First, come up with a high level plan based on your understanding of the problem and available tools and record it in the Research Plan and Status. You can revise the plan later. |
| - Research Plan and Status should well organized and succinctly keep track of 1) high level plan (can be revised), 2) what steps have been done and what steps are in progress, 3) short results and conclusions of each step after it has been performed. Do not perform action in research plan. |
| - Research Plan and Status must only include progress that has been made by previous steps. It should not include results not directly confirmed by the previous observation. |
| - Follow the plan and try to achieve the goal as straightforwardly as possible. Use tools as necessary. |
| In each round after the tool is used, conduct "reflection" step: reflect on the current state of the problem and the results of the last round. What does the observation mean? If there is an error, what caused the error and how to debug? |
| You have access to write_python_code and run_python_repl tool to write and run your own code if tools fail, or if the given tools are not enough. Please always make sure to write code when dealing with substantial data, including finding the length of long sequences or elements at different positions. |
| """ |
|
|
| if data_lake: |
| |
| data_lake_formatted = [] |
| for item in data_lake_items: |
| description = data_lake_dict.get(item, f"Data lake item: {item}") |
| data_lake_formatted.append(f"{item}: {description}") |
|
|
| prompt_modifier += """ |
| You can also access a biological data lake at the following path: {data_lake_path}. You can use the run_python_repl tool to write code to understand the data, process and utilize it for the task. |
| Here is the list of datasets with their descriptions: |
| ---- |
| {data_lake_formatted} |
| ---- |
| """.format( |
| data_lake_path=data_lake_path, |
| data_lake_formatted="\n".join(data_lake_formatted), |
| ) |
|
|
| if library_access: |
| |
| library_formatted = [] |
| for lib_name, lib_desc in library_content_dict.items(): |
| library_formatted.append(f"{lib_name}: {lib_desc}") |
|
|
| prompt_modifier += """ |
| You also have access to a list of software packages that can be used to perform various tasks. |
| You can use the run_python_repl tool to write code to access and utilize the library for the task. |
| Don't forget the import statement. |
| Here is the list of available libraries with their descriptions: |
| ---- |
| {library_formatted} |
| ---- |
| """.format(library_formatted="\n".join(library_formatted)) |
|
|
| print("=" * 25 + "System Prompt" + "=" * 25) |
| print(prompt_modifier) |
| self.system_prompt = prompt_modifier |
| self.prompt = ChatPromptTemplate.from_messages( |
| [ |
| ("system", prompt_modifier), |
| MessagesPlaceholder(variable_name="messages"), |
| ] |
| ) |
|
|
| |
| self.active_tools = tools |
|
|
| |
| self.app = self._create_custom_react_agent(self.llm, tools, self.prompt) |
|
|
| def _create_custom_react_agent(self, llm, tools, prompt): |
| """Create a custom ReAct agent using LangGraph.""" |
| |
| tools_by_name = {tool.name: tool for tool in tools} |
|
|
| |
| llm_with_tools = llm.bind_tools(tools) |
|
|
| |
| def call_model(state: AgentState, config: RunnableConfig = None): |
| """Node that calls the language model to get the next action.""" |
| system_message = SystemMessage(content=self.system_prompt) |
| messages = [system_message] + state["messages"] |
| response = llm_with_tools.invoke(messages, config=config) |
| return {"messages": [response]} |
|
|
| |
| def tool_node(state: AgentState): |
| """Node that executes tools based on the LLM's decisions.""" |
| outputs = [] |
| for tool_call in state["messages"][-1].tool_calls: |
| try: |
| tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"]) |
| outputs.append( |
| ToolMessage( |
| content=json.dumps(tool_result), |
| name=tool_call["name"], |
| tool_call_id=tool_call["id"], |
| ) |
| ) |
| except Exception as e: |
| |
| outputs.append( |
| ToolMessage( |
| content=json.dumps({"error": str(e)}), |
| name=tool_call["name"], |
| tool_call_id=tool_call["id"], |
| ) |
| ) |
| return {"messages": outputs} |
|
|
| |
| def should_continue(state: AgentState): |
| """Determine if we should continue running the graph or finish.""" |
| messages = state["messages"] |
| last_message = messages[-1] |
| |
| if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: |
| return "end" |
| |
| else: |
| return "continue" |
|
|
| |
| workflow = StateGraph(AgentState) |
|
|
| |
| workflow.add_node("agent", call_model) |
| workflow.add_node("tools", tool_node) |
|
|
| |
| workflow.set_entry_point("agent") |
|
|
| |
| workflow.add_conditional_edges( |
| "agent", |
| should_continue, |
| { |
| "continue": "tools", |
| "end": END, |
| }, |
| ) |
|
|
| |
| workflow.add_edge("tools", "agent") |
|
|
| |
| return workflow.compile() |
|
|
| def go(self, prompt): |
| """Execute the agent with the given prompt. |
| |
| Args: |
| prompt: The user's query |
| |
| """ |
| if self.use_tool_retriever: |
| |
| all_tools = self.tool_registry.tools if hasattr(self, "tool_registry") else [] |
|
|
| |
| data_lake_path = self.path + "/data_lake" |
| data_lake_content = glob.glob(data_lake_path + "/*") |
| data_lake_items = [x.split("/")[-1] for x in data_lake_content] |
|
|
| |
| data_lake_descriptions = [] |
| for item in data_lake_items: |
| description = self.data_lake_dict.get(item, f"Data lake item: {item}") |
| data_lake_descriptions.append({"name": item, "description": description}) |
|
|
| |
| library_descriptions = [] |
| for lib_name, lib_desc in self.library_content_dict.items(): |
| library_descriptions.append({"name": lib_name, "description": lib_desc}) |
|
|
| |
| resources = { |
| "tools": all_tools, |
| "data_lake": data_lake_descriptions, |
| "libraries": library_descriptions, |
| } |
|
|
| |
| selected_resources = self.retriever.prompt_based_retrieval(prompt, resources, llm=self.llm) |
| print("Using prompt-based retrieval with the agent's LLM") |
|
|
| |
| print("\nSelected tools:") |
| for tool in selected_resources["tools"]: |
| if isinstance(tool, dict): |
| print(f"- {tool.get('name', 'Unknown')}: {tool.get('description', '')}") |
| else: |
| print(f"- {getattr(tool, 'name', str(tool))}: {getattr(tool, 'description', '')}") |
|
|
| print("\nSelected data lake items:") |
| for item in selected_resources["data_lake"]: |
| if isinstance(item, dict): |
| name = item.get("name", "Unknown") |
| description = self.data_lake_dict.get(name, f"Data lake item: {name}") |
| print(f"- {name}: {description}") |
| elif isinstance(item, str) and ": " in item: |
| |
| print(f"- {item}") |
| else: |
| description = self.data_lake_dict.get(item, f"Data lake item: {item}") |
| print(f"- {item}: {description}") |
|
|
| print("\nSelected libraries:") |
| for lib in selected_resources["libraries"]: |
| if isinstance(lib, dict): |
| print(f"- {lib.get('name', 'Unknown')}: {lib.get('description', '')}") |
| else: |
| print(f"- {lib}") |
|
|
| |
| tool_names = [ |
| tool["name"] if isinstance(tool, dict) else getattr(tool, "name", str(tool)) |
| for tool in selected_resources["tools"] |
| ] |
| retrieved_list_of_tools = [] |
|
|
| |
| for tool_name in tool_names: |
| |
| matching_tools = [t for t in self.tools if getattr(t, "name", None) == tool_name] |
| if matching_tools: |
| retrieved_list_of_tools.append(matching_tools[0]) |
|
|
| |
| if len([i for i in retrieved_list_of_tools if i.name == "run_python_repl"]) == 0: |
| retrieved_list_of_tools = retrieved_list_of_tools + [ |
| i for i in self.tools if i.name == "run_python_repl" |
| ] |
|
|
| print("Retrieved tools: \n" + "\n".join([l.name + ": " + l.description for l in retrieved_list_of_tools])) |
| |
| self.app = self._create_custom_react_agent(self.llm, retrieved_list_of_tools, self.prompt) |
|
|
| |
| config = {"recursion_limit": 50} |
| inputs = {"messages": [("user", prompt)]} |
| self.log = [] |
| for s in self.app.stream(inputs, stream_mode="values", config=config): |
| message = s["messages"][-1] |
| out = pretty_print(message) |
| self.log.append(out) |
| return self.log, s["messages"][-1].content |
|
|
| def result_formatting(self, output_class, task_intention): |
| self.format_check_prompt = ChatPromptTemplate.from_messages( |
| [ |
| ( |
| "system", |
| ( |
| "You are evaluateGPT, tasked with extract and parse the task output based on the history of an agent. " |
| "Review the entire history of messages provided. " |
| "Here is the task output requirement: \n" |
| f"'{task_intention.replace('{', '{{').replace('}', '}}')}'.\n" |
| ), |
| ), |
| ("placeholder", "{messages}"), |
| ] |
| ) |
|
|
| checker_llm = self.format_check_prompt | self.llm.with_structured_output(output_class) |
| result = checker_llm.invoke({"messages": [("user", str(self.log))]}).dict() |
| return result |
|
|