from typing import TYPE_CHECKING from langchain.chains import create_sql_query_chain from langchain_core.prompts import PromptTemplate from langflow.base.chains.model import LCChainComponent from langflow.field_typing import Message from langflow.inputs import HandleInput, IntInput, MultilineInput from langflow.template import Output if TYPE_CHECKING: from langchain_core.runnables import Runnable class SQLGeneratorComponent(LCChainComponent): display_name = "Natural Language to SQL" description = "Generate SQL from natural language." name = "SQLGenerator" legacy: bool = True icon = "LangChain" inputs = [ MultilineInput( name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True, ), HandleInput( name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True, ), HandleInput( name="db", display_name="SQLDatabase", input_types=["SQLDatabase"], required=True, ), IntInput( name="top_k", display_name="Top K", info="The number of results per select statement to return.", value=5, ), MultilineInput( name="prompt", display_name="Prompt", info="The prompt must contain `{question}`.", ), ] outputs = [Output(display_name="Text", name="text", method="invoke_chain")] def invoke_chain(self) -> Message: prompt_template = PromptTemplate.from_template(template=self.prompt) if self.prompt else None if self.top_k < 1: msg = "Top K must be greater than 0." raise ValueError(msg) if not prompt_template: sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, k=self.top_k) else: # Check if {question} is in the prompt if "{question}" not in prompt_template.template or "question" not in prompt_template.input_variables: msg = "Prompt must contain `{question}` to be used with Natural Language to SQL." raise ValueError(msg) sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k) query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()} response = query_writer.invoke( {"question": self.input_value}, config={"callbacks": self.get_langchain_callbacks()}, ) query = response.get("query") self.status = query return query