Shashwat-18 commited on
Commit
bd89b88
·
verified ·
1 Parent(s): 22e87d4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2360 -0
app.py ADDED
@@ -0,0 +1,2360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os
3
+ import uuid
4
+ import sqlite3
5
+ import datetime
6
+ import json
7
+ import re
8
+ import time
9
+ from typing import Dict, Any, List, Tuple, Optional
10
+ import hashlib
11
+
12
+ import pandas as pd
13
+ from groq import Groq
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+
17
+ # %%
18
+ # 0. Global config
19
+ DB_PATH = "olist.db"
20
+ DATA_DIR = "data"
21
+
22
+ # Groq Model
23
+ GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
24
+
25
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
26
+ if not GROQ_API_KEY:
27
+ raise ValueError("GROQ_API_KEY not found.")
28
+
29
+ groq_client = Groq(api_key=GROQ_API_KEY)
30
+
31
+ # Embedding model
32
+ EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
33
+
34
+ # Logging paths
35
+ PROMPT_LOG_PATH = "prompt_logs.txt"
36
+ RUN_LOG_PATH = "run_logs.txt"
37
+
38
+ # %%
39
+
40
+ import os
41
+ import chromadb
42
+ from chromadb.config import Settings
43
+
44
+ CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data")
45
+ os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
46
+
47
+ _CHROMA_CLIENT = None
48
+
49
+ def get_chroma_client():
50
+ global _CHROMA_CLIENT
51
+ if _CHROMA_CLIENT is None:
52
+ _CHROMA_CLIENT = chromadb.PersistentClient(
53
+ path=CHROMA_PERSIST_DIR,
54
+ settings=Settings(
55
+ anonymized_telemetry=False,
56
+ ),
57
+ )
58
+ return _CHROMA_CLIENT
59
+
60
+
61
+ # %%
62
+
63
+ def log_prompt(tag: str, prompt: str) -> None:
64
+ """
65
+ Append the full prompt to a log file, with a tag and timestamp.
66
+ """
67
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
68
+ header = f"\n\n================ {tag} @ {timestamp} ================\n"
69
+ with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
70
+ f.write(header)
71
+ f.write(prompt)
72
+ f.write("\n================ END PROMPT ================\n")
73
+
74
+
75
+ def log_run_event(tag: str, content: str) -> None:
76
+ """
77
+ Append model response, final SQL, and error info into a run log.
78
+ """
79
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
80
+ header = f"\n\n================ {tag} @ {timestamp} ================\n"
81
+ with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
82
+ f.write(header)
83
+ f.write(content)
84
+ f.write("\n================ END EVENT ================\n")
85
+
86
+
87
+ # %%
88
+ # ==============================
89
+ # Chat message helpers
90
+ # ==============================
91
+
92
+ def assistant_msg(text: str):
93
+ return (None, text)
94
+
95
+ def user_msg(text: str):
96
+ return (text, None)
97
+
98
+ def add_assistant_state(history, text):
99
+ history.append(assistant_msg(text))
100
+ return history
101
+
102
+
103
+ # %%
104
+ import pandas as pd
105
+
106
+ def df_to_state(df: pd.DataFrame) -> dict:
107
+ if df is None or not isinstance(df, pd.DataFrame) or df.empty:
108
+ return {"columns": [], "rows": []}
109
+
110
+ cols = list(df.columns)
111
+ rows = df.astype(str).values.tolist()
112
+
113
+ if not rows:
114
+ return {"columns": [], "rows": []}
115
+
116
+ return {"columns": cols, "rows": rows}
117
+
118
+
119
+
120
+ def state_to_df(state: dict) -> pd.DataFrame:
121
+ if not isinstance(state, dict):
122
+ return pd.DataFrame()
123
+
124
+ cols = state.get("columns")
125
+ rows = state.get("rows")
126
+
127
+ if not cols or rows is None:
128
+ return pd.DataFrame()
129
+
130
+ return pd.DataFrame(rows, columns=cols)
131
+
132
+
133
+ # %%
134
+
135
+ def init_feedback_table(conn: sqlite3.Connection) -> None:
136
+ """
137
+ Create (or upgrade) a table to capture user feedback on model answers.
138
+ """
139
+ conn.execute("""
140
+ CREATE TABLE IF NOT EXISTS user_feedback (
141
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
142
+ created_at TEXT NOT NULL,
143
+ question TEXT NOT NULL,
144
+ generated_sql TEXT,
145
+ model_answer TEXT,
146
+ rating TEXT CHECK(rating IN ('good','bad')) NOT NULL,
147
+ comment TEXT,
148
+ corrected_sql TEXT
149
+ )
150
+ """)
151
+ conn.commit()
152
+
153
+
154
+ def record_feedback(
155
+ conn: sqlite3.Connection,
156
+ question: str,
157
+ generated_sql: str,
158
+ model_answer: str,
159
+ rating: str, # "good" or "bad"
160
+ comment: Optional[str] = None,
161
+ corrected_sql: Optional[str] = None,
162
+ ) -> None:
163
+ """
164
+ Store user feedback about a particular model answer / SQL query.
165
+ If corrected_sql is provided, it is treated as an external correction.
166
+ """
167
+ rating = rating.lower()
168
+ if rating not in ("good", "bad"):
169
+ raise ValueError("rating must be 'good' or 'bad'")
170
+
171
+ ts = datetime.datetime.now().isoformat(timespec="seconds")
172
+ conn.execute(
173
+ """
174
+ INSERT INTO user_feedback (
175
+ created_at, question, generated_sql, model_answer,
176
+ rating, comment, corrected_sql
177
+ )
178
+ VALUES (?, ?, ?, ?, ?, ?, ?)
179
+ """,
180
+ (ts, question, generated_sql, model_answer, rating, comment, corrected_sql),
181
+ )
182
+ conn.commit()
183
+
184
+
185
+ def get_last_feedback_for_question(
186
+ conn: sqlite3.Connection,
187
+ question: str,
188
+ ) -> Optional[Dict[str, Any]]:
189
+ """
190
+ Return the most recent feedback row for this question (if any).
191
+ """
192
+ cur = conn.cursor()
193
+ cur.execute(
194
+ """
195
+ SELECT created_at, generated_sql, model_answer,
196
+ rating, comment, corrected_sql
197
+ FROM user_feedback
198
+ WHERE question = ?
199
+ ORDER BY created_at DESC
200
+ LIMIT 1
201
+ """,
202
+ (question,),
203
+ )
204
+ row = cur.fetchone()
205
+ if not row:
206
+ return None
207
+
208
+ return {
209
+ "created_at": row[0],
210
+ "generated_sql": row[1],
211
+ "model_answer": row[2],
212
+ "rating": row[3],
213
+ "comment": row[4],
214
+ "corrected_sql": row[5],
215
+ }
216
+
217
+
218
+ # %%
219
+
220
+ def init_db() -> sqlite3.Connection:
221
+ """
222
+ Load all CSVs from the data/ folder into a local SQLite DB.
223
+ Table names are derived from file names (without .csv).
224
+ """
225
+ conn = sqlite3.connect(DB_PATH, check_same_thread=False)
226
+
227
+ csv_files = [
228
+ "olist_customers_dataset.csv",
229
+ "olist_orders_dataset.csv",
230
+ "olist_order_items_dataset.csv",
231
+ "olist_products_dataset.csv",
232
+ "olist_order_reviews_dataset.csv",
233
+ "olist_order_payments_dataset.csv",
234
+ "product_category_name_translation.csv",
235
+ "olist_sellers_dataset.csv",
236
+ "olist_geolocation_dataset.csv",
237
+ ]
238
+
239
+ for fname in csv_files:
240
+ path = os.path.join(DATA_DIR, fname)
241
+ print(path)
242
+ if not os.path.exists(path):
243
+ print(f"CSV not found: {path} - skipping")
244
+ continue
245
+
246
+ table_name = os.path.splitext(fname)[0]
247
+ print(f"Loading {path} into table {table_name}...")
248
+ df = pd.read_csv(path)
249
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
250
+
251
+
252
+ init_feedback_table(conn)
253
+
254
+ return conn
255
+
256
+
257
+ conn = init_db()
258
+
259
+
260
+
261
+ # %%
262
+ # 4. Manual docs for Olist tables
263
+
264
+ OLIST_DOCS: Dict[str, Dict[str, Any]] = {
265
+ "olist_customers_dataset": {
266
+ "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
267
+ "columns": {
268
+ "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
269
+ "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
270
+ "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
271
+ "customer_city": "Customer's city as captured at the time of the order or registration.",
272
+ "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
273
+ }
274
+ },
275
+ "olist_orders_dataset": {
276
+ "description": "Customer orders placed on the Olist marketplace, one row per order.",
277
+ "columns": {
278
+ "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
279
+ "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
280
+ "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
281
+ "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
282
+ "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
283
+ "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
284
+ "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
285
+ "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
286
+ }
287
+ },
288
+ "olist_order_items_dataset": {
289
+ "description": "Order line items, one row per product per order.",
290
+ "columns": {
291
+ "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
292
+ "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
293
+ "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
294
+ "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
295
+ "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
296
+ "price": "Item price paid by the customer for this line (in BRL, not including freight).",
297
+ "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
298
+ }
299
+ },
300
+ "olist_products_dataset": {
301
+ "description": "Product catalog with physical and category attributes, one row per product.",
302
+ "columns": {
303
+ "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
304
+ "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
305
+ "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
306
+ "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
307
+ "product_photos_qty": "Number of product images associated with the listing.",
308
+ "product_weight_g": "Product weight in grams.",
309
+ "product_length_cm": "Product length in centimeters (package dimension).",
310
+ "product_height_cm": "Product height in centimeters (package dimension).",
311
+ "product_width_cm": "Product width in centimeters (package dimension)."
312
+ }
313
+ },
314
+ "olist_order_reviews_dataset": {
315
+ "description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
316
+ "columns": {
317
+ "review_id": "Primary key. Unique identifier for each review record.",
318
+ "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
319
+ "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
320
+ "review_comment_title": "Optional short text title or summary of the review.",
321
+ "review_comment_message": "Optional detailed free-text comment describing the customer experience.",
322
+ "review_creation_date": "Date when the customer created the review.",
323
+ "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
324
+ }
325
+ },
326
+ "olist_order_payments_dataset": {
327
+ "description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
328
+ "columns": {
329
+ "order_id": "Foreign key to olist_orders_dataset.order_id.",
330
+ "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
331
+ "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
332
+ "payment_installments": "Number of installments chosen by the customer for this payment.",
333
+ "payment_value": "Monetary amount paid in this payment record (in BRL)."
334
+ }
335
+ },
336
+ "product_category_name_translation": {
337
+ "description": "Lookup table mapping Portuguese product category names to English equivalents.",
338
+ "columns": {
339
+ "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
340
+ "product_category_name_english": "Translated product category name in English."
341
+ }
342
+ },
343
+ "olist_sellers_dataset": {
344
+ "description": "Seller master data, one row per seller operating on the Olist marketplace.",
345
+ "columns": {
346
+ "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
347
+ "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
348
+ "seller_city": "City where the seller is located.",
349
+ "seller_state": "State where the seller is located (two-letter Brazilian state code)."
350
+ }
351
+ },
352
+ "olist_geolocation_dataset": {
353
+ "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
354
+ "columns": {
355
+ "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
356
+ "geolocation_lat": "Latitude coordinate of the location.",
357
+ "geolocation_lng": "Longitude coordinate of the location.",
358
+ "geolocation_city": "City name for the location.",
359
+ "geolocation_state": "State code for the location (two-letter Brazilian state code)."
360
+ }
361
+ }
362
+ }
363
+
364
+
365
+
366
+ # %%
367
+
368
+ def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
369
+ """
370
+ Introspect SQLite tables, columns and foreign key relationships.
371
+ """
372
+ cursor = connection.cursor()
373
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
374
+ tables = [row[0] for row in cursor.fetchall()]
375
+
376
+ metadata: Dict[str, Any] = {"tables": {}}
377
+
378
+ for table in tables:
379
+ cursor.execute(f"PRAGMA table_info('{table}')")
380
+ cols = cursor.fetchall()
381
+
382
+ # Manual docs for the table
383
+ table_docs = OLIST_DOCS.get(table, {})
384
+ table_desc: str = table_docs.get(
385
+ "description",
386
+ f"Table '{table}' from Olist dataset."
387
+ )
388
+ column_docs: Dict[str, str] = table_docs.get("columns", {})
389
+
390
+ columns_meta: Dict[str, str] = {}
391
+ for c in cols:
392
+ col_name = c[1]
393
+ col_type = c[2] or "TEXT"
394
+
395
+ if col_name in column_docs:
396
+ columns_meta[col_name] = column_docs[col_name]
397
+ else:
398
+ columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
399
+
400
+ cursor.execute(f"PRAGMA foreign_key_list('{table}')")
401
+ fk_rows = cursor.fetchall()
402
+ relationships: List[str] = []
403
+ for fk in fk_rows:
404
+ ref_table = fk[2]
405
+ from_col = fk[3]
406
+ to_col = fk[4]
407
+ relationships.append(
408
+ f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)"
409
+ )
410
+
411
+ metadata["tables"][table] = {
412
+ "description": table_desc,
413
+ "columns": columns_meta,
414
+ "relationships": relationships,
415
+ }
416
+
417
+ return metadata
418
+
419
+
420
+ schema_metadata = extract_schema_metadata(conn)
421
+
422
+
423
+ def build_schema_yaml(metadata: Dict[str, Any]) -> str:
424
+ """
425
+ Render metadata dict into a YAML-style string.
426
+ """
427
+ lines: List[str] = ["tables:"]
428
+ for tname, tinfo in metadata["tables"].items():
429
+ lines.append(f" {tname}:")
430
+ desc = tinfo.get("description", "").replace('"', "'")
431
+ lines.append(f' description: "{desc}"')
432
+ lines.append(" columns:")
433
+ for col_name, col_desc in tinfo.get("columns", {}).items():
434
+ col_desc_clean = col_desc.replace('"', "'")
435
+ lines.append(f' {col_name}: "{col_desc_clean}"')
436
+ rels = tinfo.get("relationships", [])
437
+ if rels:
438
+ lines.append(" relationships:")
439
+ for rel in rels:
440
+ rel_clean = rel.replace('"', "'")
441
+ lines.append(f' - "{rel_clean}"')
442
+ return "\n".join(lines)
443
+
444
+
445
+ schema_yaml = build_schema_yaml(schema_metadata)
446
+
447
+
448
+
449
+ # %%
450
+ # 6. Build schema documents for RAG (taking samples from the table)
451
+
452
+ def build_schema_documents(
453
+ connection: sqlite3.Connection,
454
+ schema_metadata: Dict[str, Any],
455
+ sample_rows: int = 5,
456
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
457
+ """
458
+ Build one rich RAG document per table, using schema_metadata.
459
+
460
+ Each document includes:
461
+ - Table name
462
+ - Table description
463
+ - Columns with type + description
464
+ - Relationships (FKs)
465
+ - A few sample rows
466
+ """
467
+ cursor = connection.cursor()
468
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
469
+ tables = [row[0] for row in cursor.fetchall()]
470
+
471
+ docs: List[str] = []
472
+ metadatas: List[Dict[str, Any]] = []
473
+
474
+ for table in tables:
475
+ tmeta = schema_metadata["tables"][table]
476
+ table_desc = tmeta.get("description", "")
477
+ columns_meta = tmeta.get("columns", {})
478
+ relationships = tmeta.get("relationships", [])
479
+
480
+ # Use PRAGMA to get types, then enrich with descriptions
481
+ cursor.execute(f"PRAGMA table_info('{table}')")
482
+ cols = cursor.fetchall()
483
+
484
+ col_lines = []
485
+ for c in cols:
486
+ col_name = c[1]
487
+ col_type = c[2] or "TEXT"
488
+ col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
489
+ col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
490
+
491
+ # Sample rows
492
+ try:
493
+ sample_df = pd.read_sql_query(
494
+ f"SELECT * FROM '{table}' LIMIT {sample_rows}",
495
+ connection,
496
+ )
497
+ sample_text = sample_df.to_markdown(index=False)
498
+ except Exception:
499
+ sample_text = "(could not fetch sample rows)"
500
+
501
+ # Relationships block
502
+ rel_block = ""
503
+ if relationships:
504
+ rel_block = "Relationships:\n" + "\n".join(
505
+ f"- {rel}" for rel in relationships
506
+ ) + "\n"
507
+
508
+ doc_text = (
509
+ f"Table: {table}\n"
510
+ f"Description: {table_desc}\n\n"
511
+ f"Columns:\n" + "\n".join(col_lines) + "\n\n"
512
+ f"{rel_block}\n"
513
+ f"Example rows:\n{sample_text}\n"
514
+ )
515
+
516
+ docs.append(doc_text)
517
+ metadatas.append({
518
+ "doc_type": "table_schema",
519
+ "table_name": table,
520
+ })
521
+
522
+ return docs, metadatas
523
+
524
+
525
+ # Build RAG texts + metadata
526
+ schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
527
+
528
+ RAG_TEXTS: List[str] = []
529
+ RAG_METADATAS: List[Dict[str, Any]] = []
530
+
531
+ # 1) Per-table docs
532
+ RAG_TEXTS.extend(schema_docs)
533
+ RAG_METADATAS.extend(schema_doc_metas)
534
+
535
+ # 2) Global YAML as a separate doc
536
+ RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
537
+ RAG_METADATAS.append({"doc_type": "global_schema"})
538
+
539
+
540
+ # %%
541
+ def build_store_final():
542
+ embedding_model = HuggingFaceEmbeddings(
543
+ model_name=EMBED_MODEL_NAME,
544
+ encode_kwargs={"normalize_embeddings": True},
545
+ )
546
+
547
+ rag_store = Chroma(
548
+ client=get_chroma_client(),
549
+ collection_name="rag_schema_store",
550
+ embedding_function=embedding_model,
551
+ )
552
+
553
+ if rag_store._collection.count() == 0:
554
+ rag_store.add_texts(
555
+ texts=RAG_TEXTS,
556
+ metadatas=RAG_METADATAS,
557
+ )
558
+
559
+ rag_retriever = rag_store.as_retriever(
560
+ search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}}
561
+ )
562
+
563
+ return rag_store, rag_retriever
564
+
565
+ # Initialize once
566
+ rag_store, rag_retriever = build_store_final()
567
+
568
+
569
+ # %%
570
+ from langchain_community.vectorstores import Chroma
571
+ from langchain_community.embeddings import HuggingFaceEmbeddings
572
+
573
+
574
+ _sql_cache_store = None
575
+ _sql_embedding_fn = None
576
+
577
+ SQL_CACHE_COLLECTION = "sql_cache_mpnet"
578
+
579
+ def get_sql_cache_store():
580
+ global _sql_cache_store, _sql_embedding_fn
581
+
582
+ if _sql_cache_store is not None:
583
+ return _sql_cache_store
584
+
585
+ if _sql_embedding_fn is None:
586
+ _sql_embedding_fn = HuggingFaceEmbeddings(
587
+ model_name=EMBED_MODEL_NAME,
588
+ encode_kwargs={"normalize_embeddings": True},
589
+ )
590
+
591
+ _sql_cache_store = Chroma(
592
+ client=get_chroma_client(),
593
+ collection_name=SQL_CACHE_COLLECTION,
594
+ embedding_function=_sql_embedding_fn,
595
+ )
596
+
597
+ return _sql_cache_store
598
+
599
+
600
+ # %%
601
+ def sanitize_metadata_for_chroma(metadata: dict) -> dict:
602
+ safe = {}
603
+ for k, v in (metadata or {}).items():
604
+ if v is None:
605
+ continue
606
+ if isinstance(v, (int, float, bool)):
607
+ safe[k] = v
608
+ elif isinstance(v, str):
609
+ safe[k] = v
610
+ else:
611
+ safe[k] = str(v)
612
+ return safe
613
+
614
+
615
+ # %%
616
+ def normalize_question_strict(q: str) -> str:
617
+ """
618
+ Deterministic normalization for exact cache hits.
619
+ """
620
+ if not q:
621
+ return ""
622
+ q = q.lower().strip()
623
+ q = re.sub(r"[^\w\s]", "", q)
624
+ q = re.sub(r"\s+", " ", q)
625
+ return q
626
+
627
+
628
+ # %%
629
+ # Helper: normalize question
630
+
631
+ def normalize_question_text(q: str) -> str:
632
+ if not q:
633
+ return ""
634
+ q = q.strip().lower()
635
+ q = re.sub(r"[^\w\s]", " ", q)
636
+ q = re.sub(r"\s+", " ", q).strip()
637
+ return q
638
+
639
+ # Helper: compute success rate
640
+ def compute_success_rate(md: dict) -> float:
641
+ sc = md.get("success_count", 0) or 0
642
+ tf = md.get("total_feedbacks", 0) or 0
643
+ if tf <= 0:
644
+ return 0.0
645
+ return float(sc) / float(tf)
646
+
647
+ # Insert initial cache entry (no feedback yet)
648
+ def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None):
649
+ """
650
+ Insert a cached entry when you run a query and want to cache it regardless of feedback.
651
+ initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0
652
+ """
653
+ if store is None:
654
+ store = get_sql_cache_store()
655
+
656
+ ident = uuid.uuid4().hex
657
+ norm = normalize_question_text(question)
658
+ md = {
659
+ "id": ident,
660
+ "normalized_question": norm,
661
+ "sql": sql,
662
+ "answer_md": answer_md,
663
+ "saved_at": time.time(),
664
+ "views": 1,
665
+ "success_count": 0,
666
+ "total_feedbacks": 0,
667
+ "success_rate": 1.0,
668
+ }
669
+ if extra_metadata:
670
+ md.update(extra_metadata)
671
+
672
+ # Use store's API;
673
+ store.add_texts([question], metadatas=[md])
674
+ return ident
675
+
676
+
677
+
678
+ # %%
679
+ import time
680
+ import logging
681
+ from typing import Optional, List, Dict, Any
682
+ from difflib import SequenceMatcher
683
+
684
+ _logger = logging.getLogger(__name__)
685
+
686
+ # -------------------------------------------------------------------------
687
+ # Utility helpers
688
+ # -------------------------------------------------------------------------
689
+
690
+ def _now_ts() -> float:
691
+ return time.time()
692
+
693
+ def similarity_score(a: str, b: str) -> float:
694
+ return SequenceMatcher(None, (a or ""), (b or "")).ratio()
695
+
696
+ # -------------------------------------------------------------------------
697
+ # Persist helper
698
+ # -------------------------------------------------------------------------
699
+ def langchain_upsert(
700
+ store,
701
+ text: str,
702
+ metadata: dict,
703
+ cache_id: str,
704
+ ):
705
+ safe_md = sanitize_metadata_for_chroma(metadata)
706
+
707
+ try:
708
+ store.add_texts(
709
+ texts=[text],
710
+ metadatas=[safe_md],
711
+ ids=[cache_id],
712
+ )
713
+ except TypeError:
714
+ store.add_texts(
715
+ texts=[text],
716
+ metadatas=[safe_md],
717
+ )
718
+
719
+
720
+ # -------------------------------------------------------------------------
721
+ # Cache insert / update
722
+ # -------------------------------------------------------------------------
723
+ def cache_sql_answer_dedup(
724
+ question: str,
725
+ sql: str,
726
+ answer_md: str,
727
+ metadata: dict,
728
+ store,
729
+ ):
730
+ norm_q_semantic = normalize_text(question)
731
+ norm_q_exact = normalize_question_strict(question)
732
+
733
+ cache_id = generate_cache_id(question, sql)
734
+ now = _now_ts()
735
+
736
+ md = {
737
+ # identity
738
+ "cache_id": cache_id,
739
+ "sql": sql,
740
+ "answer_md": answer_md,
741
+
742
+ # exact match key
743
+ "normalized_question": norm_q_exact,
744
+
745
+ # timestamps
746
+ "saved_at": metadata.get("saved_at", now),
747
+ "last_updated_at": now,
748
+ "last_viewed_at": metadata.get("last_viewed_at", 0),
749
+
750
+ # metrics
751
+ "good_count": metadata.get("good_count", 0),
752
+ "bad_count": metadata.get("bad_count", 0),
753
+ "total_feedbacks": metadata.get("total_feedbacks", 0),
754
+ "success_rate": metadata.get("success_rate", 0.5),
755
+ "views": metadata.get("views", 0),
756
+ }
757
+
758
+ langchain_upsert(
759
+ store=store,
760
+ text=norm_q_semantic, # semantic vector
761
+ metadata=md,
762
+ cache_id=cache_id,
763
+ )
764
+
765
+ return {
766
+ "question": question,
767
+ "sql": sql,
768
+ "answer_md": answer_md,
769
+ "metadata": md,
770
+ }
771
+
772
+ # -------------------------------------------------------------------------
773
+ # Find exact cached entry (question + SQL)
774
+ # -------------------------------------------------------------------------
775
+
776
+ def find_cached_doc_by_sql(question: str, sql: str, store):
777
+ cache_id = generate_cache_id(question, sql)
778
+ coll = getattr(store, "_collection", None)
779
+
780
+ if coll and hasattr(coll, "get"):
781
+ try:
782
+ res = coll.get(ids=[cache_id])
783
+ if res and res.get("metadatas"):
784
+ md = res["metadatas"][0]
785
+ return {
786
+ "id": cache_id,
787
+ "question": question,
788
+ "sql": md.get("sql"),
789
+ "answer_md": md.get("answer_md"),
790
+ "metadata": md,
791
+ }
792
+ except Exception:
793
+ pass
794
+
795
+ return None
796
+
797
+ # -------------------------------------------------------------------------
798
+ # Retrieve cached answers ranked primarily by success rate
799
+ # -------------------------------------------------------------------------
800
+
801
+ import re
802
+ import unicodedata
803
+
804
+
805
+ def normalize_text(text: str) -> str:
806
+ if not text:
807
+ return ""
808
+ text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii")
809
+ text = text.lower()
810
+ text = re.sub(r"[^\w\s]", " ", text)
811
+ text = re.sub(r"\s+", " ", text).strip()
812
+ return text
813
+
814
+ import hashlib
815
+
816
+ def generate_cache_id(question: str, sql: str) -> str:
817
+ q = normalize_text(question)
818
+ s = (sql or "").strip()
819
+ key = f"{q}||{s}".encode("utf-8")
820
+ return hashlib.sha1(key).hexdigest()
821
+
822
+
823
+ def rank_cached_candidates(candidates: list[dict]) -> dict:
824
+ """
825
+ Deterministic quality-first ranking.
826
+ """
827
+ candidates.sort(
828
+ key=lambda c: (
829
+ -float(c["metadata"].get("success_rate", 0.5)),
830
+ -int(c["metadata"].get("good_count", 0)),
831
+ int(c["metadata"].get("bad_count", 0)),
832
+ -float(c["metadata"].get("last_updated_at", 0)),
833
+ )
834
+ )
835
+ return candidates[0]
836
+
837
+
838
+
839
+ def retrieve_exact_cached_sql(question: str, store):
840
+ """
841
+ Exact question match, but still quality-ranked.
842
+ """
843
+ norm_q = normalize_question_strict(question)
844
+ coll = store._collection
845
+
846
+ try:
847
+ res = coll.get(where={"normalized_question": norm_q})
848
+ if not res or not res.get("metadatas"):
849
+ return None
850
+
851
+ candidates = []
852
+ for md in res["metadatas"]:
853
+ candidates.append({
854
+ "matched_question": question,
855
+ "sql": md["sql"],
856
+ "answer_md": md.get("answer_md", ""),
857
+ "distance": 0.0,
858
+ "metadata": md,
859
+ })
860
+
861
+ return rank_cached_candidates(candidates)
862
+
863
+ except Exception:
864
+ return None
865
+
866
+
867
+ def retrieve_best_cached_sql(
868
+ question: str,
869
+ store,
870
+ max_distance: float = 0.25,
871
+ ):
872
+ norm_q = normalize_text(question)
873
+ results = store.similarity_search_with_score(norm_q, k=20)
874
+
875
+ candidates = []
876
+
877
+ for doc, score in results:
878
+ distance = float(score)
879
+ if distance > max_distance:
880
+ continue
881
+
882
+ md = doc.metadata or {}
883
+ if "sql" not in md:
884
+ continue
885
+
886
+ candidates.append({
887
+ "matched_question": doc.page_content,
888
+ "sql": md["sql"],
889
+ "answer_md": md.get("answer_md", ""),
890
+ "distance": distance,
891
+ "metadata": md,
892
+
893
+ # ranking signals (FORCED numeric)
894
+ "success_rate": float(md.get("success_rate", 0.5)),
895
+ "good": int(md.get("good_count", 0)),
896
+ "bad": int(md.get("bad_count", 0)),
897
+ "views": int(md.get("views", 0)),
898
+ "last_updated": float(md.get("last_updated_at", 0)),
899
+ })
900
+
901
+ if not candidates:
902
+ return None
903
+
904
+ # QUALITY-FIRST RANKING
905
+ candidates.sort(
906
+ key=lambda c: (
907
+ -c["success_rate"],
908
+ -c["good"],
909
+ c["bad"],
910
+ c["distance"],
911
+ -c["last_updated"],
912
+ )
913
+ )
914
+
915
+
916
+ return candidates[0]
917
+
918
+ # -------------------------------------------------------------------------
919
+ # Increment views
920
+ # -------------------------------------------------------------------------
921
+
922
+ def increment_cache_views(metadata: dict, store):
923
+ if not metadata:
924
+ return False
925
+
926
+ cache_id = metadata.get("cache_id")
927
+ if not cache_id:
928
+ return False
929
+
930
+ md = dict(metadata)
931
+ md["views"] = int(md.get("views", 0)) + 1
932
+ md["last_viewed_at"] = _now_ts()
933
+ md["last_updated_at"] = _now_ts()
934
+ md["saved_at"] = md.get("saved_at", _now_ts())
935
+
936
+ try:
937
+ langchain_upsert(
938
+ store=store,
939
+ text=normalize_text(md.get("normalized_question", "")),
940
+ metadata=md,
941
+ cache_id=cache_id,
942
+ )
943
+ return True
944
+ except Exception:
945
+ _logger.exception("increment_cache_views failed")
946
+ return False
947
+
948
+
949
+ # Update metrics on feedback
950
+ def update_cache_on_feedback(
951
+ question: str,
952
+ original_doc_md: dict,
953
+ user_marked_good: bool,
954
+ llm_corrected_sql: str | None,
955
+ llm_corrected_answer_md: str | None,
956
+ store,
957
+ ):
958
+ if not original_doc_md:
959
+ return
960
+
961
+ md = dict(original_doc_md["metadata"])
962
+ cache_id = md["cache_id"]
963
+
964
+ # ---- feedback counts ----
965
+ if user_marked_good:
966
+ md["good_count"] = md.get("good_count", 0) + 1
967
+ else:
968
+ md["bad_count"] = md.get("bad_count", 0) + 1
969
+
970
+ md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1
971
+ md["success_rate"] = (
972
+ md["good_count"] / md["total_feedbacks"]
973
+ if md["total_feedbacks"] > 0 else 0.5
974
+ )
975
+
976
+ # ---- timestamps ----
977
+ md["saved_at"] = md.get("saved_at", _now_ts()) # preserve
978
+ md["last_updated_at"] = _now_ts()
979
+
980
+ langchain_upsert(
981
+ store=store,
982
+ text=normalize_text(question),
983
+ metadata=md,
984
+ cache_id=cache_id,
985
+ )
986
+
987
+ # -------------------------
988
+ # Corrected SQL --> NEW ENTRY
989
+ # -------------------------
990
+ if llm_corrected_sql and llm_corrected_answer_md:
991
+ cache_sql_answer_dedup(
992
+ question=question,
993
+ sql=llm_corrected_sql,
994
+ answer_md=llm_corrected_answer_md,
995
+ metadata={
996
+ "good_count": 1,
997
+ "bad_count": 0,
998
+ "total_feedbacks": 1,
999
+ "success_rate": 1.0,
1000
+ "views": 0,
1001
+ "saved_at": _now_ts(),
1002
+ },
1003
+ store=store,
1004
+ )
1005
+
1006
+
1007
+ # %%
1008
+ import datetime
1009
+
1010
+ def log_pipeline(event: str, detail: str = ""):
1011
+ """
1012
+ Centralized pipeline logger.
1013
+ Shows in console + can be routed to file later.
1014
+ """
1015
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1016
+ msg = f"[{ts}] PIPELINE::{event}"
1017
+ if detail:
1018
+ msg += f" | {detail}"
1019
+ print(msg)
1020
+
1021
+
1022
+ # %%
1023
+
1024
+
1025
+ # %%
1026
+ # ### 8. Groq LLM via LangChain
1027
+
1028
+ from langchain_groq import ChatGroq
1029
+ import re
1030
+ import gradio as gr
1031
+
1032
+ llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
1033
+
1034
+ # %%
1035
+ def get_rag_context(question: str) -> str:
1036
+ """
1037
+ Retrieve the most relevant schema documents for the question.
1038
+ """
1039
+ docs = rag_retriever.invoke(question)
1040
+ return "\n\n---\n\n".join(d.page_content for d in docs)
1041
+
1042
+ def clean_sql(sql: str) -> str:
1043
+ sql = sql.strip()
1044
+ if "```" in sql:
1045
+ sql = sql.replace("```sql", "").replace("```", "").strip()
1046
+ return sql
1047
+
1048
+ def extract_sql_from_markdown(text: str) -> str:
1049
+ """
1050
+ Extract the first ```sql ... ``` block from LLM output.
1051
+ If not found, return the whole text.
1052
+ """
1053
+ match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
1054
+ if match:
1055
+ return match.group(1).strip()
1056
+ return text.strip()
1057
+
1058
+ def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str:
1059
+ """
1060
+ After the given marker, return the rest of the text as explanation.
1061
+ """
1062
+ idx = text.upper().find(marker.upper())
1063
+ if idx == -1:
1064
+ return text.strip()
1065
+ return text[idx + len(marker):].strip()
1066
+
1067
+ # 6b. General / descriptive questions
1068
+
1069
+ GENERAL_DESC_KEYWORDS = [
1070
+ "what is this dataset about",
1071
+ "what is this data about",
1072
+ "describe this dataset",
1073
+ "describe the dataset",
1074
+ "dataset overview",
1075
+ "data overview",
1076
+ "summary of the dataset",
1077
+ "explain this dataset",
1078
+ ]
1079
+
1080
+
1081
+ def is_general_question(question: str) -> bool:
1082
+ """
1083
+ Detect high-level descriptive questions where we should answer
1084
+ directly from schema context instead of generating SQL.
1085
+ """
1086
+ q = question.lower().strip()
1087
+ return any(key in q for key in GENERAL_DESC_KEYWORDS)
1088
+
1089
+
1090
+ def answer_general_question(question: str) -> str:
1091
+ """
1092
+ Use the RAG schema docs to generate a rich, high-level description
1093
+ of the Olist dataset for conceptual questions.
1094
+ """
1095
+ rag_context = get_rag_context(question)
1096
+
1097
+ system_instructions = """
1098
+ You are a data documentation expert.
1099
+
1100
+ You will be given:
1101
+ - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
1102
+ - A high-level user question like "what is this dataset about?".
1103
+
1104
+ Your job:
1105
+ - Write a clear, structured overview of the dataset.
1106
+ - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
1107
+ - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
1108
+ - Target a non-technical person.
1109
+
1110
+ Do NOT write SQL. Answer in Markdown.
1111
+ """
1112
+
1113
+ prompt = (
1114
+ system_instructions
1115
+ + "\n\n=== SCHEMA CONTEXT ===\n"
1116
+ + rag_context
1117
+ + "\n\n=== USER QUESTION ===\n"
1118
+ + question
1119
+ + "\n\nDetailed dataset overview:"
1120
+ )
1121
+
1122
+ log_prompt("GENERAL_DATASET_QUESTION", prompt)
1123
+
1124
+ response = llm.invoke(prompt)
1125
+ log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
1126
+
1127
+ return response.content.strip()
1128
+
1129
+ # %%
1130
+ # ### 9. SQL execution / validation
1131
+
1132
+ def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
1133
+ """
1134
+ Execute SQL on SQLite and return a DataFrame, else return an error.
1135
+ """
1136
+ try:
1137
+ df = pd.read_sql_query(sql, connection)
1138
+ return df, None
1139
+ except Exception as e:
1140
+ return None, str(e)
1141
+
1142
+ def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
1143
+ """
1144
+ Basic SQL validator:
1145
+ - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
1146
+ """
1147
+ try:
1148
+ cursor = connection.cursor()
1149
+ cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
1150
+ return True, None
1151
+ except Exception as e:
1152
+ return False, str(e)
1153
+
1154
+ # %%
1155
+ # ### 10. SQL Generation + Repair with LLM
1156
+ def review_and_correct_sql_with_llm(
1157
+ question: str,
1158
+ generated_sql: str,
1159
+ user_feedback_comment: str,
1160
+ validation_payload: dict,
1161
+ decision_reason: str,
1162
+ rag_context: str,
1163
+ ) -> tuple[str, str]:
1164
+ """
1165
+ Fix SQL ONLY after LLM has decided it is necessary.
1166
+ """
1167
+
1168
+ log_pipeline("SQL_CORRECTION_START")
1169
+
1170
+ evidence_md = ""
1171
+ for chk in validation_payload.get("equivalence", []):
1172
+ evidence_md += chk["df"].to_markdown(index=False) + "\n\n"
1173
+
1174
+ prompt = f"""
1175
+ Original question:
1176
+ {question}
1177
+
1178
+ Original SQL:
1179
+ ```sql
1180
+ {generated_sql}
1181
+
1182
+ User feedback:
1183
+ {user_feedback_comment}
1184
+
1185
+ Reason SQL is considered incorrect:
1186
+ {decision_reason}
1187
+
1188
+ Validation evidence:
1189
+ {evidence_md}
1190
+
1191
+ Schema:
1192
+ {rag_context}
1193
+
1194
+ TASK:
1195
+
1196
+ Fix the SQL ONLY if necessary
1197
+
1198
+ If SQL is already correct, return it unchanged
1199
+
1200
+ Explain your reasoning
1201
+
1202
+ Return corrected SQL and explanation.
1203
+ """
1204
+
1205
+ resp = llm.invoke(prompt)
1206
+
1207
+ corrected_sql = extract_sql_from_markdown(resp.content) or generated_sql
1208
+ corrected_sql = clean_sql(corrected_sql)
1209
+
1210
+ explanation = extract_explanation_after_marker(resp.content, "EXPLANATION:")
1211
+
1212
+ log_pipeline(
1213
+ "SQL_CORRECTION_DONE",
1214
+ f"sql_changed={corrected_sql.strip() != generated_sql.strip()}"
1215
+ )
1216
+
1217
+ return corrected_sql, explanation
1218
+
1219
+
1220
+
1221
+ # %%
1222
+
1223
+
1224
+ # %%
1225
+ # Generate SQL
1226
+
1227
+ def generate_sql(question: str, rag_context: str) -> str:
1228
+ """
1229
+ Generate SQL using LLM, but also pull in any past user feedback
1230
+ (corrected SQL) for this question as external guidance.
1231
+ """
1232
+ # External correction from past feedback, if any
1233
+ last_fb = get_last_feedback_for_question(conn, question)
1234
+ if last_fb and last_fb.get("corrected_sql"):
1235
+ previous_feedback_block = f"""
1236
+ EXTERNAL USER FEEDBACK FROM PAST RUNS:
1237
+
1238
+ Previous generated SQL:
1239
+ {last_fb['generated_sql']}
1240
+
1241
+ Corrected SQL (preferred reference):
1242
+ {last_fb['corrected_sql']}
1243
+
1244
+ User comment & prior explanation:
1245
+ {last_fb.get('comment') or '(none)'}
1246
+
1247
+ You must avoid repeating the same mistake and should follow the logic of
1248
+ the corrected SQL when appropriate, while still reasoning from the schema.
1249
+ """
1250
+ else:
1251
+ previous_feedback_block = ""
1252
+
1253
+ system_instructions = f"""
1254
+ You are a senior data analyst writing SQL for a SQLite database.
1255
+
1256
+ You will be given:
1257
+
1258
+ - A description of available tables, columns, and relationships (schema + YAML metadata).
1259
+ - A natural language question from the user.
1260
+
1261
+ Your job:
1262
+
1263
+ - Write ONE valid SQLite SQL query that answers the question.
1264
+ - ONLY use tables and columns that exist in the schema_context.
1265
+ - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1266
+ - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1267
+ - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator,
1268
+ and round to 2 decimals when appropriate.
1269
+
1270
+ {previous_feedback_block}
1271
+ """
1272
+
1273
+ prompt = (
1274
+ system_instructions
1275
+ + "\n\n=== RAG CONTEXT ===\n"
1276
+ + rag_context
1277
+ + "\n\n=== USER QUESTION ===\n"
1278
+ + question
1279
+ + "\n\nSQL query:"
1280
+ )
1281
+
1282
+ log_prompt("SQL_GENERATION", prompt)
1283
+ response = llm.invoke(prompt)
1284
+ log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
1285
+
1286
+ sql = clean_sql(response.content)
1287
+ return sql
1288
+
1289
+
1290
+ # %%
1291
+ # Repair SQL
1292
+
1293
+ def repair_sql(
1294
+ question: str,
1295
+ rag_context: str,
1296
+ bad_sql: str,
1297
+ error_message: str,
1298
+ ) -> str:
1299
+ """
1300
+ Ask the LLM to correct a failing SQL query.
1301
+ """
1302
+ system_instructions = """
1303
+ You are a senior data analyst fixing an existing SQL query for a SQLite database.
1304
+
1305
+ You will be given:
1306
+
1307
+ - Schema context (tables, columns, relationships).
1308
+ - The user's question.
1309
+ - A previously generated SQL query that failed.
1310
+ - The SQLite error message.
1311
+
1312
+ Your job:
1313
+
1314
+ - Diagnose why the query failed.
1315
+ - Rewrite ONE valid SQLite SQL query that answers the question.
1316
+ - ONLY use tables and columns that exist in the schema_context.
1317
+ - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1318
+ - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1319
+ - Return ONLY the corrected SQL query, no explanation or markdown.
1320
+ """
1321
+
1322
+ prompt = (
1323
+ system_instructions
1324
+ + "\n\n=== RAG CONTEXT ===\n"
1325
+ + rag_context
1326
+ + "\n\n=== USER QUESTION ===\n"
1327
+ + question
1328
+ + "\n\n=== PREVIOUS (FAILING) SQL ===\n"
1329
+ + bad_sql
1330
+ + "\n\n=== SQLITE ERROR ===\n"
1331
+ + error_message
1332
+ + "\n\nCorrected SQL query:"
1333
+ )
1334
+
1335
+ log_prompt("SQL_REPAIR", prompt)
1336
+ response = llm.invoke(prompt)
1337
+ log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
1338
+
1339
+ sql = clean_sql(response.content)
1340
+ return sql
1341
+
1342
+
1343
+ # %%
1344
+ ### 11. Result summarization
1345
+
1346
+ def summarize_results(
1347
+ question: str,
1348
+ sql: str,
1349
+ df: Optional[pd.DataFrame],
1350
+ rag_context: str,
1351
+ error: Optional[str] = None,
1352
+ ) -> str:
1353
+ """
1354
+ Ask the LLM to produce a concise, human-readable answer.
1355
+ """
1356
+ system_instructions = """
1357
+ You are a senior data analyst.
1358
+
1359
+ You will be given:
1360
+
1361
+ The user's question.
1362
+ The final SQL that was executed.
1363
+ A small preview of the query result (as a Markdown table, if available).
1364
+ Optional error information if the query failed.
1365
+
1366
+ Your job:
1367
+
1368
+ Provide a clear, concise answer in Markdown.
1369
+ If the result is numeric / aggregated, explain what it means in business terms.
1370
+ If there was an error, explain it simply and suggest how the user could rephrase.
1371
+ Do NOT show raw SQL unless it is helpful to the user.
1372
+ """
1373
+
1374
+ # Build a markdown table preview if we have data
1375
+ if df is not None and not df.empty:
1376
+ preview_rows = min(len(df), 50)
1377
+ df_preview_md = df.head(preview_rows).to_markdown(index=False)
1378
+ else:
1379
+ df_preview_md = "(no rows returned)"
1380
+
1381
+ prompt = (
1382
+ system_instructions
1383
+ + "\n\n=== USER QUESTION ===\n"
1384
+ + question
1385
+ + "\n\n=== EXECUTED SQL ===\n"
1386
+ + sql
1387
+ + "\n\n=== QUERY RESULT PREVIEW ===\n"
1388
+ + df_preview_md
1389
+ + "\n\n=== RAG CONTEXT (schema) ===\n"
1390
+ + rag_context
1391
+ )
1392
+
1393
+ if error:
1394
+ prompt += "\n\n=== ERROR ===\n" + error
1395
+
1396
+ # Logging helpers assumed to exist
1397
+ log_prompt("RESULT_SUMMARY", prompt)
1398
+
1399
+ response = llm.invoke(prompt)
1400
+ log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
1401
+
1402
+ return response.content.strip()
1403
+
1404
+
1405
+ # %%
1406
+ def backend_pipeline(question: str):
1407
+ """
1408
+ STRICT cache-first backend.
1409
+ ALWAYS returns EXACTLY 10 values.
1410
+ """
1411
+
1412
+ if not question or not question.strip():
1413
+ return (
1414
+ "Please type a question.",
1415
+ pd.DataFrame(),
1416
+ "",
1417
+ "",
1418
+ "",
1419
+ "",
1420
+ {},
1421
+ 4,
1422
+ "**Feedback attempts remaining: 4**",
1423
+ "",
1424
+ )
1425
+
1426
+ attempts_left = 4
1427
+ attempts_text = f"**Feedback attempts remaining: {attempts_left}**"
1428
+ store = get_sql_cache_store()
1429
+
1430
+ # ---------------- CACHE LOOKUP ----------------
1431
+ cached = retrieve_exact_cached_sql(question, store)
1432
+ if cached is None:
1433
+ cached = retrieve_best_cached_sql(question, store, max_distance=0.25)
1434
+
1435
+ # ---------------- CACHE HIT ----------------
1436
+ if cached:
1437
+ increment_cache_views(cached["metadata"], store)
1438
+
1439
+ rag_context = get_rag_context(question)
1440
+ sql = cached["sql"]
1441
+
1442
+ df, err = execute_sql(sql, conn)
1443
+ if err:
1444
+ df = pd.DataFrame()
1445
+
1446
+ md = cached["metadata"]
1447
+
1448
+ answer_md = cached.get("answer_md", "")
1449
+
1450
+ debug_md = f"""
1451
+ ### Query Provenance (Cache)
1452
+
1453
+ **Matched question:** `{cached.get("matched_question", "")}`
1454
+ **Success rate:** {md.get("success_rate", 0.5):.2f}
1455
+ **Good / Bad:** {md.get("good_count", 0)} / {md.get("bad_count", 0)}
1456
+ **Views:** {md.get("views", 0)}
1457
+
1458
+ **SQL used:**
1459
+ ```sql
1460
+ {sql}
1461
+
1462
+
1463
+ """
1464
+
1465
+ return (
1466
+ answer_md,
1467
+ df,
1468
+ sql,
1469
+ rag_context,
1470
+ question,
1471
+ answer_md,
1472
+ df_to_state(df),
1473
+ attempts_left,
1474
+ attempts_text,
1475
+ debug_md,
1476
+ )
1477
+
1478
+ # ---------------- LLM FLOW ----------------
1479
+ rag_context = get_rag_context(question)
1480
+ sql = generate_sql(question, rag_context)
1481
+
1482
+ is_valid, err = validate_sql(sql, conn)
1483
+ if not is_valid and err:
1484
+ sql = repair_sql(question, rag_context, sql, err)
1485
+
1486
+ df, exec_error = execute_sql(sql, conn)
1487
+ if exec_error:
1488
+ df = pd.DataFrame()
1489
+
1490
+ answer_md = summarize_results(
1491
+ question=question,
1492
+ sql=sql,
1493
+ df=df if not exec_error else None,
1494
+ rag_context=rag_context,
1495
+ error=exec_error,
1496
+ )
1497
+
1498
+ # Cache LLM result
1499
+ cache_sql_answer_dedup(
1500
+ question=question,
1501
+ sql=sql,
1502
+ answer_md=answer_md,
1503
+ metadata={
1504
+ "good_count": 0,
1505
+ "bad_count": 0,
1506
+ "total_feedbacks": 0,
1507
+ "success_rate": 0.5,
1508
+ "views": 1,
1509
+ "saved_at": time.time(),
1510
+ },
1511
+ store=store,
1512
+ )
1513
+
1514
+ debug_md = f"""
1515
+
1516
+ Query Provenance (LLM Generated)
1517
+
1518
+ Source: LLM
1519
+ Execution status: {"Error" if exec_error else "Success"}
1520
+
1521
+ SQL used:
1522
+
1523
+ {sql}
1524
+
1525
+
1526
+ """
1527
+
1528
+ return (
1529
+ answer_md,
1530
+ df,
1531
+ sql,
1532
+ rag_context,
1533
+ question,
1534
+ answer_md,
1535
+ df_to_state(df),
1536
+ attempts_left,
1537
+ attempts_text,
1538
+ debug_md,
1539
+ )
1540
+
1541
+ # %%
1542
+ import re
1543
+
1544
+ def looks_like_sql(text: str) -> bool:
1545
+ if not text:
1546
+ return False
1547
+ return bool(re.search(
1548
+ r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b",
1549
+ text,
1550
+ flags=re.I,
1551
+ ))
1552
+
1553
+
1554
+ def is_feedback_sufficient(feedback_text: str) -> bool:
1555
+ """
1556
+ Decide whether feedback is actionable WITHOUT LLM.
1557
+ """
1558
+ if not feedback_text:
1559
+ return False
1560
+
1561
+ text = feedback_text.strip().lower()
1562
+
1563
+ # Long feedback → sufficient
1564
+ if len(text) >= 60:
1565
+ return True
1566
+
1567
+ # SQL pasted
1568
+ if looks_like_sql(text):
1569
+ return True
1570
+
1571
+ # Numbers + comparison language → sufficient
1572
+ if re.search(r"\b\d+\b", text) and any(
1573
+ k in text for k in [
1574
+ "last year", "previous", "this year",
1575
+ "increase", "decrease", "higher", "lower"
1576
+ ]
1577
+ ):
1578
+ return True
1579
+
1580
+ # Analytical keywords
1581
+ signal_words = [
1582
+ "filter", "where", "year", "month", "should", "instead",
1583
+ "expected", "wrong", "missing", "count", "distinct",
1584
+ "join", "exclude", "include", "only"
1585
+ ]
1586
+
1587
+ signals = sum(1 for w in signal_words if w in text)
1588
+ if signals >= 1 and len(text) >= 20:
1589
+ return True
1590
+
1591
+ return False
1592
+
1593
+
1594
+ # %%
1595
+ def build_followup_prompt_for_user(sample_feedback: str = "") -> str:
1596
+ """
1597
+ Deterministic follow-up question to ask the user when feedback is vague.
1598
+ Returns a friendly prompt that the UI can display to the user.
1599
+ """
1600
+ base = (
1601
+ "Thanks — I need a bit more detail to act on this feedback.\n\n"
1602
+ "Please tell me one (or more) of the following so I can check and correct the result:\n\n"
1603
+ "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n"
1604
+ " the **time range** (year/month), or the **filters** applied?\n"
1605
+ "2. If you expected a different number, what was the expected number (and how was it computed)?\n"
1606
+ "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n"
1607
+ "Examples you can copy-paste:\n"
1608
+ )
1609
+ examples = (
1610
+ "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n"
1611
+ "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n"
1612
+ "- \"Please exclude canceled orders (order_status = 'canceled').\"\n"
1613
+ "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n"
1614
+ )
1615
+ hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare."
1616
+ prompt = base + examples + hint
1617
+ if sample_feedback:
1618
+ prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt
1619
+ return prompt
1620
+
1621
+
1622
+
1623
+ # %%
1624
+ def interpret_feedback_intent(question: str, feedback_text: str) -> dict:
1625
+ """
1626
+ Convert business feedback into a testable hypothesis.
1627
+ NO SQL generation here.
1628
+ """
1629
+
1630
+ system_prompt = """
1631
+ You are a senior analytics reviewer.
1632
+
1633
+ Your task:
1634
+ - Identify what the user is CLAIMING.
1635
+ - Do NOT assume they are correct.
1636
+ - Do NOT write SQL.
1637
+
1638
+ Possible intents:
1639
+ - comparison_claim (mentions last year / previous / higher / lower)
1640
+ - overcount
1641
+ - undercount
1642
+ - missing_entities
1643
+ - wrong_filter
1644
+ - wrong_time_range
1645
+ - unclear
1646
+
1647
+ Return JSON ONLY.
1648
+ """
1649
+
1650
+ prompt = f"""
1651
+ USER QUESTION:
1652
+ {question}
1653
+
1654
+ USER FEEDBACK:
1655
+ {feedback_text}
1656
+
1657
+ Return JSON exactly like this:
1658
+ {{
1659
+ "intent": "comparison_claim | overcount | undercount | missing_entities | wrong_filter | wrong_time_range | unclear",
1660
+ "confidence": "low | medium | high"
1661
+ }}
1662
+ """
1663
+
1664
+ resp = llm.invoke(system_prompt + "\n\n" + prompt)
1665
+ print(resp)
1666
+
1667
+ try:
1668
+ return json.loads(resp.content)
1669
+ except Exception:
1670
+ return {
1671
+ "intent": "unclear",
1672
+ "confidence": "low",
1673
+ }
1674
+
1675
+
1676
+ # %%
1677
+ def run_validation_queries(
1678
+ intent_info: dict,
1679
+ question: str,
1680
+ original_sql: str,
1681
+ user_feedback: str,
1682
+ rag_context: str,
1683
+ ) -> dict:
1684
+ """
1685
+ Generate AND EXECUTE validation SQL.
1686
+ Separates equivalence checks vs context checks.
1687
+ Uses USER FEEDBACK explicitly for context validation.
1688
+ """
1689
+
1690
+ log_pipeline(
1691
+ "VALIDATION_START",
1692
+ f"intent={intent_info.get('intent')}"
1693
+ )
1694
+
1695
+ system_prompt = f"""
1696
+ You are validating a business claim.
1697
+
1698
+ User intent:
1699
+ {intent_info.get("intent")}
1700
+
1701
+ Original question:
1702
+ {question}
1703
+
1704
+ User feedback (IMPORTANT):
1705
+ {user_feedback}
1706
+
1707
+ Original SQL:
1708
+ {original_sql}
1709
+
1710
+ Schema context:
1711
+ {rag_context}
1712
+
1713
+ TASK:
1714
+ - Generate 2-4 SQL queries.
1715
+ - Mark EACH query explicitly as one of:
1716
+ - "equivalence":
1717
+ * Alternative formulations of the SAME question
1718
+ * Used to verify SQL correctness (joins, filters, logic)
1719
+ - "context":
1720
+ * Used to VERIFY or CHALLENGE the user's feedback
1721
+ * If the user mentions years, comparisons, growth, or missing entities,
1722
+ generate year-over-year or comparative queries.
1723
+ - DO NOT fix or modify the original SQL.
1724
+ - These queries are ONLY for validation and analysis.
1725
+
1726
+ IMPORTANT:
1727
+ - Context queries MUST be driven by the user's feedback.
1728
+ - If the user claims a number, for example, values of a previous year, then generate queries that can verify that claim from the data.
1729
+ - If the user provides some query to test in the feedback for the original question, use it as well. Do not assume it is correct.
1730
+ Return JSON ONLY in this format:
1731
+ {{
1732
+ "checks": [
1733
+ {{
1734
+ "type": "equivalence | context",
1735
+ "sql": "SELECT ...",
1736
+ "label": "short description of what this validates"
1737
+ }}
1738
+ ]
1739
+ }}
1740
+ """
1741
+
1742
+ resp = llm.invoke(system_prompt)
1743
+ raw = resp.content.strip()
1744
+ print(raw)
1745
+
1746
+ # Strip ```json fences if present
1747
+ if raw.startswith("```"):
1748
+ raw = raw.strip("`")
1749
+ raw = raw.replace("json", "", 1).strip()
1750
+
1751
+ try:
1752
+ parsed = json.loads(raw)
1753
+ except Exception as e:
1754
+ log_pipeline("VALIDATION_PARSE_FAILED", str(e))
1755
+ return {
1756
+ "equivalence": [],
1757
+ "context": [],
1758
+ }
1759
+
1760
+ equivalence = []
1761
+ context = []
1762
+
1763
+ for chk in parsed.get("checks", []):
1764
+ sql = chk.get("sql")
1765
+ chk_type = chk.get("type")
1766
+ label = chk.get("label", "unnamed check")
1767
+
1768
+ if not sql or chk_type not in ("equivalence", "context"):
1769
+ continue
1770
+
1771
+ df, err = execute_sql(sql, conn)
1772
+ print(df)
1773
+
1774
+ log_pipeline(
1775
+ "VALIDATION_SQL_EXECUTED",
1776
+ f"type={chk_type} | label={label}"
1777
+ )
1778
+
1779
+ if err or df is None or df.empty:
1780
+ continue
1781
+
1782
+ payload = {
1783
+ "label": label,
1784
+ "sql": sql,
1785
+ "df": df,
1786
+ }
1787
+
1788
+ if chk_type == "equivalence":
1789
+ equivalence.append(payload)
1790
+ else:
1791
+ context.append(payload)
1792
+
1793
+ log_pipeline(
1794
+ "VALIDATION_DONE",
1795
+ f"equivalence={len(equivalence)}, context={len(context)}"
1796
+ )
1797
+
1798
+ return {
1799
+ "equivalence": equivalence,
1800
+ "context": context,
1801
+ }
1802
+
1803
+
1804
+ # %%
1805
+ def safe_json_load(text: str) -> dict | None:
1806
+ try:
1807
+ return json.loads(text)
1808
+ except Exception:
1809
+ pass
1810
+
1811
+ # Try extracting first JSON block
1812
+ match = re.search(r"\{[\s\S]*\}", text)
1813
+ if match:
1814
+ try:
1815
+ return json.loads(match.group(0))
1816
+ except Exception:
1817
+ return None
1818
+ return None
1819
+
1820
+
1821
+ # %%
1822
+ def decide_after_validation_with_llm(
1823
+ question: str,
1824
+ original_sql: str,
1825
+ user_feedback: str,
1826
+ intent_info: dict,
1827
+ original_df: pd.DataFrame,
1828
+ validation_payload: dict,
1829
+ rag_context: str,
1830
+ ) -> dict:
1831
+ """
1832
+ Use LLM to decide whether SQL must change or explanation is sufficient.
1833
+ Returns a structured decision.
1834
+ """
1835
+
1836
+ log_pipeline("DECISION_LLM_START")
1837
+
1838
+ # Build evidence
1839
+ evidence_md = "### Original Result\n"
1840
+ evidence_md += original_df.to_markdown(index=False)
1841
+
1842
+ for i, chk in enumerate(validation_payload.get("equivalence", []), 1):
1843
+ evidence_md += f"\n\n### Equivalence Check {i}: {chk['label']}\n"
1844
+ evidence_md += chk["df"].to_markdown(index=False)
1845
+
1846
+ for i, chk in enumerate(validation_payload.get("context", []), 1):
1847
+ evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n"
1848
+ evidence_md += chk["df"].to_markdown(index=False)
1849
+
1850
+ system_prompt = """
1851
+ You are a senior analytics and data analyst.
1852
+
1853
+ Your task:
1854
+ - Decide whether the ORIGINAL SQL is logically correct or not, given the original question, the initial result, the feedback and the evidences of the equivalence check and context shown to you.
1855
+ - Do NOT judge based on expectations or growth alone.
1856
+ - Different definitions (e.g. customers vs customers-with-orders) are NOT errors.
1857
+ - Only mark SQL incorrect if evidences shows a true flaw in the results or the logic.
1858
+
1859
+ CRITICAL RULES:
1860
+ - You must return VALID JSON.
1861
+ - DO NOT add explnations outside JSON.
1862
+ - DO NO use markdown.
1863
+ - DO NOT include extra text.
1864
+
1865
+ Return JSON ONLY in this format:
1866
+ {
1867
+ "decision": "correct_sql | not_correct_sql",
1868
+ "confidence": "high | medium | low",
1869
+ "reason": "provide an explanation here"
1870
+ }
1871
+ """
1872
+
1873
+ prompt = f"""
1874
+ Original question:
1875
+ {question}
1876
+
1877
+ Original SQL:
1878
+ {original_sql}
1879
+
1880
+ User feedback:
1881
+ {user_feedback}
1882
+
1883
+ Interpreted intent:
1884
+ {intent_info}
1885
+
1886
+ Schema context:
1887
+ {rag_context}
1888
+
1889
+ Below is the validation evidence.
1890
+ """
1891
+
1892
+ resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md)
1893
+
1894
+ parsed = safe_json_load(resp.content)
1895
+
1896
+ if not parsed:
1897
+ decision = {
1898
+ "decision": "correct_sql",
1899
+ "confidence": "low",
1900
+ "reason": "LLM response was not valid JSON"
1901
+ }
1902
+ else:
1903
+ decision = parsed
1904
+
1905
+
1906
+ log_pipeline(
1907
+ "DECISION_LLM_DONE",
1908
+ f"decision={decision.get('decision')} | confidence={decision.get('confidence')} | reason ={decision.get('reason')}"
1909
+ )
1910
+
1911
+ return decision
1912
+
1913
+
1914
+ # %%
1915
+ def explain_validation_with_llm(
1916
+ question: str,
1917
+ feedback: str,
1918
+ original_df: pd.DataFrame,
1919
+ validation_payload: dict,
1920
+ schema_context: str,
1921
+ ) -> str:
1922
+ """
1923
+ Business explanation of validation results.
1924
+ """
1925
+
1926
+ log_pipeline("EXPLAIN_START")
1927
+
1928
+ evidence_md = (
1929
+ "## Pipeline: Feedback --> Validation -> Explanation\n\n"
1930
+ "### Baseline — Original Result\n"
1931
+ )
1932
+ evidence_md += original_df.to_markdown(index=False)
1933
+
1934
+ for i, chk in enumerate(validation_payload.get("context", []), 1):
1935
+ evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n"
1936
+ evidence_md += chk["df"].to_markdown(index=False)
1937
+
1938
+ for i, chk in enumerate(validation_payload.get("equivalence", []), 1):
1939
+ evidence_md += f"\n\n### Verification Check {i}: {chk['label']}\n"
1940
+ evidence_md += chk["df"].to_markdown(index=False)
1941
+
1942
+ system_prompt = """
1943
+ You are a senior data analyst explaining results to a business user.
1944
+
1945
+ Rules:
1946
+ - DO NOT mention SQL mechanics
1947
+ - Compare numbers explicitly
1948
+ - Explain trends year-over-year.
1949
+ - Explain what the actual numbers are and what it means for the business.
1950
+ - If the user is concerned about the SQL or the numbers in their feedback, based on verification and context checks, explain the results to them.
1951
+ """
1952
+
1953
+ prompt = f"""
1954
+ Original question:
1955
+ {question}
1956
+
1957
+ User feedback:
1958
+ {feedback}
1959
+
1960
+ Schema of the data:
1961
+ {schema_context}
1962
+ """
1963
+
1964
+ resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md)
1965
+
1966
+ log_pipeline("EXPLAIN_DONE")
1967
+
1968
+ return resp.content.strip()
1969
+
1970
+
1971
+ # %%
1972
+
1973
+
1974
+ # %%
1975
+
1976
+
1977
+ # %%
1978
+ def feedback_pipeline_interactive(
1979
+ feedback_rating: str,
1980
+ feedback_comment: str,
1981
+ last_sql: str,
1982
+ last_rag_context: str,
1983
+ last_question: str,
1984
+ last_answer_md: str,
1985
+ last_df_state: dict,
1986
+ feedback_sql: str,
1987
+ attempts_left: int,
1988
+ ):
1989
+ """
1990
+ ALWAYS returns EXACTLY 10 values
1991
+ NO booleans
1992
+ followup_signal belongs to {"FOLLOWUP", "NONE"}
1993
+ """
1994
+
1995
+ last_df = state_to_df(last_df_state)
1996
+ rating = (feedback_rating or "").strip().lower()
1997
+ comment = (feedback_comment or "").strip()
1998
+ attempts_left = max(0, int(attempts_left or 0))
1999
+
2000
+ # ---------- INVALID RATING ----------
2001
+ if rating not in {"correct", "wrong"}:
2002
+ return (
2003
+ last_answer_md,
2004
+ last_df,
2005
+ last_sql,
2006
+ last_rag_context,
2007
+ last_question,
2008
+ last_answer_md,
2009
+ df_to_state(last_df),
2010
+ "NONE",
2011
+ "",
2012
+ attempts_left,
2013
+ )
2014
+
2015
+ # ---------- CORRECT ----------
2016
+ if rating == "correct":
2017
+ update_cache_on_feedback(
2018
+ question=last_question,
2019
+ original_doc_md=find_cached_doc_by_sql(
2020
+ last_question, last_sql, get_sql_cache_store()
2021
+ ),
2022
+ user_marked_good=True,
2023
+ llm_corrected_sql=None,
2024
+ llm_corrected_answer_md=None,
2025
+ store=get_sql_cache_store(),
2026
+ )
2027
+
2028
+ md = "**Confirmed correct by user**\n\n" + last_answer_md
2029
+ return (
2030
+ md,
2031
+ last_df,
2032
+ last_sql,
2033
+ last_rag_context,
2034
+ last_question,
2035
+ md,
2036
+ df_to_state(last_df),
2037
+ "NONE",
2038
+ "",
2039
+ attempts_left,
2040
+ )
2041
+
2042
+ # ---------- WRONG ----------
2043
+ attempts_left = max(0, attempts_left - 1)
2044
+
2045
+ if not is_feedback_sufficient(comment):
2046
+ return (
2047
+ last_answer_md,
2048
+ last_df,
2049
+ last_sql,
2050
+ last_rag_context,
2051
+ last_question,
2052
+ last_answer_md,
2053
+ df_to_state(last_df),
2054
+ "FOLLOWUP",
2055
+ build_followup_prompt_for_user(comment),
2056
+ attempts_left,
2057
+ )
2058
+
2059
+ # ---------- VALIDATION ----------
2060
+ intent_info = interpret_feedback_intent(last_question, comment)
2061
+
2062
+ validation = run_validation_queries(
2063
+ intent_info=intent_info,
2064
+ question=last_question,
2065
+ original_sql=last_sql,
2066
+ user_feedback=comment,
2067
+ rag_context=last_rag_context,
2068
+ )
2069
+
2070
+ decision = decide_after_validation_with_llm(
2071
+ question=last_question,
2072
+ original_sql=last_sql,
2073
+ user_feedback=comment,
2074
+ intent_info=intent_info,
2075
+ original_df=last_df,
2076
+ validation_payload=validation,
2077
+ rag_context=last_rag_context,
2078
+ )
2079
+
2080
+ # ---------- EXPLAIN ----------
2081
+ if decision.get("decision") == "correct_sql":
2082
+ explanation = explain_validation_with_llm(
2083
+ question=last_question,
2084
+ feedback=comment,
2085
+ original_df=last_df,
2086
+ validation_payload=validation,
2087
+ schema_context=last_rag_context,
2088
+ )
2089
+
2090
+ return (
2091
+ explanation,
2092
+ last_df,
2093
+ last_sql,
2094
+ last_rag_context,
2095
+ last_question,
2096
+ explanation,
2097
+ df_to_state(last_df),
2098
+ "NONE",
2099
+ "",
2100
+ attempts_left,
2101
+ )
2102
+
2103
+ # ---------- SQL CORRECTION ----------
2104
+ corrected_sql, explanation = review_and_correct_sql_with_llm(
2105
+ question=last_question,
2106
+ generated_sql=last_sql,
2107
+ user_feedback_comment=comment,
2108
+ validation_payload=validation,
2109
+ decision_reason=decision.get("reason", ""),
2110
+ rag_context=last_rag_context,
2111
+ )
2112
+
2113
+ df_new, err = execute_sql(corrected_sql, conn)
2114
+
2115
+ final_md = "**SQL corrected**\n\n" + summarize_results(
2116
+ question=last_question,
2117
+ sql=corrected_sql,
2118
+ df=df_new if not err else None,
2119
+ rag_context=last_rag_context,
2120
+ error=err,
2121
+ )
2122
+
2123
+ return (
2124
+ final_md,
2125
+ df_new if not err else pd.DataFrame(),
2126
+ corrected_sql,
2127
+ last_rag_context,
2128
+ last_question,
2129
+ final_md,
2130
+ df_to_state(df_new if not err else pd.DataFrame()),
2131
+ "NONE",
2132
+ "",
2133
+ attempts_left,
2134
+ )
2135
+
2136
+
2137
+ # %%
2138
+
2139
+
2140
+ # %%
2141
+ import gradio as gr
2142
+ import pandas as pd
2143
+
2144
+ # -------------------------------
2145
+ # Helpers
2146
+ # -------------------------------
2147
+ def attempts_text(attempts: int) -> str:
2148
+ if attempts <= 0:
2149
+ return "**Feedback exhausted. Please ask a new question.**"
2150
+ return f"**Feedback attempts remaining: {attempts}**"
2151
+
2152
+ # -------------------------------
2153
+ # CHAT TURN
2154
+ # -------------------------------
2155
+ def chat_turn(user_input, history, state):
2156
+ history = history or []
2157
+
2158
+ history.append({"role": "user", "content": user_input})
2159
+ history.append({"role": "assistant", "content": "🔎 Thinking…"})
2160
+
2161
+ (
2162
+ answer_md,
2163
+ _df,
2164
+ sql,
2165
+ rag,
2166
+ q,
2167
+ ans_state,
2168
+ df_state,
2169
+ attempts,
2170
+ attempts_md,
2171
+ debug_md,
2172
+ ) = backend_pipeline(user_input)
2173
+
2174
+ history.pop()
2175
+ history.append({"role": "assistant", "content": answer_md})
2176
+
2177
+ new_state = {
2178
+ "last_sql": sql,
2179
+ "last_rag": rag,
2180
+ "last_question": q,
2181
+ "last_answer": ans_state,
2182
+ "last_df": df_state,
2183
+ "attempts": attempts,
2184
+ }
2185
+
2186
+ df = state_to_df(df_state)
2187
+
2188
+ return history, new_state, attempts_md, debug_md, df
2189
+
2190
+
2191
+ # -------------------------------
2192
+ # FEEDBACK TURN
2193
+ # -------------------------------
2194
+ def feedback_turn(feedback_rating, feedback_text, history, state):
2195
+ history = history or []
2196
+
2197
+ # ---------- Attempts exhausted ----------
2198
+ if state["attempts"] <= 0:
2199
+ history.append({
2200
+ "role": "assistant",
2201
+ "content": "Feedback attempts are exhausted. Please ask a new question."
2202
+ })
2203
+ return history, state, attempts_text(0), pd.DataFrame()
2204
+
2205
+ # ---------- BAD without text (consume attempt) ----------
2206
+ if feedback_rating == "wrong" and not (feedback_text or "").strip():
2207
+ new_attempts = max(0, state["attempts"] - 1)
2208
+
2209
+ history.append({
2210
+ "role": "assistant",
2211
+ "content": "Please provide details before submitting negative feedback."
2212
+ })
2213
+
2214
+ new_state = dict(state)
2215
+ new_state["attempts"] = new_attempts
2216
+
2217
+ return history, new_state, attempts_text(new_attempts), state_to_df(state["last_df"])
2218
+
2219
+ # ---------- Normal feedback ----------
2220
+ history.append({
2221
+ "role": "user",
2222
+ "content": f"{feedback_rating.upper()} feedback"
2223
+ })
2224
+ history.append({
2225
+ "role": "assistant",
2226
+ "content": "Validating feedback…"
2227
+ })
2228
+
2229
+ (
2230
+ answer_md,
2231
+ _,
2232
+ sql_new,
2233
+ rag_new,
2234
+ q_new,
2235
+ ans_state,
2236
+ df_state,
2237
+ followup_flag,
2238
+ followup_prompt,
2239
+ attempts_new,
2240
+ ) = feedback_pipeline_interactive(
2241
+ feedback_rating=feedback_rating,
2242
+ feedback_comment=feedback_text,
2243
+ last_sql=state["last_sql"],
2244
+ last_rag_context=state["last_rag"],
2245
+ last_question=state["last_question"],
2246
+ last_answer_md=state["last_answer"],
2247
+ last_df_state=state["last_df"],
2248
+ feedback_sql="",
2249
+ attempts_left=state["attempts"],
2250
+ )
2251
+
2252
+ history.pop()
2253
+
2254
+ history.append({
2255
+ "role": "assistant",
2256
+ "content": followup_prompt if followup_flag == "FOLLOWUP" else answer_md
2257
+ })
2258
+
2259
+ new_state = {
2260
+ "last_sql": sql_new,
2261
+ "last_rag": rag_new,
2262
+ "last_question": q_new,
2263
+ "last_answer": ans_state,
2264
+ "last_df": df_state,
2265
+ "attempts": attempts_new,
2266
+ }
2267
+
2268
+ df = state_to_df(df_state)
2269
+
2270
+ return history, new_state, attempts_text(attempts_new), df
2271
+
2272
+
2273
+ # -------------------------------
2274
+ # UI
2275
+ # -------------------------------
2276
+ with gr.Blocks(analytics_enabled=False) as demo:
2277
+ gr.Markdown("# Analytics Assistant Chatbot")
2278
+
2279
+ chat = gr.Chatbot(height=520)
2280
+
2281
+ attempts_md = gr.Markdown(value=attempts_text(4))
2282
+
2283
+ show_debug = gr.Checkbox(label="Show SQL / Cache info", value=False)
2284
+ debug_panel = gr.Markdown(visible=False)
2285
+
2286
+ result_df = gr.Dataframe(
2287
+ label="Query Results",
2288
+ wrap=True,
2289
+ interactive=False,
2290
+ )
2291
+
2292
+ state = gr.State({
2293
+ "last_sql": "",
2294
+ "last_rag": "",
2295
+ "last_question": "",
2296
+ "last_answer": "",
2297
+ "last_df": {},
2298
+ "attempts": 4,
2299
+ })
2300
+
2301
+ user_input = gr.Textbox(
2302
+ placeholder="Ask a question…",
2303
+ lines=1,
2304
+ show_label=False,
2305
+ )
2306
+
2307
+ feedback_input = gr.Textbox(
2308
+ placeholder="Details (mandatory if the result seems wrong to you)…",
2309
+ lines=2,
2310
+ show_label=False,
2311
+ )
2312
+
2313
+ # Buttons
2314
+ with gr.Row():
2315
+ gr.Column(scale=6)
2316
+ with gr.Column(scale=2, min_width=110):
2317
+ good_btn = gr.Button("Good", size="sm")
2318
+ with gr.Column(scale=2, min_width=110):
2319
+ bad_btn = gr.Button("Bad", size="sm")
2320
+
2321
+ # Toggle debug panel
2322
+ show_debug.change(
2323
+ lambda show, md: gr.update(value=md, visible=show),
2324
+ inputs=[show_debug, debug_panel],
2325
+ outputs=debug_panel,
2326
+ )
2327
+
2328
+ # Question submit
2329
+ user_input.submit(
2330
+ chat_turn,
2331
+ inputs=[user_input, chat, state],
2332
+ outputs=[chat, state, attempts_md, debug_panel, result_df],
2333
+ queue=False,
2334
+ )
2335
+
2336
+ # GOOD feedback
2337
+ good_btn.click(
2338
+ feedback_turn,
2339
+ inputs=[gr.State("correct"), feedback_input, chat, state],
2340
+ outputs=[chat, state, attempts_md, result_df],
2341
+ queue=False,
2342
+ )
2343
+
2344
+ # BAD feedback
2345
+ bad_btn.click(
2346
+ feedback_turn,
2347
+ inputs=[gr.State("wrong"), feedback_input, chat, state],
2348
+ outputs=[chat, state, attempts_md, result_df],
2349
+ queue=False,
2350
+ )
2351
+
2352
+
2353
+ # %%
2354
+ if __name__ == "__main__":
2355
+ demo.launch()
2356
+
2357
+ # %%
2358
+
2359
+
2360
+