Spaces:
Runtime error
Runtime error
| from typing import List, Optional, TypedDict, Union | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain_core.runnables import Runnable, RunnableParallel | |
| from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS | |
| from langchain.utilities.sql_database import SQLDatabase | |
| def _strip(text: str) -> str: | |
| return text.strip() | |
| class SQLInput(TypedDict): | |
| """Input for a SQL Chain.""" | |
| question: str | |
| class SQLInputWithTables(TypedDict): | |
| """Input for a SQL Chain.""" | |
| question: str | |
| table_names_to_use: List[str] | |
| def create_sql_query_chain( | |
| llm: BaseLanguageModel, | |
| db: SQLDatabase, | |
| prompt: Optional[BasePromptTemplate] = None, | |
| k: int = 5, | |
| ) -> Runnable[Union[SQLInput, SQLInputWithTables], str]: | |
| """Create a chain that generates SQL queries. | |
| *Security Note*: This chain generates SQL queries for the given database. | |
| The SQLDatabase class provides a get_table_info method that can be used | |
| to get column information as well as sample data from the table. | |
| To mitigate risk of leaking sensitive data, limit permissions | |
| to read and scope to the tables that are needed. | |
| Optionally, use the SQLInputWithTables input type to specify which tables | |
| are allowed to be accessed. | |
| Control access to who can submit requests to this chain. | |
| See https://python.langchain.com/docs/security for more information. | |
| Args: | |
| llm: The language model to use | |
| db: The SQLDatabase to generate the query for | |
| prompt: The prompt to use. If none is provided, will choose one | |
| based on dialect. Defaults to None. | |
| k: The number of results per select statement to return. Defaults to 5. | |
| Returns: | |
| A chain that takes in a question and generates a SQL query that answers | |
| that question. | |
| """ | |
| if prompt is not None: | |
| prompt_to_use = prompt | |
| elif db.dialect in SQL_PROMPTS: | |
| prompt_to_use = SQL_PROMPTS[db.dialect] | |
| else: | |
| prompt_to_use = PROMPT | |
| inputs = { | |
| "input": lambda x: x["question"] + "\nSQLQuery: ", | |
| "top_k": lambda _: k, | |
| "table_info": lambda x: db.get_table_info( | |
| table_names=x.get("table_names_to_use") | |
| ), | |
| } | |
| if "dialect" in prompt_to_use.input_variables: | |
| inputs["dialect"] = lambda _: (db.dialect, prompt_to_use) | |
| return ( | |
| RunnableParallel(inputs) | |
| | prompt_to_use | |
| | llm.bind(stop=["\nSQLResult:"]) | |
| | StrOutputParser() | |
| | _strip | |
| ) | |