Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import sqlite3 | |
| import warnings | |
| import gradio as gr | |
| import pandas as pd | |
| from schema import schema | |
| from langchain_nvidia_ai_endpoints import ChatNVIDIA | |
| warnings.filterwarnings("ignore") | |
| API_KEY = "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" | |
| db_path = "wash_db.db" | |
| client = ChatNVIDIA( | |
| model="deepseek-ai/deepseek-r1", | |
| api_key=API_KEY, | |
| temperature=0.1, | |
| top_p=1, | |
| max_tokens=4096, | |
| ) | |
| def get_table_names(schema: str): | |
| return re.findall(r'TABLE (\w+)', schema) | |
| def get_table_columns(schema: str, table: str): | |
| m = re.search(rf'TABLE {table} \((.*?)\)', schema, re.DOTALL) | |
| if m: | |
| cols_block = m.group(1) | |
| cols = re.findall(r'(\w+)', cols_block) | |
| return [col for col in cols if col.lower() not in {"int", "primary", "key", "string", "bit", "real", "references"}] | |
| return [] | |
| def agent_select_table(user_query, schema): | |
| tables = get_table_names(schema) | |
| # First, try longest keyword containment in table name | |
| best = "" | |
| best_len = 0 | |
| for table in tables: | |
| for word in user_query.lower().split(): | |
| if word in table.lower() and len(word) > best_len: | |
| best = table | |
| best_len = len(word) | |
| if best: | |
| return best | |
| # fallback: first table | |
| return tables[0] | |
| def agent_select_columns(user_query, table, schema): | |
| columns = get_table_columns(schema, table) | |
| selected = [] | |
| for col in columns: | |
| if any(word in col.lower() for word in user_query.lower().split()): | |
| selected.append(col) | |
| return selected if selected else columns # fallback all columns | |
| def build_sql_prompt(table, columns, schema, user_question, error_reason=None): | |
| prompt = ( | |
| f"You are an expert SQL assistant.\n" | |
| f"Schema: {schema}\n" | |
| # f"Columns: {', '.join(columns)}\n" | |
| f"User question: {user_question}\n" | |
| "Write a valid SQLite SQL query answering the question using only the given table and columns.\n" | |
| ) | |
| if error_reason: | |
| prompt += f"The previous SQL query failed with the error: {error_reason}\nPlease fix and regenerate the SQL only." | |
| return prompt | |
| def extract_sql_query(text): | |
| patterns = [ | |
| r"```sql\n(.*?)```", | |
| r"```\n(.*?)```", | |
| r"```(.*?)```", | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| # Else, look for SELECT...; | |
| match = re.search(r"(SELECT|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER).*?;", text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| return match.group(0).strip() | |
| lines = text.split('\n') | |
| sql_lines = [l for l in lines if any(k in l.upper() for k in ['SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE'])] | |
| if sql_lines: | |
| return ' '.join(sql_lines) | |
| return text.strip() | |
| def execute_sql_query(sql_query, db_path=db_path): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| df = pd.read_sql_query(sql_query, conn) | |
| conn.close() | |
| return df, None | |
| except Exception as e: | |
| return None, str(e) | |
| def summarize_with_llm(table, columns, data, user_query): | |
| preview = data.head(5).to_markdown(index=False) if data is not None and not data.empty else "No data returned." | |
| prompt = ( | |
| f"User query: {user_query}\n" | |
| f"SQL result preview \n{preview}\n" | |
| f"Summarize the result, referencing the user query and the preview.)." | |
| ) | |
| resp = client.invoke([{"role": "user", "content": prompt}]) | |
| return getattr(resp, "content", resp) if hasattr(resp, "content") else str(resp) | |
| # def full_pipeline(user_question): | |
| # table = agent_select_table(user_question, schema) | |
| # columns = agent_select_columns(user_question, table, schema) | |
| # yield { | |
| # table_output: gr.update(value=table), | |
| # columns_output: gr.update(value=", ".join(columns)), | |
| # } | |
| # sql_prompt = build_sql_prompt(table, columns, user_question) | |
| # sql_query, error = "", None | |
| # # Error-handling and retry loop | |
| # for _ in range(5): | |
| # llm_resp = client.invoke([{"role": "user", "content": sql_prompt}]) | |
| # llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp) | |
| # sql_query = extract_sql_query(llm_text) | |
| # results_df, error = execute_sql_query(sql_query) | |
| # if not error: | |
| # break | |
| # sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error) | |
| # # Summarize | |
| # summary = summarize_with_llm(table, columns, results_df, user_question) | |
| # # Format outputs | |
| # columns_view = ", ".join(columns) | |
| # sql_view = f"```sql\n{sql_query}\n```" | |
| # status_view = f"Success" if not error else f"Query error: {error}" | |
| # out_df = results_df if results_df is not None else pd.DataFrame() | |
| # return sql_view, status_view, summary, table, columns_view, out_df | |
| def full_pipeline_stream(user_question): | |
| yield "Identifying relevant table and columns...", "", "", "", "", pd.DataFrame() | |
| table = agent_select_table(user_question, schema) | |
| columns = agent_select_columns(user_question, table, schema) | |
| yield f"Table '{table}' selected.", "", "", table, ", ".join(columns), pd.DataFrame() | |
| sql_prompt = build_sql_prompt(table, columns, user_question) | |
| sql_query, error = "", None | |
| for _ in range(5): | |
| yield f"Generating SQL (attempt {_+1})...", "", "", table, ", ".join(columns), pd.DataFrame() | |
| llm_resp = client.invoke([{"role": "user", "content": sql_prompt}]) | |
| llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp) | |
| sql_query = extract_sql_query(llm_text) | |
| results_df, error = execute_sql_query(sql_query) | |
| if not error: | |
| yield f"SQL executed successfully.", f"``````", "", table, ", ".join(columns), results_df | |
| break | |
| sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error) | |
| yield f"Retrying due to error: {error}", f"``````", "", table, ", ".join(columns), pd.DataFrame() | |
| if not error: | |
| summary = summarize_with_llm(table, columns, results_df, user_question) | |
| yield "Summarization complete.", f"``````", summary, table, ", ".join(columns), results_df | |
| else: | |
| yield f"Final error: {error}", f"``````", "No summary due to error.", table, ", ".join(columns), pd.DataFrame() | |
| def full_pipeline(user_question): | |
| # Step 1: Identify table and columns first | |
| # yield "", "", "", "", "", pd.DataFrame() | |
| table = agent_select_table(user_question, schema) | |
| columns = agent_select_columns(user_question, table, schema) | |
| # Immediately return only these two visible outputs | |
| yield { | |
| table_output: gr.update(value=table), | |
| columns_output: gr.update(value=", ".join(columns)), | |
| } | |
| # Step 2: Continue with downstream pipeline | |
| sql_prompt = build_sql_prompt(table, columns, schema, user_question) | |
| sql_query, error = "", None | |
| for _ in range(5): | |
| llm_resp = client.invoke([{"role": "user", "content": sql_prompt}]) | |
| llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp) | |
| sql_query = extract_sql_query(llm_text) | |
| results_df, error = execute_sql_query(sql_query) | |
| if not error: | |
| break | |
| sql_prompt = build_sql_prompt(table, columns, schema, user_question, error_reason=error) | |
| sql_view = f"\n{sql_query.strip()}\n" | |
| status_view = "Success" if not error else f"Query error: {error}" | |
| out_df = results_df if results_df is not None else pd.DataFrame() | |
| yield { | |
| sql_output: gr.update(value=sql_view), | |
| status_output: gr.update(value=status_view), | |
| results_output: gr.update(value=out_df) | |
| } | |
| summary = summarize_with_llm(table, columns, results_df, user_question).strip() | |
| yield { | |
| # sql_output: gr.update(value=sql_view), | |
| summary_output: gr.update(value=summary), | |
| } | |
| with gr.Blocks(title="NL2SQL Pipeline)") as gradio_interface: | |
| gr.Markdown("## NL2SQL Pipeline ") | |
| gr.Markdown("Enter a question about the water supply database. The agent will select relevant table/columns, generate and retry SQL on error, show results and a grounded summary.") | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Enter your natural language question", lines=3) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Generate, Execute & Summarize", variant="primary") | |
| with gr.Row(): | |
| table_output = gr.Textbox(label="Table Used", lines=1) | |
| columns_output = gr.Textbox(label="Columns Used", lines=2) | |
| with gr.Row(): | |
| sql_output = gr.Textbox(label="Generated SQL Query", lines=5) | |
| with gr.Row(): | |
| status_output = gr.Textbox(label="Execution Status", lines=2) | |
| with gr.Row(): | |
| results_output = gr.Dataframe(label="Query Results", interactive=False) | |
| with gr.Row(): | |
| summary_output = gr.Textbox(label="LLM-Grounded Summary", lines=5) | |
| with gr.Row(): | |
| abort_btn = gr.Button("Abort / Stop Task") | |
| running_event=submit_btn.click( | |
| fn=full_pipeline, | |
| inputs=input_text, | |
| outputs=[sql_output, status_output, summary_output, table_output, columns_output, results_output] | |
| ) | |
| abort_btn.click( | |
| None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[running_event], | |
| queue=False | |
| ) | |
| if __name__ == "__main__": | |
| gradio_interface.launch() | |