hmm404 commited on
Commit
2e19fe3
·
verified ·
1 Parent(s): 530a106

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -242
app.py CHANGED
@@ -1,242 +1,242 @@
1
- import os
2
- import re
3
- import sqlite3
4
- import warnings
5
- import gradio as gr
6
- import pandas as pd
7
- from schema import schema
8
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
9
-
10
- warnings.filterwarnings("ignore")
11
- API_KEY = "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"
12
- db_path = "wash_db.db"
13
-
14
- client = ChatNVIDIA(
15
- model="deepseek-ai/deepseek-r1",
16
- api_key=API_KEY,
17
- temperature=0.1,
18
- top_p=1,
19
- max_tokens=4096,
20
- )
21
-
22
- def get_table_names(schema: str):
23
- return re.findall(r'TABLE (\w+)', schema)
24
-
25
- def get_table_columns(schema: str, table: str):
26
- m = re.search(rf'TABLE {table} \((.*?)\)', schema, re.DOTALL)
27
- if m:
28
- cols_block = m.group(1)
29
- cols = re.findall(r'(\w+)', cols_block)
30
- return [col for col in cols if col.lower() not in {"int", "primary", "key", "string", "bit", "real", "references"}]
31
- return []
32
-
33
- def agent_select_table(user_query, schema):
34
- tables = get_table_names(schema)
35
- # First, try longest keyword containment in table name
36
- best = ""
37
- best_len = 0
38
- for table in tables:
39
- for word in user_query.lower().split():
40
- if word in table.lower() and len(word) > best_len:
41
- best = table
42
- best_len = len(word)
43
- if best:
44
- return best
45
- # fallback: first table
46
- return tables[0]
47
-
48
- def agent_select_columns(user_query, table, schema):
49
- columns = get_table_columns(schema, table)
50
- selected = []
51
- for col in columns:
52
- if any(word in col.lower() for word in user_query.lower().split()):
53
- selected.append(col)
54
- return selected if selected else columns # fallback all columns
55
-
56
- def build_sql_prompt(table, columns, schema, user_question, error_reason=None):
57
- prompt = (
58
- f"You are an expert SQL assistant.\n"
59
- f"Schema: {schema}\n"
60
- # f"Columns: {', '.join(columns)}\n"
61
- f"User question: {user_question}\n"
62
- "Write a valid SQLite SQL query answering the question using only the given table and columns.\n"
63
- )
64
- if error_reason:
65
- prompt += f"The previous SQL query failed with the error: {error_reason}\nPlease fix and regenerate the SQL only."
66
- return prompt
67
-
68
- def extract_sql_query(text):
69
- patterns = [
70
- r"```sql\n(.*?)```",
71
- r"```\n(.*?)```",
72
- r"```(.*?)```",
73
- ]
74
-
75
- for pattern in patterns:
76
- match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
77
- if match:
78
- return match.group(1).strip()
79
- # Else, look for SELECT...;
80
- match = re.search(r"(SELECT|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER).*?;", text, re.DOTALL | re.IGNORECASE)
81
- if match:
82
- return match.group(0).strip()
83
- lines = text.split('\n')
84
- sql_lines = [l for l in lines if any(k in l.upper() for k in ['SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE'])]
85
- if sql_lines:
86
- return ' '.join(sql_lines)
87
- return text.strip()
88
-
89
- def execute_sql_query(sql_query, db_path=db_path):
90
- try:
91
- conn = sqlite3.connect(db_path)
92
- df = pd.read_sql_query(sql_query, conn)
93
- conn.close()
94
- return df, None
95
- except Exception as e:
96
- return None, str(e)
97
-
98
- def summarize_with_llm(table, columns, data, user_query):
99
- preview = data.head(5).to_markdown(index=False) if data is not None and not data.empty else "No data returned."
100
- prompt = (
101
- f"User query: {user_query}\n"
102
- f"SQL result preview \n{preview}\n"
103
- f"Summarize the result, referencing the user query and the preview.)."
104
- )
105
- resp = client.invoke([{"role": "user", "content": prompt}])
106
- return getattr(resp, "content", resp) if hasattr(resp, "content") else str(resp)
107
-
108
- # def full_pipeline(user_question):
109
- # table = agent_select_table(user_question, schema)
110
- # columns = agent_select_columns(user_question, table, schema)
111
- # yield {
112
- # table_output: gr.update(value=table),
113
- # columns_output: gr.update(value=", ".join(columns)),
114
- # }
115
- # sql_prompt = build_sql_prompt(table, columns, user_question)
116
- # sql_query, error = "", None
117
-
118
- # # Error-handling and retry loop
119
- # for _ in range(5):
120
- # llm_resp = client.invoke([{"role": "user", "content": sql_prompt}])
121
- # llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp)
122
- # sql_query = extract_sql_query(llm_text)
123
- # results_df, error = execute_sql_query(sql_query)
124
- # if not error:
125
- # break
126
- # sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error)
127
- # # Summarize
128
- # summary = summarize_with_llm(table, columns, results_df, user_question)
129
- # # Format outputs
130
- # columns_view = ", ".join(columns)
131
- # sql_view = f"```sql\n{sql_query}\n```"
132
- # status_view = f"Success" if not error else f"Query error: {error}"
133
- # out_df = results_df if results_df is not None else pd.DataFrame()
134
- # return sql_view, status_view, summary, table, columns_view, out_df
135
-
136
- def full_pipeline_stream(user_question):
137
- yield "Identifying relevant table and columns...", "", "", "", "", pd.DataFrame()
138
- table = agent_select_table(user_question, schema)
139
- columns = agent_select_columns(user_question, table, schema)
140
- yield f"Table '{table}' selected.", "", "", table, ", ".join(columns), pd.DataFrame()
141
-
142
- sql_prompt = build_sql_prompt(table, columns, user_question)
143
- sql_query, error = "", None
144
-
145
- for _ in range(5):
146
- yield f"Generating SQL (attempt {_+1})...", "", "", table, ", ".join(columns), pd.DataFrame()
147
- llm_resp = client.invoke([{"role": "user", "content": sql_prompt}])
148
- llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp)
149
- sql_query = extract_sql_query(llm_text)
150
- results_df, error = execute_sql_query(sql_query)
151
- if not error:
152
- yield f"SQL executed successfully.", f"``````", "", table, ", ".join(columns), results_df
153
- break
154
- sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error)
155
- yield f"Retrying due to error: {error}", f"``````", "", table, ", ".join(columns), pd.DataFrame()
156
-
157
- if not error:
158
- summary = summarize_with_llm(table, columns, results_df, user_question)
159
- yield "Summarization complete.", f"``````", summary, table, ", ".join(columns), results_df
160
- else:
161
- yield f"Final error: {error}", f"``````", "No summary due to error.", table, ", ".join(columns), pd.DataFrame()
162
- def full_pipeline(user_question):
163
- # Step 1: Identify table and columns first
164
- yield "", "", "", "", "", pd.DataFrame()
165
- table = agent_select_table(user_question, schema)
166
- columns = agent_select_columns(user_question, table, schema)
167
-
168
- # Immediately return only these two visible outputs
169
- yield {
170
- table_output: gr.update(value=table),
171
- columns_output: gr.update(value=", ".join(columns)),
172
- }
173
-
174
- # Step 2: Continue with downstream pipeline
175
- sql_prompt = build_sql_prompt(table, columns, schema, user_question)
176
- sql_query, error = "", None
177
-
178
- for _ in range(5):
179
- llm_resp = client.invoke([{"role": "user", "content": sql_prompt}])
180
- llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp)
181
- sql_query = extract_sql_query(llm_text)
182
- results_df, error = execute_sql_query(sql_query)
183
- if not error:
184
- break
185
- sql_prompt = build_sql_prompt(table, columns, schema, user_question, error_reason=error)
186
-
187
- sql_view = f"\n{sql_query.strip()}\n"
188
- status_view = "Success" if not error else f"Query error: {error}"
189
- out_df = results_df if results_df is not None else pd.DataFrame()
190
- yield {
191
- sql_output: gr.update(value=sql_view),
192
- status_output: gr.update(value=status_view),
193
- results_output: gr.update(value=out_df)
194
-
195
- }
196
- summary = summarize_with_llm(table, columns, results_df, user_question).strip()
197
-
198
-
199
-
200
-
201
- yield {
202
- # sql_output: gr.update(value=sql_view),
203
-
204
- summary_output: gr.update(value=summary),
205
-
206
- }
207
-
208
-
209
- with gr.Blocks(title="NL2SQL Pipeline)") as gradio_interface:
210
- gr.Markdown("## NL2SQL Pipeline ")
211
- gr.Markdown("Enter a question about the water supply database. The agent will select relevant table/columns, generate and retry SQL on error, show results and a grounded summary.")
212
- with gr.Row():
213
- input_text = gr.Textbox(label="Enter your natural language question", lines=3)
214
- with gr.Row():
215
- submit_btn = gr.Button("Generate, Execute & Summarize", variant="primary")
216
- with gr.Row():
217
- table_output = gr.Textbox(label="Table Used", lines=1)
218
- columns_output = gr.Textbox(label="Columns Used", lines=2)
219
- with gr.Row():
220
- sql_output = gr.Textbox(label="Generated SQL Query", lines=5)
221
- with gr.Row():
222
- status_output = gr.Textbox(label="Execution Status", lines=2)
223
- with gr.Row():
224
- results_output = gr.Dataframe(label="Query Results", interactive=False)
225
- with gr.Row():
226
- summary_output = gr.Textbox(label="LLM-Grounded Summary", lines=5)
227
- with gr.Row():
228
- abort_btn = gr.Button("Abort / Stop Task")
229
- running_event=submit_btn.click(
230
- fn=full_pipeline,
231
- inputs=input_text,
232
- outputs=[sql_output, status_output, summary_output, table_output, columns_output, results_output]
233
- )
234
- abort_btn.click(
235
- None,
236
- inputs=None,
237
- outputs=None,
238
- cancels=[running_event],
239
- queue=False
240
- )
241
- if __name__ == "__main__":
242
- gradio_interface.launch()
 
