Spaces:
Runtime error
Runtime error
File size: 3,535 Bytes
129cd69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | """SQL agent."""
from typing import Any, Dict, List, Optional, Sequence
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
SQL_SUFFIX,
)
from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.tools import BaseTool
def create_sql_agent(
llm: BaseLanguageModel,
toolkit: SQLDatabaseToolkit,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = SQL_PREFIX,
suffix: Optional[str] = None,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
**kwargs: Any,
) -> AgentExecutor:
"""Construct an SQL agent from an LLM and tools."""
tools = toolkit.get_tools() + list(extra_tools)
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix or SQL_SUFFIX,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
messages = [
SystemMessage(content=prefix),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
input_variables = ["input", "agent_scratchpad"]
_prompt = ChatPromptTemplate(input_variables=input_variables, messages=messages)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)
|