import os import re import sqlite3 import warnings import gradio as gr import pandas as pd from schema import schema from langchain_nvidia_ai_endpoints import ChatNVIDIA warnings.filterwarnings("ignore") API_KEY = "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" db_path = "wash_db.db" client = ChatNVIDIA( model="deepseek-ai/deepseek-r1", api_key=API_KEY, temperature=0.1, top_p=1, max_tokens=4096, ) def get_table_names(schema: str): return re.findall(r'TABLE (\w+)', schema) def get_table_columns(schema: str, table: str): m = re.search(rf'TABLE {table} \((.*?)\)', schema, re.DOTALL) if m: cols_block = m.group(1) cols = re.findall(r'(\w+)', cols_block) return [col for col in cols if col.lower() not in {"int", "primary", "key", "string", "bit", "real", "references"}] return [] def agent_select_table(user_query, schema): tables = get_table_names(schema) # First, try longest keyword containment in table name best = "" best_len = 0 for table in tables: for word in user_query.lower().split(): if word in table.lower() and len(word) > best_len: best = table best_len = len(word) if best: return best # fallback: first table return tables[0] def agent_select_columns(user_query, table, schema): columns = get_table_columns(schema, table) selected = [] for col in columns: if any(word in col.lower() for word in user_query.lower().split()): selected.append(col) return selected if selected else columns # fallback all columns def build_sql_prompt(table, columns, schema, user_question, error_reason=None): prompt = ( f"You are an expert SQL assistant.\n" f"Schema: {schema}\n" # f"Columns: {', '.join(columns)}\n" f"User question: {user_question}\n" "Write a valid SQLite SQL query answering the question using only the given table and columns.\n" ) if error_reason: prompt += f"The previous SQL query failed with the error: {error_reason}\nPlease fix and regenerate the SQL only." return prompt def extract_sql_query(text): patterns = [ r"```sql\n(.*?)```", r"```\n(.*?)```", r"```(.*?)```", ] for pattern in patterns: match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() # Else, look for SELECT...; match = re.search(r"(SELECT|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER).*?;", text, re.DOTALL | re.IGNORECASE) if match: return match.group(0).strip() lines = text.split('\n') sql_lines = [l for l in lines if any(k in l.upper() for k in ['SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE'])] if sql_lines: return ' '.join(sql_lines) return text.strip() def execute_sql_query(sql_query, db_path=db_path): try: conn = sqlite3.connect(db_path) df = pd.read_sql_query(sql_query, conn) conn.close() return df, None except Exception as e: return None, str(e) def summarize_with_llm(table, columns, data, user_query): preview = data.head(5).to_markdown(index=False) if data is not None and not data.empty else "No data returned." prompt = ( f"User query: {user_query}\n" f"SQL result preview \n{preview}\n" f"Summarize the result, referencing the user query and the preview.)." ) resp = client.invoke([{"role": "user", "content": prompt}]) return getattr(resp, "content", resp) if hasattr(resp, "content") else str(resp) # def full_pipeline(user_question): # table = agent_select_table(user_question, schema) # columns = agent_select_columns(user_question, table, schema) # yield { # table_output: gr.update(value=table), # columns_output: gr.update(value=", ".join(columns)), # } # sql_prompt = build_sql_prompt(table, columns, user_question) # sql_query, error = "", None # # Error-handling and retry loop # for _ in range(5): # llm_resp = client.invoke([{"role": "user", "content": sql_prompt}]) # llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp) # sql_query = extract_sql_query(llm_text) # results_df, error = execute_sql_query(sql_query) # if not error: # break # sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error) # # Summarize # summary = summarize_with_llm(table, columns, results_df, user_question) # # Format outputs # columns_view = ", ".join(columns) # sql_view = f"```sql\n{sql_query}\n```" # status_view = f"Success" if not error else f"Query error: {error}" # out_df = results_df if results_df is not None else pd.DataFrame() # return sql_view, status_view, summary, table, columns_view, out_df def full_pipeline_stream(user_question): yield "Identifying relevant table and columns...", "", "", "", "", pd.DataFrame() table = agent_select_table(user_question, schema) columns = agent_select_columns(user_question, table, schema) yield f"Table '{table}' selected.", "", "", table, ", ".join(columns), pd.DataFrame() sql_prompt = build_sql_prompt(table, columns, user_question) sql_query, error = "", None for _ in range(5): yield f"Generating SQL (attempt {_+1})...", "", "", table, ", ".join(columns), pd.DataFrame() llm_resp = client.invoke([{"role": "user", "content": sql_prompt}]) llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp) sql_query = extract_sql_query(llm_text) results_df, error = execute_sql_query(sql_query) if not error: yield f"SQL executed successfully.", f"``````", "", table, ", ".join(columns), results_df break sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error) yield f"Retrying due to error: {error}", f"``````", "", table, ", ".join(columns), pd.DataFrame() if not error: summary = summarize_with_llm(table, columns, results_df, user_question) yield "Summarization complete.", f"``````", summary, table, ", ".join(columns), results_df else: yield f"Final error: {error}", f"``````", "No summary due to error.", table, ", ".join(columns), pd.DataFrame() def full_pipeline(user_question): # Step 1: Identify table and columns first # yield "", "", "", "", "", pd.DataFrame() table = agent_select_table(user_question, schema) columns = agent_select_columns(user_question, table, schema) # Immediately return only these two visible outputs yield { table_output: gr.update(value=table), columns_output: gr.update(value=", ".join(columns)), } # Step 2: Continue with downstream pipeline sql_prompt = build_sql_prompt(table, columns, schema, user_question) sql_query, error = "", None for _ in range(5): llm_resp = client.invoke([{"role": "user", "content": sql_prompt}]) llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp) sql_query = extract_sql_query(llm_text) results_df, error = execute_sql_query(sql_query) if not error: break sql_prompt = build_sql_prompt(table, columns, schema, user_question, error_reason=error) sql_view = f"\n{sql_query.strip()}\n" status_view = "Success" if not error else f"Query error: {error}" out_df = results_df if results_df is not None else pd.DataFrame() yield { sql_output: gr.update(value=sql_view), status_output: gr.update(value=status_view), results_output: gr.update(value=out_df) } summary = summarize_with_llm(table, columns, results_df, user_question).strip() yield { # sql_output: gr.update(value=sql_view), summary_output: gr.update(value=summary), } with gr.Blocks(title="NL2SQL Pipeline)") as gradio_interface: gr.Markdown("## NL2SQL Pipeline ") gr.Markdown("Enter a question about the water supply database. The agent will select relevant table/columns, generate and retry SQL on error, show results and a grounded summary.") with gr.Row(): input_text = gr.Textbox(label="Enter your natural language question", lines=3) with gr.Row(): submit_btn = gr.Button("Generate, Execute & Summarize", variant="primary") with gr.Row(): table_output = gr.Textbox(label="Table Used", lines=1) columns_output = gr.Textbox(label="Columns Used", lines=2) with gr.Row(): sql_output = gr.Textbox(label="Generated SQL Query", lines=5) with gr.Row(): status_output = gr.Textbox(label="Execution Status", lines=2) with gr.Row(): results_output = gr.Dataframe(label="Query Results", interactive=False) with gr.Row(): summary_output = gr.Textbox(label="LLM-Grounded Summary", lines=5) with gr.Row(): abort_btn = gr.Button("Abort / Stop Task") running_event=submit_btn.click( fn=full_pipeline, inputs=input_text, outputs=[sql_output, status_output, summary_output, table_output, columns_output, results_output] ) abort_btn.click( None, inputs=None, outputs=None, cancels=[running_event], queue=False ) if __name__ == "__main__": gradio_interface.launch()