Shashwat-18 commited on
Commit
44eea1c
·
verified ·
1 Parent(s): 7df2bdc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +2187 -0
app.py ADDED
@@ -0,0 +1,2187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os
3
+ import uuid
4
+ import sqlite3
5
+ import datetime
6
+ import json
7
+ import re
8
+ import time
9
+ from typing import Dict, Any, List, Tuple, Optional
10
+ import hashlib
11
+
12
+ import pandas as pd
13
+ from groq import Groq
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+
17
+ # %%
18
+ # 0. Global config
19
+ DB_PATH = "olist.db"
20
+ DATA_DIR = "data"
21
+
22
+ # Groq Model
23
+ GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
24
+
25
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
26
+ if not GROQ_API_KEY:
27
+ raise ValueError("GROQ_API_KEY not found.")
28
+
29
+ groq_client = Groq(api_key=GROQ_API_KEY)
30
+
31
+ # Embedding model
32
+ EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
33
+
34
+ # Logging paths
35
+ PROMPT_LOG_PATH = "prompt_logs.txt"
36
+ RUN_LOG_PATH = "run_logs.txt"
37
+
38
+ # %%
39
+ # SINGLE, CORRECT, PERSISTENT CHROMA CLIENT
40
+
41
+ import os
42
+ import chromadb
43
+ from chromadb.config import Settings
44
+
45
+ CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data")
46
+ os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
47
+
48
+ _CHROMA_CLIENT = None
49
+
50
+ def get_chroma_client():
51
+ global _CHROMA_CLIENT
52
+ if _CHROMA_CLIENT is None:
53
+ _CHROMA_CLIENT = chromadb.PersistentClient(
54
+ path=CHROMA_PERSIST_DIR,
55
+ settings=Settings(
56
+ anonymized_telemetry=False,
57
+ ),
58
+ )
59
+ return _CHROMA_CLIENT
60
+
61
+
62
+ # %%
63
+
64
+
65
+ # %%
66
+ # 1. Logging helpers
67
+
68
+
69
+ def log_prompt(tag: str, prompt: str) -> None:
70
+ """
71
+ Append the full prompt to a log file, with a tag and timestamp.
72
+ """
73
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
74
+ header = f"\n\n================ {tag} @ {timestamp} ================\n"
75
+ with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
76
+ f.write(header)
77
+ f.write(prompt)
78
+ f.write("\n================ END PROMPT ================\n")
79
+
80
+
81
+ def log_run_event(tag: str, content: str) -> None:
82
+ """
83
+ Append model response, final SQL, and error info into a run log.
84
+ """
85
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
86
+ header = f"\n\n================ {tag} @ {timestamp} ================\n"
87
+ with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
88
+ f.write(header)
89
+ f.write(content)
90
+ f.write("\n================ END EVENT ================\n")
91
+
92
+
93
+
94
+ # %%
95
+ # 2. Feedback table + helpers
96
+
97
+ def init_feedback_table(conn: sqlite3.Connection) -> None:
98
+ """
99
+ Create (or upgrade) a table to capture user feedback on model answers.
100
+ """
101
+ conn.execute("""
102
+ CREATE TABLE IF NOT EXISTS user_feedback (
103
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
104
+ created_at TEXT NOT NULL,
105
+ question TEXT NOT NULL,
106
+ generated_sql TEXT,
107
+ model_answer TEXT,
108
+ rating TEXT CHECK(rating IN ('good','bad')) NOT NULL,
109
+ comment TEXT,
110
+ corrected_sql TEXT
111
+ )
112
+ """)
113
+ conn.commit()
114
+
115
+
116
+ def record_feedback(
117
+ conn: sqlite3.Connection,
118
+ question: str,
119
+ generated_sql: str,
120
+ model_answer: str,
121
+ rating: str, # "good" or "bad"
122
+ comment: Optional[str] = None,
123
+ corrected_sql: Optional[str] = None,
124
+ ) -> None:
125
+ """
126
+ Store user feedback about a particular model answer / SQL query.
127
+ If corrected_sql is provided, it is treated as an external correction.
128
+ """
129
+ rating = rating.lower()
130
+ if rating not in ("good", "bad"):
131
+ raise ValueError("rating must be 'good' or 'bad'")
132
+
133
+ ts = datetime.datetime.now().isoformat(timespec="seconds")
134
+ conn.execute(
135
+ """
136
+ INSERT INTO user_feedback (
137
+ created_at, question, generated_sql, model_answer,
138
+ rating, comment, corrected_sql
139
+ )
140
+ VALUES (?, ?, ?, ?, ?, ?, ?)
141
+ """,
142
+ (ts, question, generated_sql, model_answer, rating, comment, corrected_sql),
143
+ )
144
+ conn.commit()
145
+
146
+
147
+ def get_last_feedback_for_question(
148
+ conn: sqlite3.Connection,
149
+ question: str,
150
+ ) -> Optional[Dict[str, Any]]:
151
+ """
152
+ Return the most recent feedback row for this question (if any).
153
+ """
154
+ cur = conn.cursor()
155
+ cur.execute(
156
+ """
157
+ SELECT created_at, generated_sql, model_answer,
158
+ rating, comment, corrected_sql
159
+ FROM user_feedback
160
+ WHERE question = ?
161
+ ORDER BY created_at DESC
162
+ LIMIT 1
163
+ """,
164
+ (question,),
165
+ )
166
+ row = cur.fetchone()
167
+ if not row:
168
+ return None
169
+
170
+ return {
171
+ "created_at": row[0],
172
+ "generated_sql": row[1],
173
+ "model_answer": row[2],
174
+ "rating": row[3],
175
+ "comment": row[4],
176
+ "corrected_sql": row[5],
177
+ }
178
+
179
+
180
+
181
+ # %%
182
+ # 3. Database setup (from CSVs)
183
+
184
+ def init_db() -> sqlite3.Connection:
185
+ """
186
+ Load all CSVs from the data/ folder into a local SQLite DB.
187
+ Table names are derived from file names (without .csv).
188
+ """
189
+ conn = sqlite3.connect(DB_PATH, check_same_thread=False)
190
+
191
+ csv_files = [
192
+ "olist_customers_dataset.csv",
193
+ "olist_orders_dataset.csv",
194
+ "olist_order_items_dataset.csv",
195
+ "olist_products_dataset.csv",
196
+ "olist_order_reviews_dataset.csv",
197
+ "olist_order_payments_dataset.csv",
198
+ "product_category_name_translation.csv",
199
+ "olist_sellers_dataset.csv",
200
+ "olist_geolocation_dataset.csv",
201
+ ]
202
+
203
+ for fname in csv_files:
204
+ path = os.path.join(DATA_DIR, fname)
205
+ print(path)
206
+ if not os.path.exists(path):
207
+ print(f"CSV not found: {path} - skipping")
208
+ continue
209
+
210
+ table_name = os.path.splitext(fname)[0]
211
+ print(f"Loading {path} into table {table_name}...")
212
+ df = pd.read_csv(path)
213
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
214
+
215
+
216
+ init_feedback_table(conn)
217
+
218
+ return conn
219
+
220
+
221
+ conn = init_db()
222
+
223
+
224
+
225
+ # %%
226
+ # 4. Manual docs for Olist tables
227
+
228
+ OLIST_DOCS: Dict[str, Dict[str, Any]] = {
229
+ "olist_customers_dataset": {
230
+ "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
231
+ "columns": {
232
+ "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
233
+ "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
234
+ "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
235
+ "customer_city": "Customer's city as captured at the time of the order or registration.",
236
+ "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
237
+ }
238
+ },
239
+ "olist_orders_dataset": {
240
+ "description": "Customer orders placed on the Olist marketplace, one row per order.",
241
+ "columns": {
242
+ "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
243
+ "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
244
+ "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
245
+ "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
246
+ "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
247
+ "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
248
+ "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
249
+ "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
250
+ }
251
+ },
252
+ "olist_order_items_dataset": {
253
+ "description": "Order line items, one row per product per order.",
254
+ "columns": {
255
+ "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
256
+ "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
257
+ "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
258
+ "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
259
+ "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
260
+ "price": "Item price paid by the customer for this line (in BRL, not including freight).",
261
+ "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
262
+ }
263
+ },
264
+ "olist_products_dataset": {
265
+ "description": "Product catalog with physical and category attributes, one row per product.",
266
+ "columns": {
267
+ "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
268
+ "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
269
+ "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
270
+ "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
271
+ "product_photos_qty": "Number of product images associated with the listing.",
272
+ "product_weight_g": "Product weight in grams.",
273
+ "product_length_cm": "Product length in centimeters (package dimension).",
274
+ "product_height_cm": "Product height in centimeters (package dimension).",
275
+ "product_width_cm": "Product width in centimeters (package dimension)."
276
+ }
277
+ },
278
+ "olist_order_reviews_dataset": {
279
+ "description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
280
+ "columns": {
281
+ "review_id": "Primary key. Unique identifier for each review record.",
282
+ "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
283
+ "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
284
+ "review_comment_title": "Optional short text title or summary of the review.",
285
+ "review_comment_message": "Optional detailed free-text comment describing the customer experience.",
286
+ "review_creation_date": "Date when the customer created the review.",
287
+ "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
288
+ }
289
+ },
290
+ "olist_order_payments_dataset": {
291
+ "description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
292
+ "columns": {
293
+ "order_id": "Foreign key to olist_orders_dataset.order_id.",
294
+ "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
295
+ "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
296
+ "payment_installments": "Number of installments chosen by the customer for this payment.",
297
+ "payment_value": "Monetary amount paid in this payment record (in BRL)."
298
+ }
299
+ },
300
+ "product_category_name_translation": {
301
+ "description": "Lookup table mapping Portuguese product category names to English equivalents.",
302
+ "columns": {
303
+ "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
304
+ "product_category_name_english": "Translated product category name in English."
305
+ }
306
+ },
307
+ "olist_sellers_dataset": {
308
+ "description": "Seller master data, one row per seller operating on the Olist marketplace.",
309
+ "columns": {
310
+ "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
311
+ "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
312
+ "seller_city": "City where the seller is located.",
313
+ "seller_state": "State where the seller is located (two-letter Brazilian state code)."
314
+ }
315
+ },
316
+ "olist_geolocation_dataset": {
317
+ "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
318
+ "columns": {
319
+ "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
320
+ "geolocation_lat": "Latitude coordinate of the location.",
321
+ "geolocation_lng": "Longitude coordinate of the location.",
322
+ "geolocation_city": "City name for the location.",
323
+ "geolocation_state": "State code for the location (two-letter Brazilian state code)."
324
+ }
325
+ }
326
+ }
327
+
328
+
329
+
330
+ # %%
331
+ # 5. Schema extractor + metadata YAML
332
+
333
+ def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
334
+ """
335
+ Introspect SQLite tables, columns and foreign key relationships.
336
+ """
337
+ cursor = connection.cursor()
338
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
339
+ tables = [row[0] for row in cursor.fetchall()]
340
+
341
+ metadata: Dict[str, Any] = {"tables": {}}
342
+
343
+ for table in tables:
344
+ cursor.execute(f"PRAGMA table_info('{table}')")
345
+ cols = cursor.fetchall()
346
+
347
+ # Manual docs for the table
348
+ table_docs = OLIST_DOCS.get(table, {})
349
+ table_desc: str = table_docs.get(
350
+ "description",
351
+ f"Table '{table}' from Olist dataset."
352
+ )
353
+ column_docs: Dict[str, str] = table_docs.get("columns", {})
354
+
355
+ columns_meta: Dict[str, str] = {}
356
+ for c in cols:
357
+ col_name = c[1]
358
+ col_type = c[2] or "TEXT"
359
+
360
+ if col_name in column_docs:
361
+ columns_meta[col_name] = column_docs[col_name]
362
+ else:
363
+ columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
364
+
365
+ cursor.execute(f"PRAGMA foreign_key_list('{table}')")
366
+ fk_rows = cursor.fetchall()
367
+ relationships: List[str] = []
368
+ for fk in fk_rows:
369
+ ref_table = fk[2]
370
+ from_col = fk[3]
371
+ to_col = fk[4]
372
+ relationships.append(
373
+ f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)"
374
+ )
375
+
376
+ metadata["tables"][table] = {
377
+ "description": table_desc,
378
+ "columns": columns_meta,
379
+ "relationships": relationships,
380
+ }
381
+
382
+ return metadata
383
+
384
+
385
+ schema_metadata = extract_schema_metadata(conn)
386
+
387
+
388
+ def build_schema_yaml(metadata: Dict[str, Any]) -> str:
389
+ """
390
+ Render metadata dict into a YAML-style string.
391
+ """
392
+ lines: List[str] = ["tables:"]
393
+ for tname, tinfo in metadata["tables"].items():
394
+ lines.append(f" {tname}:")
395
+ desc = tinfo.get("description", "").replace('"', "'")
396
+ lines.append(f' description: "{desc}"')
397
+ lines.append(" columns:")
398
+ for col_name, col_desc in tinfo.get("columns", {}).items():
399
+ col_desc_clean = col_desc.replace('"', "'")
400
+ lines.append(f' {col_name}: "{col_desc_clean}"')
401
+ rels = tinfo.get("relationships", [])
402
+ if rels:
403
+ lines.append(" relationships:")
404
+ for rel in rels:
405
+ rel_clean = rel.replace('"', "'")
406
+ lines.append(f' - "{rel_clean}"')
407
+ return "\n".join(lines)
408
+
409
+
410
+ schema_yaml = build_schema_yaml(schema_metadata)
411
+
412
+
413
+
414
+ # %%
415
+ # 6. Build schema documents for RAG (taking samples from the table)
416
+
417
+ def build_schema_documents(
418
+ connection: sqlite3.Connection,
419
+ schema_metadata: Dict[str, Any],
420
+ sample_rows: int = 5,
421
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
422
+ """
423
+ Build one rich RAG document per table, using schema_metadata.
424
+
425
+ Each document includes:
426
+ - Table name
427
+ - Table description
428
+ - Columns with type + description
429
+ - Relationships (FKs)
430
+ - A few sample rows
431
+ """
432
+ cursor = connection.cursor()
433
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
434
+ tables = [row[0] for row in cursor.fetchall()]
435
+
436
+ docs: List[str] = []
437
+ metadatas: List[Dict[str, Any]] = []
438
+
439
+ for table in tables:
440
+ tmeta = schema_metadata["tables"][table]
441
+ table_desc = tmeta.get("description", "")
442
+ columns_meta = tmeta.get("columns", {})
443
+ relationships = tmeta.get("relationships", [])
444
+
445
+ # Use PRAGMA to get types, then enrich with descriptions
446
+ cursor.execute(f"PRAGMA table_info('{table}')")
447
+ cols = cursor.fetchall()
448
+
449
+ col_lines = []
450
+ for c in cols:
451
+ col_name = c[1]
452
+ col_type = c[2] or "TEXT"
453
+ col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
454
+ col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
455
+
456
+ # Sample rows
457
+ try:
458
+ sample_df = pd.read_sql_query(
459
+ f"SELECT * FROM '{table}' LIMIT {sample_rows}",
460
+ connection,
461
+ )
462
+ sample_text = sample_df.to_markdown(index=False)
463
+ except Exception:
464
+ sample_text = "(could not fetch sample rows)"
465
+
466
+ # Relationships block
467
+ rel_block = ""
468
+ if relationships:
469
+ rel_block = "Relationships:\n" + "\n".join(
470
+ f"- {rel}" for rel in relationships
471
+ ) + "\n"
472
+
473
+ doc_text = (
474
+ f"Table: {table}\n"
475
+ f"Description: {table_desc}\n\n"
476
+ f"Columns:\n" + "\n".join(col_lines) + "\n\n"
477
+ f"{rel_block}\n"
478
+ f"Example rows:\n{sample_text}\n"
479
+ )
480
+
481
+ docs.append(doc_text)
482
+ metadatas.append({
483
+ "doc_type": "table_schema",
484
+ "table_name": table,
485
+ })
486
+
487
+ return docs, metadatas
488
+
489
+
490
+ # Build RAG texts + metadata
491
+ schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
492
+
493
+ RAG_TEXTS: List[str] = []
494
+ RAG_METADATAS: List[Dict[str, Any]] = []
495
+
496
+ # 1) Per-table docs
497
+ RAG_TEXTS.extend(schema_docs)
498
+ RAG_METADATAS.extend(schema_doc_metas)
499
+
500
+ # 2) Global YAML as a separate doc
501
+ RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
502
+ RAG_METADATAS.append({"doc_type": "global_schema"})
503
+
504
+
505
+ # %%
506
+ def build_store_final():
507
+ embedding_model = HuggingFaceEmbeddings(
508
+ model_name=EMBED_MODEL_NAME,
509
+ encode_kwargs={"normalize_embeddings": True},
510
+ )
511
+
512
+ rag_store = Chroma(
513
+ client=get_chroma_client(),
514
+ collection_name="rag_schema_store",
515
+ embedding_function=embedding_model,
516
+ )
517
+
518
+ if rag_store._collection.count() == 0:
519
+ rag_store.add_texts(
520
+ texts=RAG_TEXTS,
521
+ metadatas=RAG_METADATAS,
522
+ )
523
+
524
+ rag_retriever = rag_store.as_retriever(
525
+ search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}}
526
+ )
527
+
528
+ return rag_store, rag_retriever
529
+
530
+ # Initialize once
531
+ rag_store, rag_retriever = build_store_final()
532
+
533
+
534
+ # %%
535
+ from langchain.vectorstores import Chroma
536
+ from langchain.embeddings import HuggingFaceEmbeddings
537
+
538
+ _sql_cache_store = None
539
+ _sql_embedding_fn = None
540
+
541
+ SQL_CACHE_COLLECTION = "sql_cache_mpnet"
542
+
543
+ def get_sql_cache_store():
544
+ global _sql_cache_store, _sql_embedding_fn
545
+
546
+ if _sql_cache_store is not None:
547
+ return _sql_cache_store
548
+
549
+ if _sql_embedding_fn is None:
550
+ _sql_embedding_fn = HuggingFaceEmbeddings(
551
+ model_name=EMBED_MODEL_NAME,
552
+ encode_kwargs={"normalize_embeddings": True},
553
+ )
554
+
555
+ _sql_cache_store = Chroma(
556
+ client=get_chroma_client(),
557
+ collection_name=SQL_CACHE_COLLECTION,
558
+ embedding_function=_sql_embedding_fn,
559
+ )
560
+
561
+ return _sql_cache_store
562
+
563
+
564
+ # %%
565
+ def sanitize_metadata_for_chroma(metadata: dict) -> dict:
566
+ safe = {}
567
+ for k, v in (metadata or {}).items():
568
+ if v is None:
569
+ continue
570
+ if isinstance(v, (int, float, bool)):
571
+ safe[k] = v
572
+ elif isinstance(v, str):
573
+ safe[k] = v
574
+ else:
575
+ safe[k] = str(v)
576
+ return safe
577
+
578
+
579
+ # %%
580
+ def normalize_question_strict(q: str) -> str:
581
+ """
582
+ Deterministic normalization for exact cache hits.
583
+ """
584
+ if not q:
585
+ return ""
586
+ q = q.lower().strip()
587
+ q = re.sub(r"[^\w\s]", "", q)
588
+ q = re.sub(r"\s+", " ", q)
589
+ return q
590
+
591
+
592
+ # %%
593
+ # Helper: normalize question
594
+
595
+ def normalize_question_text(q: str) -> str:
596
+ if not q:
597
+ return ""
598
+ q = q.strip().lower()
599
+ q = re.sub(r"[^\w\s]", " ", q)
600
+ q = re.sub(r"\s+", " ", q).strip()
601
+ return q
602
+
603
+ # Helper: compute success rate
604
+ def compute_success_rate(md: dict) -> float:
605
+ sc = md.get("success_count", 0) or 0
606
+ tf = md.get("total_feedbacks", 0) or 0
607
+ if tf <= 0:
608
+ return 0.0
609
+ return float(sc) / float(tf)
610
+
611
+ # Insert initial cache entry (no feedback yet)
612
+ def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None):
613
+ """
614
+ Insert a cached entry when you run a query and want to cache it regardless of feedback.
615
+ initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0
616
+ """
617
+ if store is None:
618
+ store = get_sql_cache_store()
619
+
620
+ ident = uuid.uuid4().hex
621
+ norm = normalize_question_text(question)
622
+ md = {
623
+ "id": ident,
624
+ "normalized_question": norm,
625
+ "sql": sql,
626
+ "answer_md": answer_md,
627
+ "saved_at": time.time(),
628
+ "views": 1,
629
+ "success_count": 0,
630
+ "total_feedbacks": 0,
631
+ "success_rate": 1.0,
632
+ }
633
+ if extra_metadata:
634
+ md.update(extra_metadata)
635
+
636
+ # Use store's API;
637
+ store.add_texts([question], metadatas=[md])
638
+ return ident
639
+
640
+
641
+
642
+ # %%
643
+ import time
644
+ import logging
645
+ from typing import Optional, List, Dict, Any
646
+ from difflib import SequenceMatcher
647
+
648
+ _logger = logging.getLogger(__name__)
649
+
650
+ # -------------------------------------------------------------------------
651
+ # Utility helpers
652
+ # -------------------------------------------------------------------------
653
+
654
+ def _now_ts() -> float:
655
+ return time.time()
656
+
657
+ def similarity_score(a: str, b: str) -> float:
658
+ return SequenceMatcher(None, (a or ""), (b or "")).ratio()
659
+
660
+ # -------------------------------------------------------------------------
661
+ # Persist helper
662
+ # -------------------------------------------------------------------------
663
+ def langchain_upsert(
664
+ store,
665
+ text: str,
666
+ metadata: dict,
667
+ cache_id: str,
668
+ ):
669
+ safe_md = sanitize_metadata_for_chroma(metadata)
670
+
671
+ try:
672
+ store.add_texts(
673
+ texts=[text],
674
+ metadatas=[safe_md],
675
+ ids=[cache_id],
676
+ )
677
+ except TypeError:
678
+ store.add_texts(
679
+ texts=[text],
680
+ metadatas=[safe_md],
681
+ )
682
+
683
+
684
+ # -------------------------------------------------------------------------
685
+ # Cache insert / update
686
+ # -------------------------------------------------------------------------
687
+ def cache_sql_answer_dedup(
688
+ question: str,
689
+ sql: str,
690
+ answer_md: str,
691
+ metadata: dict,
692
+ store,
693
+ ):
694
+ norm_q_semantic = normalize_text(question)
695
+ norm_q_exact = normalize_question_strict(question)
696
+
697
+ cache_id = generate_cache_id(question, sql)
698
+ now = _now_ts()
699
+
700
+ md = {
701
+ # identity
702
+ "cache_id": cache_id,
703
+ "sql": sql,
704
+ "answer_md": answer_md,
705
+
706
+ # exact match key
707
+ "normalized_question": norm_q_exact,
708
+
709
+ # timestamps
710
+ "saved_at": metadata.get("saved_at", now),
711
+ "last_updated_at": now,
712
+ "last_viewed_at": metadata.get("last_viewed_at", 0),
713
+
714
+ # metrics
715
+ "good_count": metadata.get("good_count", 0),
716
+ "bad_count": metadata.get("bad_count", 0),
717
+ "total_feedbacks": metadata.get("total_feedbacks", 0),
718
+ "success_rate": metadata.get("success_rate", 0.5),
719
+ "views": metadata.get("views", 0),
720
+ }
721
+
722
+ langchain_upsert(
723
+ store=store,
724
+ text=norm_q_semantic, # semantic vector
725
+ metadata=md,
726
+ cache_id=cache_id,
727
+ )
728
+
729
+ return {
730
+ "question": question,
731
+ "sql": sql,
732
+ "answer_md": answer_md,
733
+ "metadata": md,
734
+ }
735
+
736
+ # -------------------------------------------------------------------------
737
+ # Find exact cached entry (question + SQL)
738
+ # -------------------------------------------------------------------------
739
+
740
+ def find_cached_doc_by_sql(question: str, sql: str, store):
741
+ cache_id = generate_cache_id(question, sql)
742
+ coll = getattr(store, "_collection", None)
743
+
744
+ if coll and hasattr(coll, "get"):
745
+ try:
746
+ res = coll.get(ids=[cache_id])
747
+ if res and res.get("metadatas"):
748
+ md = res["metadatas"][0]
749
+ return {
750
+ "id": cache_id,
751
+ "question": question,
752
+ "sql": md.get("sql"),
753
+ "answer_md": md.get("answer_md"),
754
+ "metadata": md,
755
+ }
756
+ except Exception:
757
+ pass
758
+
759
+ return None
760
+
761
+ # -------------------------------------------------------------------------
762
+ # Retrieve cached answers ranked primarily by success rate
763
+ # -------------------------------------------------------------------------
764
+
765
+ import re
766
+ import unicodedata
767
+
768
+
769
+ def normalize_text(text: str) -> str:
770
+ if not text:
771
+ return ""
772
+ text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii")
773
+ text = text.lower()
774
+ text = re.sub(r"[^\w\s]", " ", text)
775
+ text = re.sub(r"\s+", " ", text).strip()
776
+ return text
777
+
778
+ import hashlib
779
+
780
+ def generate_cache_id(question: str, sql: str) -> str:
781
+ q = normalize_text(question)
782
+ s = (sql or "").strip()
783
+ key = f"{q}||{s}".encode("utf-8")
784
+ return hashlib.sha1(key).hexdigest()
785
+
786
+
787
+ def rank_cached_candidates(candidates: list[dict]) -> dict:
788
+ """
789
+ Deterministic quality-first ranking.
790
+ """
791
+ candidates.sort(
792
+ key=lambda c: (
793
+ -float(c["metadata"].get("success_rate", 0.5)),
794
+ -int(c["metadata"].get("good_count", 0)),
795
+ int(c["metadata"].get("bad_count", 0)),
796
+ -float(c["metadata"].get("last_updated_at", 0)),
797
+ )
798
+ )
799
+ return candidates[0]
800
+
801
+
802
+
803
+ def retrieve_exact_cached_sql(question: str, store):
804
+ """
805
+ Exact question match, but still quality-ranked.
806
+ """
807
+ norm_q = normalize_question_strict(question)
808
+ coll = store._collection
809
+
810
+ try:
811
+ res = coll.get(where={"normalized_question": norm_q})
812
+ if not res or not res.get("metadatas"):
813
+ return None
814
+
815
+ candidates = []
816
+ for md in res["metadatas"]:
817
+ candidates.append({
818
+ "matched_question": question,
819
+ "sql": md["sql"],
820
+ "answer_md": md.get("answer_md", ""),
821
+ "distance": 0.0,
822
+ "metadata": md,
823
+ })
824
+
825
+ return rank_cached_candidates(candidates)
826
+
827
+ except Exception:
828
+ return None
829
+
830
+
831
+ def retrieve_best_cached_sql(
832
+ question: str,
833
+ store,
834
+ max_distance: float = 0.25,
835
+ ):
836
+ norm_q = normalize_text(question)
837
+ results = store.similarity_search_with_score(norm_q, k=20)
838
+
839
+ candidates = []
840
+
841
+ for doc, score in results:
842
+ distance = float(score)
843
+ if distance > max_distance:
844
+ continue
845
+
846
+ md = doc.metadata or {}
847
+ if "sql" not in md:
848
+ continue
849
+
850
+ candidates.append({
851
+ "matched_question": doc.page_content,
852
+ "sql": md["sql"],
853
+ "answer_md": md.get("answer_md", ""),
854
+ "distance": distance,
855
+ "metadata": md,
856
+
857
+ # ranking signals (FORCED numeric)
858
+ "success_rate": float(md.get("success_rate", 0.5)),
859
+ "good": int(md.get("good_count", 0)),
860
+ "bad": int(md.get("bad_count", 0)),
861
+ "views": int(md.get("views", 0)),
862
+ "last_updated": float(md.get("last_updated_at", 0)),
863
+ })
864
+
865
+ if not candidates:
866
+ return None
867
+
868
+ # QUALITY-FIRST RANKING
869
+ candidates.sort(
870
+ key=lambda c: (
871
+ -c["success_rate"],
872
+ -c["good"],
873
+ c["bad"],
874
+ c["distance"],
875
+ -c["last_updated"],
876
+ )
877
+ )
878
+
879
+
880
+ return candidates[0]
881
+
882
+ # -------------------------------------------------------------------------
883
+ # Increment views
884
+ # -------------------------------------------------------------------------
885
+
886
+ def increment_cache_views(metadata: dict, store):
887
+ if not metadata:
888
+ return False
889
+
890
+ cache_id = metadata.get("cache_id")
891
+ if not cache_id:
892
+ return False
893
+
894
+ md = dict(metadata)
895
+ md["views"] = int(md.get("views", 0)) + 1
896
+ md["last_viewed_at"] = _now_ts()
897
+ md["last_updated_at"] = _now_ts()
898
+ md["saved_at"] = md.get("saved_at", _now_ts())
899
+
900
+ try:
901
+ langchain_upsert(
902
+ store=store,
903
+ text=normalize_text(md.get("normalized_question", "")),
904
+ metadata=md,
905
+ cache_id=cache_id,
906
+ )
907
+ return True
908
+ except Exception:
909
+ _logger.exception("increment_cache_views failed")
910
+ return False
911
+
912
+
913
+ # Update metrics on feedback
914
+ def update_cache_on_feedback(
915
+ question: str,
916
+ original_doc_md: dict,
917
+ user_marked_good: bool,
918
+ llm_corrected_sql: str | None,
919
+ llm_corrected_answer_md: str | None,
920
+ store,
921
+ ):
922
+ if not original_doc_md:
923
+ return
924
+
925
+ md = dict(original_doc_md["metadata"])
926
+ cache_id = md["cache_id"]
927
+
928
+ # ---- feedback counts ----
929
+ if user_marked_good:
930
+ md["good_count"] = md.get("good_count", 0) + 1
931
+ else:
932
+ md["bad_count"] = md.get("bad_count", 0) + 1
933
+
934
+ md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1
935
+ md["success_rate"] = (
936
+ md["good_count"] / md["total_feedbacks"]
937
+ if md["total_feedbacks"] > 0 else 0.5
938
+ )
939
+
940
+ # ---- timestamps ----
941
+ md["saved_at"] = md.get("saved_at", _now_ts()) # preserve
942
+ md["last_updated_at"] = _now_ts()
943
+
944
+ langchain_upsert(
945
+ store=store,
946
+ text=normalize_text(question),
947
+ metadata=md,
948
+ cache_id=cache_id,
949
+ )
950
+
951
+ # -------------------------
952
+ # Corrected SQL --> NEW ENTRY
953
+ # -------------------------
954
+ if llm_corrected_sql and llm_corrected_answer_md:
955
+ cache_sql_answer_dedup(
956
+ question=question,
957
+ sql=llm_corrected_sql,
958
+ answer_md=llm_corrected_answer_md,
959
+ metadata={
960
+ "good_count": 1,
961
+ "bad_count": 0,
962
+ "total_feedbacks": 1,
963
+ "success_rate": 1.0,
964
+ "views": 0,
965
+ "saved_at": _now_ts(),
966
+ },
967
+ store=store,
968
+ )
969
+
970
+
971
+ # %%
972
+ # ### 8. Groq LLM via LangChain
973
+
974
+ from langchain_groq import ChatGroq
975
+ import re
976
+ import gradio as gr
977
+
978
+ llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
979
+
980
+ # %%
981
+ def get_rag_context(question: str) -> str:
982
+ """
983
+ Retrieve the most relevant schema documents for the question.
984
+ """
985
+ docs = rag_retriever.invoke(question)
986
+ return "\n\n---\n\n".join(d.page_content for d in docs)
987
+
988
+ def clean_sql(sql: str) -> str:
989
+ sql = sql.strip()
990
+ if "```" in sql:
991
+ sql = sql.replace("```sql", "").replace("```", "").strip()
992
+ return sql
993
+
994
+ def extract_sql_from_markdown(text: str) -> str:
995
+ """
996
+ Extract the first ```sql ... ``` block from LLM output.
997
+ If not found, return the whole text.
998
+ """
999
+ match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
1000
+ if match:
1001
+ return match.group(1).strip()
1002
+ return text.strip()
1003
+
1004
+ def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str:
1005
+ """
1006
+ After the given marker, return the rest of the text as explanation.
1007
+ """
1008
+ idx = text.upper().find(marker.upper())
1009
+ if idx == -1:
1010
+ return text.strip()
1011
+ return text[idx + len(marker):].strip()
1012
+
1013
+ # 6b. General / descriptive questions
1014
+
1015
+ GENERAL_DESC_KEYWORDS = [
1016
+ "what is this dataset about",
1017
+ "what is this data about",
1018
+ "describe this dataset",
1019
+ "describe the dataset",
1020
+ "dataset overview",
1021
+ "data overview",
1022
+ "summary of the dataset",
1023
+ "explain this dataset",
1024
+ ]
1025
+
1026
+
1027
+ def is_general_question(question: str) -> bool:
1028
+ """
1029
+ Detect high-level descriptive questions where we should answer
1030
+ directly from schema context instead of generating SQL.
1031
+ """
1032
+ q = question.lower().strip()
1033
+ return any(key in q for key in GENERAL_DESC_KEYWORDS)
1034
+
1035
+
1036
+ def answer_general_question(question: str) -> str:
1037
+ """
1038
+ Use the RAG schema docs to generate a rich, high-level description
1039
+ of the Olist dataset for conceptual questions.
1040
+ """
1041
+ rag_context = get_rag_context(question)
1042
+
1043
+ system_instructions = """
1044
+ You are a data documentation expert.
1045
+
1046
+ You will be given:
1047
+ - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
1048
+ - A high-level user question like "what is this dataset about?".
1049
+
1050
+ Your job:
1051
+ - Write a clear, structured overview of the dataset.
1052
+ - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
1053
+ - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
1054
+ - Target a non-technical person.
1055
+
1056
+ Do NOT write SQL. Answer in Markdown.
1057
+ """
1058
+
1059
+ prompt = (
1060
+ system_instructions
1061
+ + "\n\n=== SCHEMA CONTEXT ===\n"
1062
+ + rag_context
1063
+ + "\n\n=== USER QUESTION ===\n"
1064
+ + question
1065
+ + "\n\nDetailed dataset overview:"
1066
+ )
1067
+
1068
+ log_prompt("GENERAL_DATASET_QUESTION", prompt)
1069
+
1070
+ response = llm.invoke(prompt)
1071
+ log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
1072
+
1073
+ return response.content.strip()
1074
+
1075
+ # %%
1076
+ # ### 9. SQL execution / validation
1077
+
1078
+ def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
1079
+ """
1080
+ Execute SQL on SQLite and return a DataFrame, else return an error.
1081
+ """
1082
+ try:
1083
+ df = pd.read_sql_query(sql, connection)
1084
+ return df, None
1085
+ except Exception as e:
1086
+ return None, str(e)
1087
+
1088
+ def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
1089
+ """
1090
+ Basic SQL validator:
1091
+ - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
1092
+ """
1093
+ try:
1094
+ cursor = connection.cursor()
1095
+ cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
1096
+ return True, None
1097
+ except Exception as e:
1098
+ return False, str(e)
1099
+
1100
+ # %%
1101
+ # ### 10. SQL Generation + Repair with LLM (feedback-aware)
1102
+
1103
+ def build_sql_review_prompt(
1104
+ question: str,
1105
+ generated_sql: str,
1106
+ user_feedback_comment: str,
1107
+ rag_context: str,
1108
+ ) -> str:
1109
+ """
1110
+ Prompt to let the LLM compare its SQL with the user's feedback,
1111
+ decide if the join/logic is wrong, and produce a corrected SQL + explanation.
1112
+
1113
+ We explicitly allow the model to say:
1114
+ - "query was already correct, user mistaken" OR
1115
+ - "query was wrong, here is the fix".
1116
+ """
1117
+ prompt = f"""
1118
+ You previously generated the following SQL for a SQLite database:
1119
+
1120
+ ```sql
1121
+ {generated_sql}
1122
+
1123
+ The user now says this query is WRONG and provided this feedback:
1124
+ "{user_feedback_comment}"
1125
+
1126
+ TASKS:
1127
+
1128
+ Compare the SQL with the database schema (given in the context) and the user's feedback.
1129
+
1130
+ Decide whether the query is actually correct or incorrect.
1131
+ Make sure that the user clearly specifies why they think it is incorrect.
1132
+ It can be if they are unsatisfied with the numbers, or the logic is incorrect, or SQL is invalid etc.
1133
+ If any reason is not clearly specified, return the previous result as it is.
1134
+
1135
+ If it is already correct, keep it unchanged and explain why the user might be mistaken.
1136
+
1137
+ If it is incorrect (e.g., wrong joins, missing filters, wrong aggregation), fix it.
1138
+
1139
+ Produce a corrected SQL query that better answers the question.
1140
+
1141
+ If the original query is already correct, just repeat the same SQL.
1142
+
1143
+ Explain in a few sentences WHO is correct (you or the user) and WHY.
1144
+
1145
+ DATABASE SCHEMA (partial, from RAG):
1146
+ {rag_context}
1147
+
1148
+ User question:
1149
+ {question}
1150
+
1151
+ Return your answer in this format:
1152
+
1153
+ CORRECTED_SQL:
1154
+
1155
+ -- your (possibly unchanged) SQL here
1156
+ SELECT ...
1157
+
1158
+ EXPLANATION:
1159
+ Your explanation here, clearly stating whether:
1160
+
1161
+ the original query was correct or not, and
1162
+
1163
+ how your corrected SQL addresses the issue (or why it didn't need changes).
1164
+ """
1165
+ return prompt.strip()
1166
+
1167
+
1168
+ # %%
1169
+ # Review and Correct SQL based on Feedback
1170
+
1171
+ def review_and_correct_sql_with_llm(
1172
+ question: str,
1173
+ generated_sql: str,
1174
+ user_feedback_comment: str,
1175
+ rag_context: str,
1176
+ ) -> Tuple[str, str]:
1177
+ """
1178
+ Ask the LLM to compare its SQL with user's feedback, decide what is wrong (or not),
1179
+ and propose a corrected SQL (possibly unchanged) + explanation.
1180
+
1181
+ Returns:
1182
+ corrected_sql, explanation
1183
+ """
1184
+ prompt = build_sql_review_prompt(
1185
+ question=question,
1186
+ generated_sql=generated_sql,
1187
+ user_feedback_comment=user_feedback_comment,
1188
+ rag_context=rag_context,
1189
+ )
1190
+
1191
+ log_prompt("SQL_REVIEW_FEEDBACK", prompt)
1192
+ response = llm.invoke(prompt)
1193
+ log_run_event("RAW_MODEL_RESPONSE_SQL_REVIEW", response.content)
1194
+
1195
+ # Try to extract SQL; if none, fall back to original
1196
+ extracted_sql = extract_sql_from_markdown(response.content)
1197
+ corrected_sql = clean_sql(extracted_sql) if extracted_sql else generated_sql
1198
+ if not corrected_sql.strip():
1199
+ corrected_sql = generated_sql
1200
+
1201
+ explanation = extract_explanation_after_marker(
1202
+ response.content,
1203
+ marker="EXPLANATION:",
1204
+ )
1205
+ return corrected_sql, explanation
1206
+
1207
+ # %%
1208
+ # Generate SQL
1209
+
1210
+ def generate_sql(question: str, rag_context: str) -> str:
1211
+ """
1212
+ Generate SQL using LLM, but also pull in any past user feedback
1213
+ (corrected SQL) for this question as external guidance.
1214
+ """
1215
+ # External correction from past feedback, if any
1216
+ last_fb = get_last_feedback_for_question(conn, question)
1217
+ if last_fb and last_fb.get("corrected_sql"):
1218
+ previous_feedback_block = f"""
1219
+ EXTERNAL USER FEEDBACK FROM PAST RUNS:
1220
+
1221
+ Previous generated SQL:
1222
+ {last_fb['generated_sql']}
1223
+
1224
+ Corrected SQL (preferred reference):
1225
+ {last_fb['corrected_sql']}
1226
+
1227
+ User comment & prior explanation:
1228
+ {last_fb.get('comment') or '(none)'}
1229
+
1230
+ You must avoid repeating the same mistake and should follow the logic of
1231
+ the corrected SQL when appropriate, while still reasoning from the schema.
1232
+ """
1233
+ else:
1234
+ previous_feedback_block = ""
1235
+
1236
+ system_instructions = f"""
1237
+ You are a senior data analyst writing SQL for a SQLite database.
1238
+
1239
+ You will be given:
1240
+
1241
+ - A description of available tables, columns, and relationships (schema + YAML metadata).
1242
+ - A natural language question from the user.
1243
+
1244
+ Your job:
1245
+
1246
+ - Write ONE valid SQLite SQL query that answers the question.
1247
+ - ONLY use tables and columns that exist in the schema_context.
1248
+ - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1249
+ - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1250
+ - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator,
1251
+ and round to 2 decimals when appropriate.
1252
+
1253
+ {previous_feedback_block}
1254
+ """
1255
+
1256
+ prompt = (
1257
+ system_instructions
1258
+ + "\n\n=== RAG CONTEXT ===\n"
1259
+ + rag_context
1260
+ + "\n\n=== USER QUESTION ===\n"
1261
+ + question
1262
+ + "\n\nSQL query:"
1263
+ )
1264
+
1265
+ log_prompt("SQL_GENERATION", prompt)
1266
+ response = llm.invoke(prompt)
1267
+ log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
1268
+
1269
+ sql = clean_sql(response.content)
1270
+ return sql
1271
+
1272
+
1273
+ # %%
1274
+ # Repair SQL
1275
+
1276
+ def repair_sql(
1277
+ question: str,
1278
+ rag_context: str,
1279
+ bad_sql: str,
1280
+ error_message: str,
1281
+ ) -> str:
1282
+ """
1283
+ Ask the LLM to correct a failing SQL query.
1284
+ """
1285
+ system_instructions = """
1286
+ You are a senior data analyst fixing an existing SQL query for a SQLite database.
1287
+
1288
+ You will be given:
1289
+
1290
+ - Schema context (tables, columns, relationships).
1291
+ - The user's question.
1292
+ - A previously generated SQL query that failed.
1293
+ - The SQLite error message.
1294
+
1295
+ Your job:
1296
+
1297
+ - Diagnose why the query failed.
1298
+ - Rewrite ONE valid SQLite SQL query that answers the question.
1299
+ - ONLY use tables and columns that exist in the schema_context.
1300
+ - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1301
+ - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1302
+ - Return ONLY the corrected SQL query, no explanation or markdown.
1303
+ """
1304
+
1305
+ prompt = (
1306
+ system_instructions
1307
+ + "\n\n=== RAG CONTEXT ===\n"
1308
+ + rag_context
1309
+ + "\n\n=== USER QUESTION ===\n"
1310
+ + question
1311
+ + "\n\n=== PREVIOUS (FAILING) SQL ===\n"
1312
+ + bad_sql
1313
+ + "\n\n=== SQLITE ERROR ===\n"
1314
+ + error_message
1315
+ + "\n\nCorrected SQL query:"
1316
+ )
1317
+
1318
+ log_prompt("SQL_REPAIR", prompt)
1319
+ response = llm.invoke(prompt)
1320
+ log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
1321
+
1322
+ sql = clean_sql(response.content)
1323
+ return sql
1324
+
1325
+
1326
+ # %%
1327
+ ### 11. Result summarization
1328
+
1329
+ def summarize_results(
1330
+ question: str,
1331
+ sql: str,
1332
+ df: Optional[pd.DataFrame],
1333
+ rag_context: str,
1334
+ error: Optional[str] = None,
1335
+ ) -> str:
1336
+ """
1337
+ Ask the LLM to produce a concise, human-readable answer.
1338
+ """
1339
+ system_instructions = """
1340
+ You are a senior data analyst.
1341
+
1342
+ You will be given:
1343
+
1344
+ The user's question.
1345
+ The final SQL that was executed.
1346
+ A small preview of the query result (as a Markdown table, if available).
1347
+ Optional error information if the query failed.
1348
+
1349
+ Your job:
1350
+
1351
+ Provide a clear, concise answer in Markdown.
1352
+ If the result is numeric / aggregated, explain what it means in business terms.
1353
+ If there was an error, explain it simply and suggest how the user could rephrase.
1354
+ Do NOT show raw SQL unless it is helpful to the user.
1355
+ """
1356
+
1357
+ # Build a markdown table preview if we have data
1358
+ if df is not None and not df.empty:
1359
+ preview_rows = min(len(df), 50)
1360
+ df_preview_md = df.head(preview_rows).to_markdown(index=False)
1361
+ else:
1362
+ df_preview_md = "(no rows returned)"
1363
+
1364
+ prompt = (
1365
+ system_instructions
1366
+ + "\n\n=== USER QUESTION ===\n"
1367
+ + question
1368
+ + "\n\n=== EXECUTED SQL ===\n"
1369
+ + sql
1370
+ + "\n\n=== QUERY RESULT PREVIEW ===\n"
1371
+ + df_preview_md
1372
+ + "\n\n=== RAG CONTEXT (schema) ===\n"
1373
+ + rag_context
1374
+ )
1375
+
1376
+ if error:
1377
+ prompt += "\n\n=== ERROR ===\n" + error
1378
+
1379
+ # Logging helpers assumed to exist
1380
+ log_prompt("RESULT_SUMMARY", prompt)
1381
+
1382
+ response = llm.invoke(prompt)
1383
+ log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
1384
+
1385
+ return response.content.strip()
1386
+
1387
+
1388
+ # %%
1389
+ def backend_pipeline(question: str):
1390
+ """
1391
+ STRICT cache-first backend.
1392
+
1393
+ Priority:
1394
+ 1. Exact question cache hit
1395
+ 2. Best semantic cache hit
1396
+ 3. LLM fallback
1397
+ """
1398
+
1399
+ # ------------------------------------------------------------------
1400
+ # Guards
1401
+ # ------------------------------------------------------------------
1402
+ if not question or not question.strip():
1403
+ return (
1404
+ "Please type a question.",
1405
+ pd.DataFrame(),
1406
+ "",
1407
+ "",
1408
+ "",
1409
+ "",
1410
+ pd.DataFrame(),
1411
+ [],
1412
+ [],
1413
+ False,
1414
+ 4,
1415
+ "**Feedback attempts remaining: 4**",
1416
+ gr.update(value="", visible=False),
1417
+ )
1418
+
1419
+ attempts_left = 4
1420
+ attempts_text = f"**Feedback attempts remaining: {attempts_left}**"
1421
+
1422
+ # ------------------------------------------------------------------
1423
+ # General / descriptive questions
1424
+ # ------------------------------------------------------------------
1425
+ if is_general_question(question):
1426
+ overview_md = answer_general_question(question)
1427
+ return (
1428
+ overview_md,
1429
+ pd.DataFrame(),
1430
+ "",
1431
+ "",
1432
+ question,
1433
+ overview_md,
1434
+ pd.DataFrame(),
1435
+ [],
1436
+ [],
1437
+ False,
1438
+ attempts_left,
1439
+ attempts_text,
1440
+ gr.update(value="", visible=False),
1441
+ )
1442
+
1443
+ store = get_sql_cache_store()
1444
+
1445
+ # ------------------------------------------------------------------
1446
+ # STEP 0A: EXACT CACHE LOOKUP (deterministic)
1447
+ # ------------------------------------------------------------------
1448
+ cached = retrieve_exact_cached_sql(question, store)
1449
+
1450
+ # ------------------------------------------------------------------
1451
+ # STEP 0B: SEMANTIC CACHE LOOKUP (ranked)
1452
+ # ------------------------------------------------------------------
1453
+ if cached is None:
1454
+ cached = retrieve_best_cached_sql(
1455
+ question=question,
1456
+ store=store,
1457
+ max_distance=0.25,
1458
+ )
1459
+
1460
+ # ------------------------------------------------------------------
1461
+ # CACHE HIT PATH
1462
+ # ------------------------------------------------------------------
1463
+ if cached:
1464
+ try:
1465
+ increment_cache_views(cached["metadata"], store=store)
1466
+ except Exception:
1467
+ pass
1468
+
1469
+ rag_context = get_rag_context(question)
1470
+
1471
+ header = (
1472
+ "### Cache Hit\n"
1473
+ f"- **Matched question:** \"{cached['matched_question']}\"\n"
1474
+ f"- **Success rate:** {cached['metadata'].get('success_rate', 0.5):.2f}\n"
1475
+ f"- **Similarity distance:** {cached['distance']:.4f}\n\n"
1476
+ "---\n\n"
1477
+ )
1478
+
1479
+ answer_md = header + (cached.get("answer_md") or "")
1480
+
1481
+ try:
1482
+ df, exec_error = execute_sql(cached["sql"], conn)
1483
+ if exec_error:
1484
+ df = pd.DataFrame()
1485
+ answer_md += f"\n\n Error re-running cached SQL: `{exec_error}`"
1486
+ except Exception as e:
1487
+ df = pd.DataFrame()
1488
+ answer_md += f"\n\n Exception re-running cached SQL: `{e}`"
1489
+
1490
+ md = cached["metadata"]
1491
+
1492
+ stats_md = (
1493
+ f"**Cached entry stats**\n\n"
1494
+ f"- **Success rate:** {md.get('success_rate', 0.5):.2f} \n"
1495
+ f"- **Total feedbacks:** {md.get('total_feedbacks', 0)} \n"
1496
+ f"- **Good / Bad:** {md.get('good_count', 0)} / {md.get('bad_count', 0)} \n"
1497
+ f"- **Views:** {md.get('views', 0)} \n"
1498
+ f"- **Saved at:** "
1499
+ f"{datetime.datetime.fromtimestamp(md.get('saved_at')).strftime('%Y-%m-%d %H:%M') if md.get('saved_at') else 'unknown'} \n"
1500
+ f"- **Last updated:** "
1501
+ f"{datetime.datetime.fromtimestamp(md.get('last_updated_at')).strftime('%Y-%m-%d %H:%M') if md.get('last_updated_at') else 'unknown'}\n\n"
1502
+ f"**SQL preview:**\n\n```sql\n{cached['sql']}\n```\n"
1503
+ )
1504
+
1505
+ return (
1506
+ answer_md,
1507
+ df,
1508
+ cached["sql"],
1509
+ rag_context,
1510
+ question,
1511
+ answer_md,
1512
+ df,
1513
+ [],
1514
+ [],
1515
+ False,
1516
+ attempts_left,
1517
+ attempts_text,
1518
+ gr.update(value=stats_md, visible=True),
1519
+ )
1520
+
1521
+ # ------------------------------------------------------------------
1522
+ # STEP 1: LLM FLOW (NO CACHE HIT)
1523
+ # ------------------------------------------------------------------
1524
+ rag_context = get_rag_context(question)
1525
+
1526
+ sql = generate_sql(question, rag_context)
1527
+ original_sql = sql
1528
+
1529
+ is_valid, validation_error = validate_sql(sql, conn)
1530
+ repaired = False
1531
+
1532
+ if not is_valid and validation_error:
1533
+ repaired_sql = repair_sql(question, rag_context, sql, validation_error)
1534
+ repaired_valid, repaired_error = validate_sql(repaired_sql, conn)
1535
+ if repaired_valid:
1536
+ sql = repaired_sql
1537
+ repaired = True
1538
+ validation_error = None
1539
+ else:
1540
+ validation_error = repaired_error or validation_error
1541
+
1542
+ df, exec_error = (None, None)
1543
+ if not validation_error:
1544
+ df, exec_error = execute_sql(sql, conn)
1545
+ else:
1546
+ exec_error = validation_error
1547
+
1548
+ summary_text = summarize_results(
1549
+ question=question,
1550
+ sql=sql,
1551
+ df=df,
1552
+ rag_context=rag_context,
1553
+ error=exec_error,
1554
+ )
1555
+
1556
+ sql_status = []
1557
+ if exec_error:
1558
+ sql_status.append(f"**Error:** `{exec_error}`")
1559
+ else:
1560
+ sql_status.append("Query ran successfully.")
1561
+ if repaired:
1562
+ sql_status.append("_Note: SQL was auto-repaired._")
1563
+
1564
+ sql_status.append("\n**Final SQL used:**\n")
1565
+ sql_status.append(f"```sql\n{sql}\n```")
1566
+
1567
+ answer_md = summary_text + "\n\n---\n\n" + "\n".join(sql_status)
1568
+ df_preview = df if df is not None and exec_error is None else pd.DataFrame()
1569
+
1570
+ # ------------------------------------------------------------------
1571
+ # STEP 2: CACHE LLM RESULT (ALWAYS)
1572
+ # ------------------------------------------------------------------
1573
+ try:
1574
+ cache_sql_answer_dedup(
1575
+ question=question,
1576
+ sql=sql,
1577
+ answer_md=answer_md,
1578
+ metadata={
1579
+ "good_count": 0,
1580
+ "bad_count": 0,
1581
+ "total_feedbacks": 0,
1582
+ "success_rate": 0.5,
1583
+ "views": 1,
1584
+ "saved_at": _now_ts(),
1585
+ },
1586
+ store=store,
1587
+ )
1588
+ except Exception:
1589
+ _logger.exception("backend_pipeline: failed to cache LLM result")
1590
+
1591
+ return (
1592
+ answer_md,
1593
+ df_preview,
1594
+ sql,
1595
+ rag_context,
1596
+ question,
1597
+ answer_md,
1598
+ df_preview,
1599
+ [],
1600
+ [],
1601
+ False,
1602
+ attempts_left,
1603
+ attempts_text,
1604
+ gr.update(value="", visible=False),
1605
+ )
1606
+
1607
+
1608
+ # %%
1609
+ def _looks_like_sql(text: str) -> bool:
1610
+ """Quick heuristic: does text contain SQL keywords / SELECT ?"""
1611
+ if not text:
1612
+ return False
1613
+ return bool(re.search(r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b", text, flags=re.I))
1614
+
1615
+
1616
+ def is_feedback_sufficient(feedback_text: str) -> bool:
1617
+ """
1618
+ Heuristic to decide whether the user's free-text feedback is actionable.
1619
+
1620
+ Returns True if:
1621
+ - length >= 20 characters AND contains a signal word (e.g., 'filter', 'year', 'should', 'instead', 'missing', 'wrong', digits),
1622
+ OR
1623
+ - it looks like SQL (user pasted corrected SQL),
1624
+ OR
1625
+ - length >= 60 characters (long feedback).
1626
+ """
1627
+ if not feedback_text:
1628
+ return False
1629
+
1630
+ text = feedback_text.strip()
1631
+ if len(text) >= 60:
1632
+ return True
1633
+
1634
+ if _looks_like_sql(text):
1635
+ return True
1636
+
1637
+ # look for signal words that indicate specificity
1638
+ signal_words = [
1639
+ "filter", "where", "year", "month", "should", "instead", "expected",
1640
+ "wrong", "missing", "aggregate", "sum", "avg", "count", "distinct",
1641
+ "join", "left join", "inner join", "group by", "order by", "date",
1642
+ "range", "exclude", "include", "only"
1643
+ ]
1644
+ lower = text.lower()
1645
+ signals = sum(1 for w in signal_words if w in lower)
1646
+ if signals >= 1 and len(text) >= 20:
1647
+ return True
1648
+
1649
+ # short hits like "numbers look off" are insufficient
1650
+ return False
1651
+
1652
+
1653
+ def build_followup_prompt_for_user(sample_feedback: str = "") -> str:
1654
+ """
1655
+ Deterministic follow-up question to ask the user when feedback is vague.
1656
+ Returns a friendly prompt that the UI can display to the user.
1657
+ """
1658
+ base = (
1659
+ "Thanks — I need a bit more detail to act on this feedback.\n\n"
1660
+ "Please tell me one (or more) of the following so I can check and correct the result:\n\n"
1661
+ "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n"
1662
+ " the **time range** (year/month), or the **filters** applied?\n"
1663
+ "2. If you expected a different number, what was the expected number (and how was it computed)?\n"
1664
+ "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n"
1665
+ "Examples you can copy-paste:\n"
1666
+ )
1667
+ examples = (
1668
+ "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n"
1669
+ "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n"
1670
+ "- \"Please exclude canceled orders (order_status = 'canceled').\"\n"
1671
+ "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n"
1672
+ )
1673
+ hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare."
1674
+ prompt = base + examples + hint
1675
+ if sample_feedback:
1676
+ prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt
1677
+ return prompt
1678
+
1679
+
1680
+
1681
+ # %%
1682
+ def feedback_pipeline_interactive(
1683
+ feedback_rating: str,
1684
+ feedback_comment: str,
1685
+ last_sql: str,
1686
+ last_rag_context: str,
1687
+ last_question: str,
1688
+ last_answer_md: str,
1689
+ last_df: pd.DataFrame,
1690
+ feedback_sql: str,
1691
+ attempts_left: int,
1692
+ ):
1693
+ rating = (feedback_rating or "").strip().lower()
1694
+ comment = (feedback_comment or "").strip()
1695
+ attempts_left = int(attempts_left or 0)
1696
+
1697
+ # ---------------- Guard ----------------
1698
+ if not last_question or not last_sql:
1699
+ return (
1700
+ last_answer_md,
1701
+ last_df,
1702
+ last_sql,
1703
+ last_rag_context,
1704
+ last_question,
1705
+ last_answer_md,
1706
+ last_df,
1707
+ False,
1708
+ "",
1709
+ attempts_left,
1710
+ )
1711
+
1712
+ if rating not in ("correct", "wrong"):
1713
+ return (
1714
+ last_answer_md + "\n\n Please select **Correct** or **Wrong**.",
1715
+ last_df,
1716
+ last_sql,
1717
+ last_rag_context,
1718
+ last_question,
1719
+ last_answer_md,
1720
+ last_df,
1721
+ False,
1722
+ "",
1723
+ attempts_left,
1724
+ )
1725
+
1726
+ # ============================================================
1727
+ # CORRECT -> no attempt decrement
1728
+ # ============================================================
1729
+ if rating == "correct":
1730
+ original_doc = find_cached_doc_by_sql(
1731
+ last_question, last_sql, store=get_sql_cache_store()
1732
+ )
1733
+
1734
+ update_cache_on_feedback(
1735
+ question=last_question,
1736
+ original_doc_md=original_doc,
1737
+ user_marked_good=True,
1738
+ llm_corrected_sql=None,
1739
+ llm_corrected_answer_md=None,
1740
+ store=get_sql_cache_store(),
1741
+ )
1742
+
1743
+ record_feedback(
1744
+ conn=conn,
1745
+ question=last_question,
1746
+ generated_sql=last_sql,
1747
+ model_answer=last_answer_md,
1748
+ rating="good",
1749
+ comment=comment or None,
1750
+ corrected_sql=None,
1751
+ )
1752
+
1753
+ return (
1754
+ last_answer_md + "\n\n **Feedback recorded as GOOD.**",
1755
+ last_df,
1756
+ last_sql,
1757
+ last_rag_context,
1758
+ last_question,
1759
+ last_answer_md,
1760
+ last_df,
1761
+ False,
1762
+ "",
1763
+ attempts_left,
1764
+ )
1765
+
1766
+ # ============================================================
1767
+ # WRONG -> decrement immediately
1768
+ # ============================================================
1769
+ attempts_left = max(0, attempts_left - 1)
1770
+
1771
+ # ============================================================
1772
+ # Attempts exhausted → FORCE LLM
1773
+ # ============================================================
1774
+ if attempts_left == 0:
1775
+ comment = comment or "User marked result as wrong."
1776
+
1777
+ # ============================================================
1778
+ # Insufficient feedback -> FOLLOW-UP (only if attempts remain)
1779
+ # ============================================================
1780
+ if attempts_left > 0 and not is_feedback_sufficient(comment):
1781
+ return (
1782
+ last_answer_md,
1783
+ last_df,
1784
+ last_sql,
1785
+ last_rag_context,
1786
+ last_question,
1787
+ last_answer_md,
1788
+ last_df,
1789
+ True, # awaiting follow-up
1790
+ build_followup_prompt_for_user(comment),
1791
+ attempts_left,
1792
+ )
1793
+
1794
+ # ============================================================
1795
+ # Run LLM review
1796
+ # ============================================================
1797
+ original_doc = find_cached_doc_by_sql(
1798
+ last_question, last_sql, store=get_sql_cache_store()
1799
+ )
1800
+
1801
+ corrected_sql, explanation = review_and_correct_sql_with_llm(
1802
+ question=last_question,
1803
+ generated_sql=last_sql,
1804
+ user_feedback_comment=comment,
1805
+ rag_context=last_rag_context,
1806
+ )
1807
+
1808
+ corrected_sql = corrected_sql or last_sql
1809
+ df_new, exec_error = execute_sql(corrected_sql, conn)
1810
+
1811
+ if exec_error:
1812
+ answer_core = summarize_results(
1813
+ question=last_question,
1814
+ sql=corrected_sql,
1815
+ df=None,
1816
+ rag_context=last_rag_context,
1817
+ error=exec_error,
1818
+ )
1819
+ df_new = pd.DataFrame()
1820
+ else:
1821
+ answer_core = summarize_results(
1822
+ question=last_question,
1823
+ sql=corrected_sql,
1824
+ df=df_new,
1825
+ rag_context=last_rag_context,
1826
+ error=None,
1827
+ )
1828
+
1829
+ update_cache_on_feedback(
1830
+ question=last_question,
1831
+ original_doc_md=original_doc,
1832
+ user_marked_good=False,
1833
+ llm_corrected_sql=(
1834
+ corrected_sql if corrected_sql.strip() != last_sql.strip() else None
1835
+ ),
1836
+ llm_corrected_answer_md=(
1837
+ answer_core if corrected_sql.strip() != last_sql.strip() else None
1838
+ ),
1839
+ store=get_sql_cache_store(),
1840
+ )
1841
+
1842
+ record_feedback(
1843
+ conn=conn,
1844
+ question=last_question,
1845
+ generated_sql=last_sql,
1846
+ model_answer=last_answer_md,
1847
+ rating="bad",
1848
+ comment=comment + "\n\nLLM explanation:\n" + (explanation or ""),
1849
+ corrected_sql=corrected_sql,
1850
+ )
1851
+
1852
+ final_md = (
1853
+ answer_core
1854
+ + "\n\n---\n\n"
1855
+ + f"**Final corrected SQL:**\n```sql\n{corrected_sql}\n```\n\n"
1856
+ + "### LLM Review Explanation\n"
1857
+ + (explanation or "")
1858
+ )
1859
+
1860
+ return (
1861
+ final_md,
1862
+ df_new,
1863
+ corrected_sql,
1864
+ last_rag_context,
1865
+ last_question,
1866
+ final_md,
1867
+ df_new,
1868
+ False,
1869
+ "",
1870
+ attempts_left,
1871
+ )
1872
+
1873
+
1874
+ # %%
1875
+ import gradio as gr
1876
+ import pandas as pd
1877
+
1878
+ with gr.Blocks() as demo:
1879
+ gr.Markdown("# Olist Analytics Assistant (RAG + SQL + Feedback)")
1880
+
1881
+ # ==================== STATE ====================
1882
+ last_sql_state = gr.State("")
1883
+ last_rag_state = gr.State("")
1884
+ last_question_state = gr.State("")
1885
+ last_answer_state = gr.State("")
1886
+ last_df_state = gr.State(pd.DataFrame())
1887
+
1888
+ attempts_state = gr.State(4)
1889
+ feedback_sql_state = gr.State("")
1890
+
1891
+ # ==================== MAIN UI ====================
1892
+ with gr.Row():
1893
+ with gr.Column(scale=1):
1894
+ question_in = gr.Textbox(
1895
+ label="Your question",
1896
+ placeholder="e.g. Total number of customers",
1897
+ lines=4,
1898
+ )
1899
+ submit_btn = gr.Button("Run")
1900
+ with gr.Column(scale=2):
1901
+ answer_out = gr.Markdown()
1902
+ table_out = gr.Dataframe()
1903
+
1904
+ attempts_display = gr.Markdown("**Feedback attempts remaining: 4**")
1905
+ cached_stats_md = gr.Markdown(visible=False)
1906
+
1907
+ # ==================== FEEDBACK ====================
1908
+ gr.Markdown("### Feedback")
1909
+
1910
+ feedback_rating = gr.Radio(
1911
+ ["Correct", "Wrong"],
1912
+ label="Is the answer correct?",
1913
+ value=None,
1914
+ )
1915
+ feedback_comment = gr.Textbox(
1916
+ label="Explain (required if Wrong)",
1917
+ lines=3,
1918
+ )
1919
+ feedback_btn = gr.Button("Submit feedback")
1920
+
1921
+ # ==================== FOLLOW-UP ====================
1922
+ followup_prompt_md = gr.Markdown(visible=False)
1923
+ followup_input = gr.Textbox(
1924
+ label="Please clarify",
1925
+ visible=False,
1926
+ lines=4,
1927
+ )
1928
+ followup_submit_btn = gr.Button(
1929
+ "Submit follow-up",
1930
+ visible=False,
1931
+ )
1932
+
1933
+ exhausted_md = gr.Markdown(
1934
+ "**You have exhausted your feedback attempts. Please ask a new question to continue.**",
1935
+ visible=False,
1936
+ )
1937
+
1938
+ # ==================== UI HELPERS ====================
1939
+ def reset_feedback_ui():
1940
+ return (
1941
+ gr.update(value=None, visible=True), # rating
1942
+ gr.update(value="", visible=True), # comment
1943
+ gr.update(visible=True), # submit
1944
+ gr.update(visible=False), # followup input
1945
+ gr.update(visible=False), # followup btn
1946
+ gr.update(visible=False), # followup prompt
1947
+ gr.update(visible=False), # exhausted
1948
+ )
1949
+
1950
+ def show_followup_ui(prompt: str):
1951
+ return (
1952
+ gr.update(visible=False), # rating
1953
+ gr.update(visible=False), # comment
1954
+ gr.update(visible=False), # submit
1955
+ gr.update(value="", visible=True), # followup input
1956
+ gr.update(visible=True), # followup btn
1957
+ gr.update(value=prompt, visible=True), # followup prompt
1958
+ gr.update(visible=False), # exhausted
1959
+ )
1960
+
1961
+ def show_exhausted_ui():
1962
+ return (
1963
+ gr.update(visible=False), # rating
1964
+ gr.update(visible=False), # comment
1965
+ gr.update(visible=False), # submit
1966
+ gr.update(visible=False), # followup input
1967
+ gr.update(visible=False), # followup btn
1968
+ gr.update(visible=False), # followup prompt
1969
+ gr.update(visible=True), # exhausted
1970
+ )
1971
+
1972
+ # ==================== RUN PIPELINE ====================
1973
+ def run_and_render(question):
1974
+ (
1975
+ answer_md,
1976
+ df,
1977
+ sql,
1978
+ rag,
1979
+ q,
1980
+ answer_state,
1981
+ df_state,
1982
+ _cached_matches,
1983
+ _dropdown_choices,
1984
+ _dropdown_visible,
1985
+ attempts,
1986
+ attempts_text,
1987
+ cached_stats_update,
1988
+ ) = backend_pipeline(question)
1989
+
1990
+ return (
1991
+ answer_md,
1992
+ df,
1993
+ sql,
1994
+ rag,
1995
+ q,
1996
+ answer_state,
1997
+ df_state,
1998
+ attempts,
1999
+ attempts_text,
2000
+ cached_stats_update,
2001
+ *reset_feedback_ui(),
2002
+ )
2003
+
2004
+ submit_btn.click(
2005
+ run_and_render,
2006
+ inputs=[question_in],
2007
+ outputs=[
2008
+ answer_out,
2009
+ table_out,
2010
+ last_sql_state,
2011
+ last_rag_state,
2012
+ last_question_state,
2013
+ last_answer_state,
2014
+ last_df_state,
2015
+ attempts_state,
2016
+ attempts_display,
2017
+ cached_stats_md,
2018
+ feedback_rating,
2019
+ feedback_comment,
2020
+ feedback_btn,
2021
+ followup_input,
2022
+ followup_submit_btn,
2023
+ followup_prompt_md,
2024
+ exhausted_md,
2025
+ ],
2026
+ )
2027
+
2028
+ # ==================== FEEDBACK HANDLER ====================
2029
+ def handle_feedback(
2030
+ rating,
2031
+ comment,
2032
+ last_sql,
2033
+ last_rag,
2034
+ last_question,
2035
+ last_answer,
2036
+ last_df,
2037
+ feedback_sql,
2038
+ attempts_left,
2039
+ ):
2040
+ (
2041
+ answer_md,
2042
+ df_new,
2043
+ sql_new,
2044
+ rag_new,
2045
+ q_new,
2046
+ ans_state,
2047
+ df_state,
2048
+ awaiting_followup,
2049
+ followup_prompt,
2050
+ attempts_new,
2051
+ ) = feedback_pipeline_interactive(
2052
+ rating,
2053
+ comment,
2054
+ last_sql,
2055
+ last_rag,
2056
+ last_question,
2057
+ last_answer,
2058
+ last_df,
2059
+ feedback_sql,
2060
+ attempts_left,
2061
+ )
2062
+
2063
+ attempts_md = f"**Feedback attempts remaining: {attempts_new}**"
2064
+
2065
+ # Exhausted
2066
+ if attempts_new <= 0:
2067
+ ui = show_exhausted_ui()
2068
+ return (
2069
+ answer_md,
2070
+ df_new,
2071
+ sql_new,
2072
+ rag_new,
2073
+ q_new,
2074
+ ans_state,
2075
+ df_state,
2076
+ attempts_new,
2077
+ attempts_md,
2078
+ cached_stats_md,
2079
+ *ui,
2080
+ )
2081
+
2082
+ # Follow-up only
2083
+ if awaiting_followup:
2084
+ ui = show_followup_ui(followup_prompt)
2085
+ return (
2086
+ answer_md,
2087
+ df_new,
2088
+ sql_new,
2089
+ rag_new,
2090
+ q_new,
2091
+ ans_state,
2092
+ df_state,
2093
+ attempts_new,
2094
+ attempts_md,
2095
+ cached_stats_md,
2096
+ *ui,
2097
+ )
2098
+
2099
+ # Normal reset
2100
+ ui = reset_feedback_ui()
2101
+ return (
2102
+ answer_md,
2103
+ df_new,
2104
+ sql_new,
2105
+ rag_new,
2106
+ q_new,
2107
+ ans_state,
2108
+ df_state,
2109
+ attempts_new,
2110
+ attempts_md,
2111
+ cached_stats_md,
2112
+ *ui,
2113
+ )
2114
+
2115
+ feedback_btn.click(
2116
+ handle_feedback,
2117
+ inputs=[
2118
+ feedback_rating,
2119
+ feedback_comment,
2120
+ last_sql_state,
2121
+ last_rag_state,
2122
+ last_question_state,
2123
+ last_answer_state,
2124
+ last_df_state,
2125
+ feedback_sql_state,
2126
+ attempts_state,
2127
+ ],
2128
+ outputs=[
2129
+ answer_out,
2130
+ table_out,
2131
+ last_sql_state,
2132
+ last_rag_state,
2133
+ last_question_state,
2134
+ last_answer_state,
2135
+ last_df_state,
2136
+ attempts_state,
2137
+ attempts_display,
2138
+ cached_stats_md,
2139
+ feedback_rating,
2140
+ feedback_comment,
2141
+ feedback_btn,
2142
+ followup_input,
2143
+ followup_submit_btn,
2144
+ followup_prompt_md,
2145
+ exhausted_md,
2146
+ ],
2147
+ )
2148
+
2149
+ followup_submit_btn.click(
2150
+ handle_feedback,
2151
+ inputs=[
2152
+ feedback_rating,
2153
+ followup_input,
2154
+ last_sql_state,
2155
+ last_rag_state,
2156
+ last_question_state,
2157
+ last_answer_state,
2158
+ last_df_state,
2159
+ feedback_sql_state,
2160
+ attempts_state,
2161
+ ],
2162
+ outputs=[
2163
+ answer_out,
2164
+ table_out,
2165
+ last_sql_state,
2166
+ last_rag_state,
2167
+ last_question_state,
2168
+ last_answer_state,
2169
+ last_df_state,
2170
+ attempts_state,
2171
+ attempts_display,
2172
+ cached_stats_md,
2173
+ feedback_rating,
2174
+ feedback_comment,
2175
+ feedback_btn,
2176
+ followup_input,
2177
+ followup_submit_btn,
2178
+ followup_prompt_md,
2179
+ exhausted_md,
2180
+ ],
2181
+ )
2182
+
2183
+ # %%
2184
+ if __name__ == "__main__":
2185
+ demo.launch()
2186
+
2187
+