1
+ import os
2
+ import re
3
+ import sqlite3
4
+ import warnings
5
+ import gradio as gr
6
+ import pandas as pd
7
+ from schema import schema
8
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
9
+
10
+ warnings.filterwarnings("ignore")
11
+ API_KEY = "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"
12
+ db_path = "wash_db.db"
13
+
14
+ client = ChatNVIDIA(
15
+ model="deepseek-ai/deepseek-r1",
16
+ api_key=API_KEY,
17
+ temperature=0.1,
18
+ top_p=1,
19
+ max_tokens=4096,
20
+ )
21
+
22
+ def get_table_names(schema: str):
23
+ return re.findall(r'TABLE (\w+)', schema)
24
+
25
+ def get_table_columns(schema: str, table: str):
26
+ m = re.search(rf'TABLE {table} \((.*?)\)', schema, re.DOTALL)
27
+ if m:
28
+ cols_block = m.group(1)
29
+ cols = re.findall(r'(\w+)', cols_block)
30
+ return [col for col in cols if col.lower() not in {"int", "primary", "key", "string", "bit", "real", "references"}]
31
+ return []
32
+
33
+ def agent_select_table(user_query, schema):
34
+ tables = get_table_names(schema)
35
+ # First, try longest keyword containment in table name
36
+ best = ""
37
+ best_len = 0
38
+ for table in tables:
39
+ for word in user_query.lower().split():
40
+ if word in table.lower() and len(word) > best_len:
41
+ best = table
42
+ best_len = len(word)
43
+ if best:
44
+ return best
45
+ # fallback: first table
46
+ return tables[0]
47
+
48
+ def agent_select_columns(user_query, table, schema):
49
+ columns = get_table_columns(schema, table)
50
+ selected = []
51
+ for col in columns:
52
+ if any(word in col.lower() for word in user_query.lower().split()):
53
+ selected.append(col)
54
+ return selected if selected else columns # fallback all columns
55
+
56
+ def build_sql_prompt(table, columns, schema, user_question, error_reason=None):
57
+ prompt = (
58
+ f"You are an expert SQL assistant.\n"
59
+ f"Schema: {schema}\n"
60
+ # f"Columns: {', '.join(columns)}\n"
61
+ f"User question: {user_question}\n"
62
+ "Write a valid SQLite SQL query answering the question using only the given table and columns.\n"
63
+ )
64
+ if error_reason:
65
+ prompt += f"The previous SQL query failed with the error: {error_reason}\nPlease fix and regenerate the SQL only."
66
+ return prompt
67
+
68
+ def extract_sql_query(text):
69
+ patterns = [
70
+ r"```sql\n(.*?)```",
71
+ r"```\n(.*?)```",
72
+ r"```(.*?)```",
73
+ ]
74
+
75
+ for pattern in patterns:
76
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
77
+ if match:
78
+ return match.group(1).strip()
79
+ # Else, look for SELECT...;
80
+ match = re.search(r"(SELECT|INSERT|UPDATE|DELETE|CREATE|DROP|ALTER).*?;", text, re.DOTALL | re.IGNORECASE)
81
+ if match:
82
+ return match.group(0).strip()
83
+ lines = text.split('\n')
84
+ sql_lines = [l for l in lines if any(k in l.upper() for k in ['SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE'])]
85
+ if sql_lines:
86
+ return ' '.join(sql_lines)
87
+ return text.strip()
88
+
89
+ def execute_sql_query(sql_query, db_path=db_path):
90
+ try:
91
+ conn = sqlite3.connect(db_path)
92
+ df = pd.read_sql_query(sql_query, conn)
93
+ conn.close()
94
+ return df, None
95
+ except Exception as e:
96
+ return None, str(e)
97
+
98
+ def summarize_with_llm(table, columns, data, user_query):
99
+ preview = data.head(5).to_markdown(index=False) if data is not None and not data.empty else "No data returned."
100
+ prompt = (
101
+ f"User query: {user_query}\n"
102
+ f"SQL result preview \n{preview}\n"
103
+ f"Summarize the result, referencing the user query and the preview.)."
104
+ )
105
+ resp = client.invoke([{"role": "user", "content": prompt}])
106
+ return getattr(resp, "content", resp) if hasattr(resp, "content") else str(resp)
107
+
108
+ # def full_pipeline(user_question):
109
+ # table = agent_select_table(user_question, schema)
110
+ # columns = agent_select_columns(user_question, table, schema)
111
+ # yield {
112
+ # table_output: gr.update(value=table),
113
+ # columns_output: gr.update(value=", ".join(columns)),
114
+ # }
115
+ # sql_prompt = build_sql_prompt(table, columns, user_question)
116
+ # sql_query, error = "", None
117
+
118
+ # # Error-handling and retry loop
119
+ # for _ in range(5):
120
+ # llm_resp = client.invoke([{"role": "user", "content": sql_prompt}])
121
+ # llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp)
122
+ # sql_query = extract_sql_query(llm_text)
123
+ # results_df, error = execute_sql_query(sql_query)
124
+ # if not error:
125
+ # break
126
+ # sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error)
127
+ # # Summarize
128
+ # summary = summarize_with_llm(table, columns, results_df, user_question)
129
+ # # Format outputs
130
+ # columns_view = ", ".join(columns)
131
+ # sql_view = f"```sql\n{sql_query}\n```"
132
+ # status_view = f"Success" if not error else f"Query error: {error}"
133
+ # out_df = results_df if results_df is not None else pd.DataFrame()
134
+ # return sql_view, status_view, summary, table, columns_view, out_df
135
+
136
+ def full_pipeline_stream(user_question):
137
+ yield "Identifying relevant table and columns...", "", "", "", "", pd.DataFrame()
138
+ table = agent_select_table(user_question, schema)
139
+ columns = agent_select_columns(user_question, table, schema)
140
+ yield f"Table '{table}' selected.", "", "", table, ", ".join(columns), pd.DataFrame()
141
+
142
+ sql_prompt = build_sql_prompt(table, columns, user_question)
143
+ sql_query, error = "", None
144
+
145
+ for _ in range(5):
146
+ yield f"Generating SQL (attempt {_+1})...", "", "", table, ", ".join(columns), pd.DataFrame()
147
+ llm_resp = client.invoke([{"role": "user", "content": sql_prompt}])
148
+ llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp)
149
+ sql_query = extract_sql_query(llm_text)
150
+ results_df, error = execute_sql_query(sql_query)
151
+ if not error:
152
+ yield f"SQL executed successfully.", f"``````", "", table, ", ".join(columns), results_df
153
+ break
154
+ sql_prompt = build_sql_prompt(table, columns, user_question, error_reason=error)
155
+ yield f"Retrying due to error: {error}", f"``````", "", table, ", ".join(columns), pd.DataFrame()
156
+
157
+ if not error:
158
+ summary = summarize_with_llm(table, columns, results_df, user_question)
159
+ yield "Summarization complete.", f"``````", summary, table, ", ".join(columns), results_df
160
+ else:
161
+ yield f"Final error: {error}", f"``````", "No summary due to error.", table, ", ".join(columns), pd.DataFrame()
162
+ def full_pipeline(user_question):
163
+ # Step 1: Identify table and columns first
164
+ # yield "", "", "", "", "", pd.DataFrame()
165
+ table = agent_select_table(user_question, schema)
166
+ columns = agent_select_columns(user_question, table, schema)
167
+
168
+ # Immediately return only these two visible outputs
169
+ yield {
170
+ table_output: gr.update(value=table),
171
+ columns_output: gr.update(value=", ".join(columns)),
172
+ }
173
+
174
+ # Step 2: Continue with downstream pipeline
175
+ sql_prompt = build_sql_prompt(table, columns, schema, user_question)
176
+ sql_query, error = "", None
177
+
178
+ for _ in range(5):
179
+ llm_resp = client.invoke([{"role": "user", "content": sql_prompt}])
180
+ llm_text = getattr(llm_resp, "content", llm_resp) if hasattr(llm_resp, "content") else str(llm_resp)
181
+ sql_query = extract_sql_query(llm_text)
182
+ results_df, error = execute_sql_query(sql_query)
183
+ if not error:
184
+ break
185
+ sql_prompt = build_sql_prompt(table, columns, schema, user_question, error_reason=error)
186
+
187
+ sql_view = f"\n{sql_query.strip()}\n"
188
+ status_view = "Success" if not error else f"Query error: {error}"
189
+ out_df = results_df if results_df is not None else pd.DataFrame()
190
+ yield {
191
+ sql_output: gr.update(value=sql_view),
192
+ status_output: gr.update(value=status_view),
193
+ results_output: gr.update(value=out_df)
194
+
195
+ }
196
+ summary = summarize_with_llm(table, columns, results_df, user_question).strip()
197
+
198
+
199
+
200
+
201
+ yield {
202
+ # sql_output: gr.update(value=sql_view),
203
+
204
+ summary_output: gr.update(value=summary),
205
+
206
+ }
207
+
208
+
209
+ with gr.Blocks(title="NL2SQL Pipeline)") as gradio_interface:
210
+ gr.Markdown("## NL2SQL Pipeline ")
211
+ gr.Markdown("Enter a question about the water supply database. The agent will select relevant table/columns, generate and retry SQL on error, show results and a grounded summary.")
212
+ with gr.Row():
213
+ input_text = gr.Textbox(label="Enter your natural language question", lines=3)
214
+ with gr.Row():
215
+ submit_btn = gr.Button("Generate, Execute & Summarize", variant="primary")
216
+ with gr.Row():
217
+ table_output = gr.Textbox(label="Table Used", lines=1)
218
+ columns_output = gr.Textbox(label="Columns Used", lines=2)
219
+ with gr.Row():
220
+ sql_output = gr.Textbox(label="Generated SQL Query", lines=5)
221
+ with gr.Row():
222
+ status_output = gr.Textbox(label="Execution Status", lines=2)
223
+ with gr.Row():
224
+ results_output = gr.Dataframe(label="Query Results", interactive=False)
225
+ with gr.Row():
226
+ summary_output = gr.Textbox(label="LLM-Grounded Summary", lines=5)
227
+ with gr.Row():
228
+ abort_btn = gr.Button("Abort / Stop Task")
229
+ running_event=submit_btn.click(
230
+ fn=full_pipeline,
231
+ inputs=input_text,
232
+ outputs=[sql_output, status_output, summary_output, table_output, columns_output, results_output]
233
+ )
234
+ abort_btn.click(
235
+ None,
236
+ inputs=None,
237
+ outputs=None,
238
+ cancels=[running_event],
239
+ queue=False
240
+ )
241
+ if __name__ == "__main__":
242
+ gradio_interface.launch()