Shashwat-18 commited on
Commit
aed2722
·
verified ·
1 Parent(s): 7ec9af7

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -2084
app.py DELETED
@@ -1,2084 +0,0 @@
1
- # %%
2
- import os
3
- import uuid
4
- import sqlite3
5
- import datetime
6
- import json
7
- import re
8
- import time
9
- from typing import Dict, Any, List, Tuple, Optional
10
- import hashlib
11
-
12
- import pandas as pd
13
- from groq import Groq
14
- from langchain_community.vectorstores import Chroma
15
- from langchain_community.embeddings import HuggingFaceEmbeddings
16
-
17
- # %%
18
- # 0. Global config
19
- DB_PATH = "olist.db"
20
- DATA_DIR = "data"
21
-
22
- # Groq Model
23
- GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
24
-
25
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
26
- if not GROQ_API_KEY:
27
- raise ValueError("GROQ_API_KEY not found.")
28
-
29
- groq_client = Groq(api_key=GROQ_API_KEY)
30
-
31
- # Embedding model
32
- EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
33
-
34
- # Logging paths
35
- PROMPT_LOG_PATH = "prompt_logs.txt"
36
- RUN_LOG_PATH = "run_logs.txt"
37
-
38
- # %%
39
- # 1. Logging helpers
40
-
41
-
42
- def log_prompt(tag: str, prompt: str) -> None:
43
- """
44
- Append the full prompt to a log file, with a tag and timestamp.
45
- """
46
- timestamp = datetime.datetime.now().isoformat(timespec="seconds")
47
- header = f"\n\n================ {tag} @ {timestamp} ================\n"
48
- with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
49
- f.write(header)
50
- f.write(prompt)
51
- f.write("\n================ END PROMPT ================\n")
52
-
53
-
54
- def log_run_event(tag: str, content: str) -> None:
55
- """
56
- Append model response, final SQL, and error info into a run log.
57
- """
58
- timestamp = datetime.datetime.now().isoformat(timespec="seconds")
59
- header = f"\n\n================ {tag} @ {timestamp} ================\n"
60
- with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
61
- f.write(header)
62
- f.write(content)
63
- f.write("\n================ END EVENT ================\n")
64
-
65
-
66
-
67
- # %%
68
- # 2. Feedback table + helpers
69
-
70
- def init_feedback_table(conn: sqlite3.Connection) -> None:
71
- """
72
- Create (or upgrade) a table to capture user feedback on model answers.
73
- """
74
- conn.execute("""
75
- CREATE TABLE IF NOT EXISTS user_feedback (
76
- id INTEGER PRIMARY KEY AUTOINCREMENT,
77
- created_at TEXT NOT NULL,
78
- question TEXT NOT NULL,
79
- generated_sql TEXT,
80
- model_answer TEXT,
81
- rating TEXT CHECK(rating IN ('good','bad')) NOT NULL,
82
- comment TEXT,
83
- corrected_sql TEXT
84
- )
85
- """)
86
- conn.commit()
87
-
88
-
89
- def record_feedback(
90
- conn: sqlite3.Connection,
91
- question: str,
92
- generated_sql: str,
93
- model_answer: str,
94
- rating: str, # "good" or "bad"
95
- comment: Optional[str] = None,
96
- corrected_sql: Optional[str] = None,
97
- ) -> None:
98
- """
99
- Store user feedback about a particular model answer / SQL query.
100
- If corrected_sql is provided, it is treated as an external correction.
101
- """
102
- rating = rating.lower()
103
- if rating not in ("good", "bad"):
104
- raise ValueError("rating must be 'good' or 'bad'")
105
-
106
- ts = datetime.datetime.now().isoformat(timespec="seconds")
107
- conn.execute(
108
- """
109
- INSERT INTO user_feedback (
110
- created_at, question, generated_sql, model_answer,
111
- rating, comment, corrected_sql
112
- )
113
- VALUES (?, ?, ?, ?, ?, ?, ?)
114
- """,
115
- (ts, question, generated_sql, model_answer, rating, comment, corrected_sql),
116
- )
117
- conn.commit()
118
-
119
-
120
- def get_last_feedback_for_question(
121
- conn: sqlite3.Connection,
122
- question: str,
123
- ) -> Optional[Dict[str, Any]]:
124
- """
125
- Return the most recent feedback row for this question (if any).
126
- """
127
- cur = conn.cursor()
128
- cur.execute(
129
- """
130
- SELECT created_at, generated_sql, model_answer,
131
- rating, comment, corrected_sql
132
- FROM user_feedback
133
- WHERE question = ?
134
- ORDER BY created_at DESC
135
- LIMIT 1
136
- """,
137
- (question,),
138
- )
139
- row = cur.fetchone()
140
- if not row:
141
- return None
142
-
143
- return {
144
- "created_at": row[0],
145
- "generated_sql": row[1],
146
- "model_answer": row[2],
147
- "rating": row[3],
148
- "comment": row[4],
149
- "corrected_sql": row[5],
150
- }
151
-
152
-
153
-
154
- # %%
155
- # 3. Database setup (from CSVs)
156
-
157
- def init_db() -> sqlite3.Connection:
158
- """
159
- Load all CSVs from the data/ folder into a local SQLite DB.
160
- Table names are derived from file names (without .csv).
161
- """
162
- conn = sqlite3.connect(DB_PATH, check_same_thread=False)
163
-
164
- csv_files = [
165
- "olist_customers_dataset.csv",
166
- "olist_orders_dataset.csv",
167
- "olist_order_items_dataset.csv",
168
- "olist_products_dataset.csv",
169
- "olist_order_reviews_dataset.csv",
170
- "olist_order_payments_dataset.csv",
171
- "product_category_name_translation.csv",
172
- "olist_sellers_dataset.csv",
173
- "olist_geolocation_dataset.csv",
174
- ]
175
-
176
- for fname in csv_files:
177
- path = os.path.join(DATA_DIR, fname)
178
- print(path)
179
- if not os.path.exists(path):
180
- print(f"CSV not found: {path} - skipping")
181
- continue
182
-
183
- table_name = os.path.splitext(fname)[0]
184
- print(f"Loading {path} into table {table_name}...")
185
- df = pd.read_csv(path)
186
- df.to_sql(table_name, conn, if_exists="replace", index=False)
187
-
188
-
189
- init_feedback_table(conn)
190
-
191
- return conn
192
-
193
-
194
- conn = init_db()
195
-
196
-
197
-
198
- # %%
199
- # 4. Manual docs for Olist tables
200
-
201
- OLIST_DOCS: Dict[str, Dict[str, Any]] = {
202
- "olist_customers_dataset": {
203
- "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
204
- "columns": {
205
- "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
206
- "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
207
- "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
208
- "customer_city": "Customer's city as captured at the time of the order or registration.",
209
- "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
210
- }
211
- },
212
- "olist_orders_dataset": {
213
- "description": "Customer orders placed on the Olist marketplace, one row per order.",
214
- "columns": {
215
- "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
216
- "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
217
- "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
218
- "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
219
- "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
220
- "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
221
- "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
222
- "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
223
- }
224
- },
225
- "olist_order_items_dataset": {
226
- "description": "Order line items, one row per product per order.",
227
- "columns": {
228
- "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
229
- "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
230
- "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
231
- "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
232
- "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
233
- "price": "Item price paid by the customer for this line (in BRL, not including freight).",
234
- "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
235
- }
236
- },
237
- "olist_products_dataset": {
238
- "description": "Product catalog with physical and category attributes, one row per product.",
239
- "columns": {
240
- "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
241
- "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
242
- "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
243
- "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
244
- "product_photos_qty": "Number of product images associated with the listing.",
245
- "product_weight_g": "Product weight in grams.",
246
- "product_length_cm": "Product length in centimeters (package dimension).",
247
- "product_height_cm": "Product height in centimeters (package dimension).",
248
- "product_width_cm": "Product width in centimeters (package dimension)."
249
- }
250
- },
251
- "olist_order_reviews_dataset": {
252
- "description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
253
- "columns": {
254
- "review_id": "Primary key. Unique identifier for each review record.",
255
- "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
256
- "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
257
- "review_comment_title": "Optional short text title or summary of the review.",
258
- "review_comment_message": "Optional detailed free-text comment describing the customer experience.",
259
- "review_creation_date": "Date when the customer created the review.",
260
- "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
261
- }
262
- },
263
- "olist_order_payments_dataset": {
264
- "description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
265
- "columns": {
266
- "order_id": "Foreign key to olist_orders_dataset.order_id.",
267
- "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
268
- "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
269
- "payment_installments": "Number of installments chosen by the customer for this payment.",
270
- "payment_value": "Monetary amount paid in this payment record (in BRL)."
271
- }
272
- },
273
- "product_category_name_translation": {
274
- "description": "Lookup table mapping Portuguese product category names to English equivalents.",
275
- "columns": {
276
- "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
277
- "product_category_name_english": "Translated product category name in English."
278
- }
279
- },
280
- "olist_sellers_dataset": {
281
- "description": "Seller master data, one row per seller operating on the Olist marketplace.",
282
- "columns": {
283
- "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
284
- "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
285
- "seller_city": "City where the seller is located.",
286
- "seller_state": "State where the seller is located (two-letter Brazilian state code)."
287
- }
288
- },
289
- "olist_geolocation_dataset": {
290
- "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
291
- "columns": {
292
- "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
293
- "geolocation_lat": "Latitude coordinate of the location.",
294
- "geolocation_lng": "Longitude coordinate of the location.",
295
- "geolocation_city": "City name for the location.",
296
- "geolocation_state": "State code for the location (two-letter Brazilian state code)."
297
- }
298
- }
299
- }
300
-
301
-
302
-
303
- # %%
304
- # 5. Schema extractor + metadata YAML
305
-
306
- def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
307
- """
308
- Introspect SQLite tables, columns and foreign key relationships.
309
- """
310
- cursor = connection.cursor()
311
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
312
- tables = [row[0] for row in cursor.fetchall()]
313
-
314
- metadata: Dict[str, Any] = {"tables": {}}
315
-
316
- for table in tables:
317
- cursor.execute(f"PRAGMA table_info('{table}')")
318
- cols = cursor.fetchall()
319
-
320
- # Manual docs for the table
321
- table_docs = OLIST_DOCS.get(table, {})
322
- table_desc: str = table_docs.get(
323
- "description",
324
- f"Table '{table}' from Olist dataset."
325
- )
326
- column_docs: Dict[str, str] = table_docs.get("columns", {})
327
-
328
- columns_meta: Dict[str, str] = {}
329
- for c in cols:
330
- col_name = c[1]
331
- col_type = c[2] or "TEXT"
332
-
333
- if col_name in column_docs:
334
- columns_meta[col_name] = column_docs[col_name]
335
- else:
336
- columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
337
-
338
- cursor.execute(f"PRAGMA foreign_key_list('{table}')")
339
- fk_rows = cursor.fetchall()
340
- relationships: List[str] = []
341
- for fk in fk_rows:
342
- ref_table = fk[2]
343
- from_col = fk[3]
344
- to_col = fk[4]
345
- relationships.append(
346
- f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)"
347
- )
348
-
349
- metadata["tables"][table] = {
350
- "description": table_desc,
351
- "columns": columns_meta,
352
- "relationships": relationships,
353
- }
354
-
355
- return metadata
356
-
357
-
358
- schema_metadata = extract_schema_metadata(conn)
359
-
360
-
361
- def build_schema_yaml(metadata: Dict[str, Any]) -> str:
362
- """
363
- Render metadata dict into a YAML-style string.
364
- """
365
- lines: List[str] = ["tables:"]
366
- for tname, tinfo in metadata["tables"].items():
367
- lines.append(f" {tname}:")
368
- desc = tinfo.get("description", "").replace('"', "'")
369
- lines.append(f' description: "{desc}"')
370
- lines.append(" columns:")
371
- for col_name, col_desc in tinfo.get("columns", {}).items():
372
- col_desc_clean = col_desc.replace('"', "'")
373
- lines.append(f' {col_name}: "{col_desc_clean}"')
374
- rels = tinfo.get("relationships", [])
375
- if rels:
376
- lines.append(" relationships:")
377
- for rel in rels:
378
- rel_clean = rel.replace('"', "'")
379
- lines.append(f' - "{rel_clean}"')
380
- return "\n".join(lines)
381
-
382
-
383
- schema_yaml = build_schema_yaml(schema_metadata)
384
-
385
-
386
-
387
- # %%
388
- # 6. Build schema documents for RAG (taking samples from the table)
389
-
390
- def build_schema_documents(
391
- connection: sqlite3.Connection,
392
- schema_metadata: Dict[str, Any],
393
- sample_rows: int = 5,
394
- ) -> Tuple[List[str], List[Dict[str, Any]]]:
395
- """
396
- Build one rich RAG document per table, using schema_metadata.
397
-
398
- Each document includes:
399
- - Table name
400
- - Table description
401
- - Columns with type + description
402
- - Relationships (FKs)
403
- - A few sample rows
404
- """
405
- cursor = connection.cursor()
406
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
407
- tables = [row[0] for row in cursor.fetchall()]
408
-
409
- docs: List[str] = []
410
- metadatas: List[Dict[str, Any]] = []
411
-
412
- for table in tables:
413
- tmeta = schema_metadata["tables"][table]
414
- table_desc = tmeta.get("description", "")
415
- columns_meta = tmeta.get("columns", {})
416
- relationships = tmeta.get("relationships", [])
417
-
418
- # Use PRAGMA to get types, then enrich with descriptions
419
- cursor.execute(f"PRAGMA table_info('{table}')")
420
- cols = cursor.fetchall()
421
-
422
- col_lines = []
423
- for c in cols:
424
- col_name = c[1]
425
- col_type = c[2] or "TEXT"
426
- col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
427
- col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
428
-
429
- # Sample rows
430
- try:
431
- sample_df = pd.read_sql_query(
432
- f"SELECT * FROM '{table}' LIMIT {sample_rows}",
433
- connection,
434
- )
435
- sample_text = sample_df.to_markdown(index=False)
436
- except Exception:
437
- sample_text = "(could not fetch sample rows)"
438
-
439
- # Relationships block
440
- rel_block = ""
441
- if relationships:
442
- rel_block = "Relationships:\n" + "\n".join(
443
- f"- {rel}" for rel in relationships
444
- ) + "\n"
445
-
446
- doc_text = (
447
- f"Table: {table}\n"
448
- f"Description: {table_desc}\n\n"
449
- f"Columns:\n" + "\n".join(col_lines) + "\n\n"
450
- f"{rel_block}\n"
451
- f"Example rows:\n{sample_text}\n"
452
- )
453
-
454
- docs.append(doc_text)
455
- metadatas.append({
456
- "doc_type": "table_schema",
457
- "table_name": table,
458
- })
459
-
460
- return docs, metadatas
461
-
462
-
463
- # Build RAG texts + metadata
464
- schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
465
-
466
- RAG_TEXTS: List[str] = []
467
- RAG_METADATAS: List[Dict[str, Any]] = []
468
-
469
- # 1) Per-table docs
470
- RAG_TEXTS.extend(schema_docs)
471
- RAG_METADATAS.extend(schema_doc_metas)
472
-
473
- # 2) Global YAML as a separate doc
474
- RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
475
- RAG_METADATAS.append({"doc_type": "global_schema"})
476
-
477
-
478
- # %%
479
- # 7. Build Chroma store + global RAG retriever
480
-
481
- def build_store_final() -> Tuple[Chroma, Any]:
482
- """
483
- Build the production RAG store with fixed embedding model.
484
- """
485
- embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
486
-
487
- collection_name = f"olist_schema_all_mpnet_{uuid.uuid4().hex[:8]}"
488
-
489
- store = Chroma.from_texts(
490
- RAG_TEXTS,
491
- embedding_model,
492
- metadatas=RAG_METADATAS,
493
- collection_name=collection_name,
494
- persist_directory=None, # in-memory
495
- )
496
-
497
- retriever = store.as_retriever(
498
- search_kwargs={
499
- "k": 3,
500
- "filter": {"doc_type": "table_schema"},
501
- }
502
- )
503
-
504
- return store, retriever
505
-
506
- rag_store, rag_retriever = build_store_final()
507
-
508
-
509
- # %%
510
- # 7b. SQL cache vector store (question --> SQL + answer_md)
511
-
512
- SQL_CACHE_COLLECTION = "sql_cache_mpnet"
513
- SQL_CACHE_PERSIST_DIR = "sql_cache_chroma" # directory on disk for cached SQL
514
-
515
- sql_cache_store: Optional[Chroma] = None
516
-
517
- def get_sql_cache_store() -> Chroma:
518
- """
519
- Return a Chroma store dedicated to caching
520
- confirmed-good (question, sql, answer_md) triples.
521
- """
522
- global sql_cache_store
523
-
524
- if sql_cache_store is not None:
525
- return sql_cache_store
526
-
527
- embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
528
-
529
- # This will create the collection the first time, and reload it on later runs
530
- sql_cache_store = Chroma(
531
- collection_name=SQL_CACHE_COLLECTION,
532
- embedding_function=embedding_model,
533
- persist_directory=SQL_CACHE_PERSIST_DIR,
534
- )
535
- return sql_cache_store
536
-
537
-
538
-
539
- # %%
540
- def sanitize_metadata_for_chroma(metadata: dict) -> dict:
541
- safe = {}
542
- for k, v in (metadata or {}).items():
543
- if v is None:
544
- safe[k] = ""
545
- elif isinstance(v, (str, int, float, bool)):
546
- safe[k] = v
547
- else:
548
- safe[k] = str(v)
549
- return safe
550
-
551
-
552
- # %%
553
- # Helper: normalize question
554
-
555
- def normalize_question_text(q: str) -> str:
556
- if not q:
557
- return ""
558
- q = q.strip().lower()
559
- q = re.sub(r"[^\w\s]", " ", q)
560
- q = re.sub(r"\s+", " ", q).strip()
561
- return q
562
-
563
- # Helper: compute success rate
564
- def compute_success_rate(md: dict) -> float:
565
- sc = md.get("success_count", 0) or 0
566
- tf = md.get("total_feedbacks", 0) or 0
567
- if tf <= 0:
568
- return 0.0
569
- return float(sc) / float(tf)
570
-
571
- # Insert initial cache entry (no feedback yet)
572
- def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None):
573
- """
574
- Insert a cached entry when you run a query and want to cache it regardless of feedback.
575
- initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=0.0
576
- """
577
- if store is None:
578
- store = get_sql_cache_store()
579
-
580
- ident = uuid.uuid4().hex
581
- norm = normalize_question_text(question)
582
- md = {
583
- "id": ident,
584
- "normalized_question": norm,
585
- "sql": sql,
586
- "answer_md": answer_md,
587
- "saved_at": time.time(),
588
- "views": 1,
589
- "success_count": 0,
590
- "total_feedbacks": 0,
591
- "success_rate": 0.0,
592
- }
593
- if extra_metadata:
594
- md.update(extra_metadata)
595
-
596
- # Use store's API;
597
- store.add_texts([question], metadatas=[md])
598
- return ident
599
-
600
-
601
-
602
- # %%
603
- import time
604
- import logging
605
- from typing import Optional, List, Dict, Any
606
- from difflib import SequenceMatcher
607
-
608
- _logger = logging.getLogger(__name__)
609
-
610
- # -------------------------------------------------------------------------
611
- # Utility helpers
612
- # -------------------------------------------------------------------------
613
-
614
- def _now_ts() -> float:
615
- return time.time()
616
-
617
- def similarity_score(a: str, b: str) -> float:
618
- return SequenceMatcher(None, (a or ""), (b or "")).ratio()
619
-
620
- # -------------------------------------------------------------------------
621
- # Persist helper
622
- # -------------------------------------------------------------------------
623
-
624
- def _maybe_persist_store(store) -> None:
625
- try:
626
- if hasattr(store, "persist"):
627
- store.persist()
628
- return
629
- coll = getattr(store, "_collection", None)
630
- if coll and hasattr(coll, "persist"):
631
- coll.persist()
632
- return
633
- except Exception:
634
- pass
635
-
636
- def langchain_upsert(
637
- store,
638
- text: str,
639
- metadata: dict,
640
- cache_id: str,
641
- ):
642
- """
643
- Safe upsert via LangChain wrapper.
644
- Embeddings are ALWAYS handled correctly.
645
- """
646
- safe_md = sanitize_metadata_for_chroma(metadata)
647
-
648
- try:
649
- store.add_texts(
650
- texts=[text],
651
- metadatas=[safe_md],
652
- ids=[cache_id],
653
- )
654
- except TypeError:
655
- store.add_texts(
656
- texts=[text],
657
- metadatas=[safe_md],
658
- )
659
-
660
- _maybe_persist_store(store)
661
-
662
- # -------------------------------------------------------------------------
663
- # Cache insert / update
664
- # -------------------------------------------------------------------------
665
- def cache_sql_answer_dedup(
666
- question: str,
667
- sql: str,
668
- answer_md: str,
669
- metadata: dict,
670
- store,
671
- ):
672
- norm_q = normalize_text(question)
673
- cache_id = generate_cache_id(question, sql)
674
-
675
- now = _now_ts()
676
-
677
- md = {
678
- "sql": sql,
679
- "answer_md": answer_md,
680
- "cache_id": cache_id,
681
-
682
- # ---- timestamps ----
683
- "saved_at": metadata.get("saved_at", now), # set once
684
- "last_updated_at": now, # always updated
685
-
686
- # ---- metrics ----
687
- "good_count": metadata.get("good_count", 0),
688
- "bad_count": metadata.get("bad_count", 0),
689
- "total_feedbacks": metadata.get("total_feedbacks", 0),
690
- "success_rate": metadata.get("success_rate", 0.5),
691
- "views": metadata.get("views", 0),
692
- }
693
-
694
- langchain_upsert(
695
- store=store,
696
- text=norm_q,
697
- metadata=md,
698
- cache_id=cache_id,
699
- )
700
-
701
- return {
702
- "question": question,
703
- "sql": sql,
704
- "answer_md": answer_md,
705
- "metadata": md,
706
- }
707
-
708
-
709
- # -------------------------------------------------------------------------
710
- # Find exact cached entry (question + SQL)
711
- # -------------------------------------------------------------------------
712
-
713
- def find_cached_doc_by_sql(question: str, sql: str, store):
714
- cache_id = generate_cache_id(question, sql)
715
- coll = getattr(store, "_collection", None)
716
-
717
- if coll and hasattr(coll, "get"):
718
- try:
719
- res = coll.get(ids=[cache_id])
720
- if res and res.get("metadatas"):
721
- md = res["metadatas"][0]
722
- return {
723
- "id": cache_id,
724
- "question": question,
725
- "sql": md.get("sql"),
726
- "answer_md": md.get("answer_md"),
727
- "metadata": md,
728
- }
729
- except Exception:
730
- pass
731
-
732
- return None
733
-
734
- # -------------------------------------------------------------------------
735
- # Retrieve cached answers ranked primarily by success rate
736
- # -------------------------------------------------------------------------
737
-
738
- import re
739
- import unicodedata
740
-
741
-
742
- def normalize_text(text: str) -> str:
743
- if not text:
744
- return ""
745
- text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii")
746
- text = text.lower()
747
- text = re.sub(r"[^\w\s]", " ", text)
748
- text = re.sub(r"\s+", " ", text).strip()
749
- return text
750
-
751
- import hashlib
752
-
753
- def generate_cache_id(question: str, sql: str) -> str:
754
- q = normalize_text(question)
755
- s = (sql or "").strip()
756
- key = f"{q}||{s}".encode("utf-8")
757
- return hashlib.sha1(key).hexdigest()
758
-
759
- def retrieve_best_cached_sql(
760
- question: str,
761
- store,
762
- max_distance: float = 0.25,
763
- ):
764
- norm_q = normalize_text(question)
765
-
766
- results = store.similarity_search_with_score(norm_q, k=10)
767
- if not results:
768
- return None
769
-
770
- candidates = []
771
-
772
- for doc, score in results:
773
- if score > max_distance:
774
- continue
775
-
776
- md = doc.metadata or {}
777
- if "sql" not in md:
778
- continue
779
-
780
- candidates.append({
781
- "matched_question": doc.page_content,
782
- "sql": md["sql"],
783
- "answer_md": md.get("answer_md", ""),
784
- "distance": float(score),
785
- "metadata": md,
786
- "success_rate": float(md.get("success_rate", 0.5)),
787
- })
788
-
789
- if not candidates:
790
- return None
791
-
792
- candidates.sort(key=lambda c: (-c["success_rate"], c["distance"]))
793
- return candidates[0]
794
-
795
-
796
- # -------------------------------------------------------------------------
797
- # Increment views
798
- # -------------------------------------------------------------------------
799
-
800
- def increment_cache_views(metadata: dict, store):
801
- if not metadata:
802
- return False
803
-
804
- cache_id = metadata.get("cache_id")
805
- sql = metadata.get("sql")
806
- if not cache_id or not sql:
807
- return False
808
-
809
- md = dict(metadata)
810
- md["views"] = int(md.get("views", 0)) + 1
811
- md["last_updated_at"] = _now_ts()
812
-
813
- # ⚠️ preserve saved_at
814
- md["saved_at"] = md.get("saved_at", _now_ts())
815
-
816
- try:
817
- langchain_upsert(
818
- store=store,
819
- text=normalize_text(md.get("question", "")),
820
- metadata=md,
821
- cache_id=cache_id,
822
- )
823
- return True
824
- except Exception:
825
- _logger.exception("increment_cache_views failed")
826
- return False
827
-
828
-
829
- # Update metrics on feedback
830
- def update_cache_on_feedback(
831
- question: str,
832
- original_doc_md: dict,
833
- user_marked_good: bool,
834
- llm_corrected_sql: str | None,
835
- llm_corrected_answer_md: str | None,
836
- store,
837
- ):
838
- if not original_doc_md:
839
- return
840
-
841
- md = dict(original_doc_md["metadata"])
842
- cache_id = md["cache_id"]
843
-
844
- # ---- feedback counts ----
845
- if user_marked_good:
846
- md["good_count"] = md.get("good_count", 0) + 1
847
- else:
848
- md["bad_count"] = md.get("bad_count", 0) + 1
849
-
850
- md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1
851
- md["success_rate"] = (
852
- md["good_count"] / md["total_feedbacks"]
853
- if md["total_feedbacks"] > 0 else 0.5
854
- )
855
-
856
- # ---- timestamps ----
857
- md["saved_at"] = md.get("saved_at", _now_ts()) # preserve
858
- md["last_updated_at"] = _now_ts()
859
-
860
- langchain_upsert(
861
- store=store,
862
- text=normalize_text(question),
863
- metadata=md,
864
- cache_id=cache_id,
865
- )
866
-
867
- # -------------------------
868
- # Corrected SQL → NEW ENTRY
869
- # -------------------------
870
- if llm_corrected_sql and llm_corrected_answer_md:
871
- cache_sql_answer_dedup(
872
- question=question,
873
- sql=llm_corrected_sql,
874
- answer_md=llm_corrected_answer_md,
875
- metadata={
876
- "good_count": 1,
877
- "bad_count": 0,
878
- "total_feedbacks": 1,
879
- "success_rate": 1.0,
880
- "views": 0,
881
- "saved_at": _now_ts(),
882
- },
883
- store=store,
884
- )
885
-
886
-
887
- # %%
888
- # ### 8. Groq LLM via LangChain
889
-
890
- from langchain_groq import ChatGroq
891
- import re
892
- import gradio as gr
893
-
894
- llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
895
-
896
- # %%
897
- def get_rag_context(question: str) -> str:
898
- """
899
- Retrieve the most relevant schema documents for the question.
900
- """
901
- docs = rag_retriever.invoke(question)
902
- return "\n\n---\n\n".join(d.page_content for d in docs)
903
-
904
- def clean_sql(sql: str) -> str:
905
- sql = sql.strip()
906
- if "```" in sql:
907
- sql = sql.replace("```sql", "").replace("```", "").strip()
908
- return sql
909
-
910
- def extract_sql_from_markdown(text: str) -> str:
911
- """
912
- Extract the first ```sql ... ``` block from LLM output.
913
- If not found, return the whole text.
914
- """
915
- match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
916
- if match:
917
- return match.group(1).strip()
918
- return text.strip()
919
-
920
- def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str:
921
- """
922
- After the given marker, return the rest of the text as explanation.
923
- """
924
- idx = text.upper().find(marker.upper())
925
- if idx == -1:
926
- return text.strip()
927
- return text[idx + len(marker):].strip()
928
-
929
- # 6b. General / descriptive questions
930
-
931
- GENERAL_DESC_KEYWORDS = [
932
- "what is this dataset about",
933
- "what is this data about",
934
- "describe this dataset",
935
- "describe the dataset",
936
- "dataset overview",
937
- "data overview",
938
- "summary of the dataset",
939
- "explain this dataset",
940
- ]
941
-
942
-
943
- def is_general_question(question: str) -> bool:
944
- """
945
- Detect high-level descriptive questions where we should answer
946
- directly from schema context instead of generating SQL.
947
- """
948
- q = question.lower().strip()
949
- return any(key in q for key in GENERAL_DESC_KEYWORDS)
950
-
951
-
952
- def answer_general_question(question: str) -> str:
953
- """
954
- Use the RAG schema docs to generate a rich, high-level description
955
- of the Olist dataset for conceptual questions.
956
- """
957
- rag_context = get_rag_context(question)
958
-
959
- system_instructions = """
960
- You are a data documentation expert.
961
-
962
- You will be given:
963
- - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
964
- - A high-level user question like "what is this dataset about?".
965
-
966
- Your job:
967
- - Write a clear, structured overview of the dataset.
968
- - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
969
- - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
970
- - Target a non-technical person.
971
-
972
- Do NOT write SQL. Answer in Markdown.
973
- """
974
-
975
- prompt = (
976
- system_instructions
977
- + "\n\n=== SCHEMA CONTEXT ===\n"
978
- + rag_context
979
- + "\n\n=== USER QUESTION ===\n"
980
- + question
981
- + "\n\nDetailed dataset overview:"
982
- )
983
-
984
- log_prompt("GENERAL_DATASET_QUESTION", prompt)
985
-
986
- response = llm.invoke(prompt)
987
- log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
988
-
989
- return response.content.strip()
990
-
991
- # %%
992
- # ### 9. SQL execution / validation
993
-
994
- def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
995
- """
996
- Execute SQL on SQLite and return a DataFrame, else return an error.
997
- """
998
- try:
999
- df = pd.read_sql_query(sql, connection)
1000
- return df, None
1001
- except Exception as e:
1002
- return None, str(e)
1003
-
1004
- def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
1005
- """
1006
- Basic SQL validator:
1007
- - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
1008
- """
1009
- try:
1010
- cursor = connection.cursor()
1011
- cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
1012
- return True, None
1013
- except Exception as e:
1014
- return False, str(e)
1015
-
1016
- # %%
1017
- # ### 10. SQL Generation + Repair with LLM (feedback-aware)
1018
-
1019
- def build_sql_review_prompt(
1020
- question: str,
1021
- generated_sql: str,
1022
- user_feedback_comment: str,
1023
- rag_context: str,
1024
- ) -> str:
1025
- """
1026
- Prompt to let the LLM compare its SQL with the user's feedback,
1027
- decide if the join/logic is wrong, and produce a corrected SQL + explanation.
1028
-
1029
- We explicitly allow the model to say:
1030
- - "query was already correct, user mistaken" OR
1031
- - "query was wrong, here is the fix".
1032
- """
1033
- prompt = f"""
1034
- You previously generated the following SQL for a SQLite database:
1035
-
1036
- ```sql
1037
- {generated_sql}
1038
-
1039
- The user now says this query is WRONG and provided this feedback:
1040
- "{user_feedback_comment}"
1041
-
1042
- TASKS:
1043
-
1044
- Compare the SQL with the database schema (given in the context) and the user's feedback.
1045
-
1046
- Decide whether the query is actually correct or incorrect.
1047
- Make sure that the user clearly specifies why they think it is incorrect.
1048
- It can be if they are unsatisfied with the numbers, or the logic is incorrect, or SQL is invalid etc.
1049
- If any reason is not clearly specified, return the previous result as it is.
1050
-
1051
- If it is already correct, keep it unchanged and explain why the user might be mistaken.
1052
-
1053
- If it is incorrect (e.g., wrong joins, missing filters, wrong aggregation), fix it.
1054
-
1055
- Produce a corrected SQL query that better answers the question.
1056
-
1057
- If the original query is already correct, just repeat the same SQL.
1058
-
1059
- Explain in a few sentences WHO is correct (you or the user) and WHY.
1060
-
1061
- DATABASE SCHEMA (partial, from RAG):
1062
- {rag_context}
1063
-
1064
- User question:
1065
- {question}
1066
-
1067
- Return your answer in this format:
1068
-
1069
- CORRECTED_SQL:
1070
-
1071
- -- your (possibly unchanged) SQL here
1072
- SELECT ...
1073
-
1074
- EXPLANATION:
1075
- Your explanation here, clearly stating whether:
1076
-
1077
- the original query was correct or not, and
1078
-
1079
- how your corrected SQL addresses the issue (or why it didn't need changes).
1080
- """
1081
- return prompt.strip()
1082
-
1083
-
1084
- # %%
1085
- # Review and Correct SQL based on Feedback
1086
-
1087
- def review_and_correct_sql_with_llm(
1088
- question: str,
1089
- generated_sql: str,
1090
- user_feedback_comment: str,
1091
- rag_context: str,
1092
- ) -> Tuple[str, str]:
1093
- """
1094
- Ask the LLM to compare its SQL with user's feedback, decide what is wrong (or not),
1095
- and propose a corrected SQL (possibly unchanged) + explanation.
1096
-
1097
- Returns:
1098
- corrected_sql, explanation
1099
- """
1100
- prompt = build_sql_review_prompt(
1101
- question=question,
1102
- generated_sql=generated_sql,
1103
- user_feedback_comment=user_feedback_comment,
1104
- rag_context=rag_context,
1105
- )
1106
-
1107
- log_prompt("SQL_REVIEW_FEEDBACK", prompt)
1108
- response = llm.invoke(prompt)
1109
- log_run_event("RAW_MODEL_RESPONSE_SQL_REVIEW", response.content)
1110
-
1111
- # Try to extract SQL; if none, fall back to original
1112
- extracted_sql = extract_sql_from_markdown(response.content)
1113
- corrected_sql = clean_sql(extracted_sql) if extracted_sql else generated_sql
1114
- if not corrected_sql.strip():
1115
- corrected_sql = generated_sql
1116
-
1117
- explanation = extract_explanation_after_marker(
1118
- response.content,
1119
- marker="EXPLANATION:",
1120
- )
1121
- return corrected_sql, explanation
1122
-
1123
- # %%
1124
- # Generate SQL
1125
-
1126
- def generate_sql(question: str, rag_context: str) -> str:
1127
- """
1128
- Generate SQL using LLM, but also pull in any past user feedback
1129
- (corrected SQL) for this question as external guidance.
1130
- """
1131
- # External correction from past feedback, if any
1132
- last_fb = get_last_feedback_for_question(conn, question)
1133
- if last_fb and last_fb.get("corrected_sql"):
1134
- previous_feedback_block = f"""
1135
- EXTERNAL USER FEEDBACK FROM PAST RUNS:
1136
-
1137
- Previous generated SQL:
1138
- {last_fb['generated_sql']}
1139
-
1140
- Corrected SQL (preferred reference):
1141
- {last_fb['corrected_sql']}
1142
-
1143
- User comment & prior explanation:
1144
- {last_fb.get('comment') or '(none)'}
1145
-
1146
- You must avoid repeating the same mistake and should follow the logic of
1147
- the corrected SQL when appropriate, while still reasoning from the schema.
1148
- """
1149
- else:
1150
- previous_feedback_block = ""
1151
-
1152
- system_instructions = f"""
1153
- You are a senior data analyst writing SQL for a SQLite database.
1154
-
1155
- You will be given:
1156
-
1157
- - A description of available tables, columns, and relationships (schema + YAML metadata).
1158
- - A natural language question from the user.
1159
-
1160
- Your job:
1161
-
1162
- - Write ONE valid SQLite SQL query that answers the question.
1163
- - ONLY use tables and columns that exist in the schema_context.
1164
- - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1165
- - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1166
- - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator,
1167
- and round to 2 decimals when appropriate.
1168
-
1169
- {previous_feedback_block}
1170
- """
1171
-
1172
- prompt = (
1173
- system_instructions
1174
- + "\n\n=== RAG CONTEXT ===\n"
1175
- + rag_context
1176
- + "\n\n=== USER QUESTION ===\n"
1177
- + question
1178
- + "\n\nSQL query:"
1179
- )
1180
-
1181
- log_prompt("SQL_GENERATION", prompt)
1182
- response = llm.invoke(prompt)
1183
- log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
1184
-
1185
- sql = clean_sql(response.content)
1186
- return sql
1187
-
1188
-
1189
- # %%
1190
- # Repair SQL
1191
-
1192
- def repair_sql(
1193
- question: str,
1194
- rag_context: str,
1195
- bad_sql: str,
1196
- error_message: str,
1197
- ) -> str:
1198
- """
1199
- Ask the LLM to correct a failing SQL query.
1200
- """
1201
- system_instructions = """
1202
- You are a senior data analyst fixing an existing SQL query for a SQLite database.
1203
-
1204
- You will be given:
1205
-
1206
- - Schema context (tables, columns, relationships).
1207
- - The user's question.
1208
- - A previously generated SQL query that failed.
1209
- - The SQLite error message.
1210
-
1211
- Your job:
1212
-
1213
- - Diagnose why the query failed.
1214
- - Rewrite ONE valid SQLite SQL query that answers the question.
1215
- - ONLY use tables and columns that exist in the schema_context.
1216
- - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1217
- - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1218
- - Return ONLY the corrected SQL query, no explanation or markdown.
1219
- """
1220
-
1221
- prompt = (
1222
- system_instructions
1223
- + "\n\n=== RAG CONTEXT ===\n"
1224
- + rag_context
1225
- + "\n\n=== USER QUESTION ===\n"
1226
- + question
1227
- + "\n\n=== PREVIOUS (FAILING) SQL ===\n"
1228
- + bad_sql
1229
- + "\n\n=== SQLITE ERROR ===\n"
1230
- + error_message
1231
- + "\n\nCorrected SQL query:"
1232
- )
1233
-
1234
- log_prompt("SQL_REPAIR", prompt)
1235
- response = llm.invoke(prompt)
1236
- log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
1237
-
1238
- sql = clean_sql(response.content)
1239
- return sql
1240
-
1241
-
1242
- # %%
1243
- ### 11. Result summarization
1244
-
1245
- def summarize_results(
1246
- question: str,
1247
- sql: str,
1248
- df: Optional[pd.DataFrame],
1249
- rag_context: str,
1250
- error: Optional[str] = None,
1251
- ) -> str:
1252
- """
1253
- Ask the LLM to produce a concise, human-readable answer.
1254
- """
1255
- system_instructions = """
1256
- You are a senior data analyst.
1257
-
1258
- You will be given:
1259
-
1260
- The user's question.
1261
- The final SQL that was executed.
1262
- A small preview of the query result (as a Markdown table, if available).
1263
- Optional error information if the query failed.
1264
-
1265
- Your job:
1266
-
1267
- Provide a clear, concise answer in Markdown.
1268
- If the result is numeric / aggregated, explain what it means in business terms.
1269
- If there was an error, explain it simply and suggest how the user could rephrase.
1270
- Do NOT show raw SQL unless it is helpful to the user.
1271
- """
1272
-
1273
- # Build a markdown table preview if we have data
1274
- if df is not None and not df.empty:
1275
- preview_rows = min(len(df), 50)
1276
- df_preview_md = df.head(preview_rows).to_markdown(index=False)
1277
- else:
1278
- df_preview_md = "(no rows returned)"
1279
-
1280
- prompt = (
1281
- system_instructions
1282
- + "\n\n=== USER QUESTION ===\n"
1283
- + question
1284
- + "\n\n=== EXECUTED SQL ===\n"
1285
- + sql
1286
- + "\n\n=== QUERY RESULT PREVIEW ===\n"
1287
- + df_preview_md
1288
- + "\n\n=== RAG CONTEXT (schema) ===\n"
1289
- + rag_context
1290
- )
1291
-
1292
- if error:
1293
- prompt += "\n\n=== ERROR ===\n" + error
1294
-
1295
- # Logging helpers assumed to exist
1296
- log_prompt("RESULT_SUMMARY", prompt)
1297
-
1298
- response = llm.invoke(prompt)
1299
- log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
1300
-
1301
- return response.content.strip()
1302
-
1303
-
1304
- # %%
1305
- def backend_pipeline(question: str):
1306
- """
1307
- STRICT cache-first backend.
1308
-
1309
- Behavior:
1310
- - Try ONE strict cache hit (ranked by success_rate, distance).
1311
- - If hit → reuse SQL, increment views.
1312
- - Else → run LLM → execute → summarize → cache ALWAYS.
1313
- """
1314
-
1315
- # ----------------------------
1316
- # Guards
1317
- # ----------------------------
1318
- if not question or not question.strip():
1319
- return (
1320
- "Please type a question.",
1321
- pd.DataFrame(),
1322
- "", "", "", "", pd.DataFrame(),
1323
- [], [], False,
1324
- 4, "**Feedback attempts remaining: 4**",
1325
- gr.update(value="", visible=False),
1326
- )
1327
-
1328
- attempts_left = 4
1329
- attempts_text = f"**Feedback attempts remaining: {attempts_left}**"
1330
-
1331
- # ----------------------------
1332
- # General questions
1333
- # ----------------------------
1334
- if is_general_question(question):
1335
- overview_md = answer_general_question(question)
1336
- return (
1337
- overview_md,
1338
- pd.DataFrame(),
1339
- "",
1340
- "",
1341
- question,
1342
- overview_md,
1343
- pd.DataFrame(),
1344
- [], [], False,
1345
- attempts_left,
1346
- attempts_text,
1347
- gr.update(value="", visible=False),
1348
- )
1349
-
1350
- store = get_sql_cache_store()
1351
-
1352
- # ----------------------------
1353
- # STEP 0: STRICT CACHE LOOKUP
1354
- # ----------------------------
1355
- try:
1356
- cached = retrieve_best_cached_sql(
1357
- question=question,
1358
- store=store,
1359
- max_distance=0.25,
1360
- )
1361
- except Exception as e:
1362
- log_run_event("CACHE_LOOKUP_ERROR", str(e))
1363
- cached = None
1364
-
1365
- if cached:
1366
- # increment views
1367
- try:
1368
- increment_cache_views(cached["metadata"], store=store)
1369
- except Exception:
1370
- pass
1371
-
1372
- rag_context = get_rag_context(question)
1373
-
1374
- header = (
1375
- "### Cache Hit\n"
1376
- f"- **Matched question:** \"{cached['matched_question']}\"\n"
1377
- f"- **Success rate:** {cached['metadata'].get('success_rate', 0.5):.2f}\n"
1378
- f"- **Similarity distance:** {cached['distance']:.4f}\n\n"
1379
- "---\n\n"
1380
- )
1381
-
1382
- answer_md = header + (cached.get("answer_md") or "")
1383
-
1384
- try:
1385
- df, exec_error = execute_sql(cached["sql"], conn)
1386
- if exec_error:
1387
- df = pd.DataFrame()
1388
- answer_md += f"\n\n Error re-running cached SQL: `{exec_error}`"
1389
- except Exception as e:
1390
- df = pd.DataFrame()
1391
- answer_md += f"\n\n Exception re-running cached SQL: `{e}`"
1392
-
1393
- md = cached["metadata"]
1394
- stats_md = (
1395
- f"**Cached entry stats**\n\n"
1396
- f"- **Success rate:** {md.get('success_rate',0.5):.2f} \n"
1397
- f"- **Total feedbacks:** {md.get('total_feedbacks',0)} \n"
1398
- f"- **Good / Bad:** {md.get('good_count',0)} / {md.get('bad_count',0)} \n"
1399
- f"- **Views:** {md.get('views',0)} \n"
1400
- f"- **Saved at:** "
1401
- f"{datetime.datetime.fromtimestamp(md.get('saved_at')).strftime('%Y-%m-%d %H:%M') if md.get('saved_at') else 'unknown'} \n"
1402
- f"- **Last updated:** "
1403
- f"{datetime.datetime.fromtimestamp(md.get('last_updated_at')).strftime('%Y-%m-%d %H:%M') if md.get('last_updated_at') else 'unknown'}\n\n"
1404
- f"**SQL preview:**\n\n```sql\n{cached['sql']}\n```\n"
1405
-
1406
- )
1407
-
1408
- return (
1409
- answer_md,
1410
- df,
1411
- cached["sql"],
1412
- rag_context,
1413
- question,
1414
- answer_md,
1415
- df,
1416
- [], [], False,
1417
- attempts_left,
1418
- attempts_text,
1419
- gr.update(value=stats_md, visible=True),
1420
- )
1421
-
1422
- # ----------------------------
1423
- # STEP 1: LLM FLOW
1424
- # ----------------------------
1425
- rag_context = get_rag_context(question)
1426
-
1427
- sql = generate_sql(question, rag_context)
1428
- original_sql = sql
1429
-
1430
- is_valid, validation_error = validate_sql(sql, conn)
1431
- repaired = False
1432
-
1433
- if not is_valid and validation_error:
1434
- repaired_sql = repair_sql(question, rag_context, sql, validation_error)
1435
- repaired_valid, repaired_error = validate_sql(repaired_sql, conn)
1436
- if repaired_valid:
1437
- sql = repaired_sql
1438
- repaired = True
1439
- validation_error = None
1440
- else:
1441
- validation_error = repaired_error or validation_error
1442
-
1443
- df, exec_error = (None, None)
1444
- if not validation_error:
1445
- df, exec_error = execute_sql(sql, conn)
1446
- else:
1447
- exec_error = validation_error
1448
-
1449
- summary_text = summarize_results(
1450
- question=question,
1451
- sql=sql,
1452
- df=df,
1453
- rag_context=rag_context,
1454
- error=exec_error,
1455
- )
1456
-
1457
- sql_status = []
1458
- if exec_error:
1459
- sql_status.append(f"**Error:** `{exec_error}`")
1460
- else:
1461
- sql_status.append("Query ran successfully.")
1462
- if repaired:
1463
- sql_status.append("_Note: SQL was auto-repaired._")
1464
- sql_status.append("\n**Final SQL used:**\n")
1465
- sql_status.append(f"```sql\n{sql}\n```")
1466
-
1467
- answer_md = summary_text + "\n\n---\n\n" + "\n".join(sql_status)
1468
- df_preview = df if df is not None and exec_error is None else pd.DataFrame()
1469
-
1470
- # ----------------------------
1471
- # STEP 2: CACHE ALWAYS (neutral confidence)
1472
- # ----------------------------
1473
- try:
1474
- cache_sql_answer_dedup(
1475
- question=question,
1476
- sql=sql,
1477
- answer_md=answer_md,
1478
- metadata={
1479
- "good_count": 0,
1480
- "bad_count": 0,
1481
- "total_feedbacks": 0,
1482
- "success_rate": 0.5,
1483
- "views": 1,
1484
- "saved_at": _now_ts(),
1485
- },
1486
- store=store,
1487
- )
1488
- except Exception:
1489
- _logger.exception("backend_pipeline: failed to cache LLM result")
1490
-
1491
- return (
1492
- answer_md,
1493
- df_preview,
1494
- sql,
1495
- rag_context,
1496
- question,
1497
- answer_md,
1498
- df_preview,
1499
- [], [], False,
1500
- attempts_left,
1501
- attempts_text,
1502
- gr.update(value="", visible=False),
1503
- )
1504
-
1505
-
1506
- # %%
1507
- def _looks_like_sql(text: str) -> bool:
1508
- """Quick heuristic: does text contain SQL keywords / SELECT ?"""
1509
- if not text:
1510
- return False
1511
- return bool(re.search(r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b", text, flags=re.I))
1512
-
1513
-
1514
- def is_feedback_sufficient(feedback_text: str) -> bool:
1515
- """
1516
- Heuristic to decide whether the user's free-text feedback is actionable.
1517
-
1518
- Returns True if:
1519
- - length >= 20 characters AND contains a signal word (e.g., 'filter', 'year', 'should', 'instead', 'missing', 'wrong', digits),
1520
- OR
1521
- - it looks like SQL (user pasted corrected SQL),
1522
- OR
1523
- - length >= 60 characters (long feedback).
1524
- """
1525
- if not feedback_text:
1526
- return False
1527
-
1528
- text = feedback_text.strip()
1529
- if len(text) >= 60:
1530
- return True
1531
-
1532
- if _looks_like_sql(text):
1533
- return True
1534
-
1535
- # look for signal words that indicate specificity
1536
- signal_words = [
1537
- "filter", "where", "year", "month", "should", "instead", "expected",
1538
- "wrong", "missing", "aggregate", "sum", "avg", "count", "distinct",
1539
- "join", "left join", "inner join", "group by", "order by", "date",
1540
- "range", "exclude", "include", "only"
1541
- ]
1542
- lower = text.lower()
1543
- signals = sum(1 for w in signal_words if w in lower)
1544
- if signals >= 1 and len(text) >= 20:
1545
- return True
1546
-
1547
- # short hits like "numbers look off" are insufficient
1548
- return False
1549
-
1550
-
1551
- def build_followup_prompt_for_user(sample_feedback: str = "") -> str:
1552
- """
1553
- Deterministic follow-up question to ask the user when feedback is vague.
1554
- Returns a friendly prompt that the UI can display to the user.
1555
- """
1556
- base = (
1557
- "Thanks — I need a bit more detail to act on this feedback.\n\n"
1558
- "Please tell me one (or more) of the following so I can check and correct the result:\n\n"
1559
- "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n"
1560
- " the **time range** (year/month), or the **filters** applied?\n"
1561
- "2. If you expected a different number, what was the expected number (and how was it computed)?\n"
1562
- "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n"
1563
- "Examples you can copy-paste:\n"
1564
- )
1565
- examples = (
1566
- "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n"
1567
- "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n"
1568
- "- \"Please exclude canceled orders (order_status = 'canceled').\"\n"
1569
- "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n"
1570
- )
1571
- hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare."
1572
- prompt = base + examples + hint
1573
- if sample_feedback:
1574
- prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt
1575
- return prompt
1576
-
1577
-
1578
-
1579
- # %%
1580
- def feedback_pipeline_interactive(
1581
- feedback_rating: str,
1582
- feedback_comment: str,
1583
- last_sql: str,
1584
- last_rag_context: str,
1585
- last_question: str,
1586
- last_answer_md: str,
1587
- last_df: pd.DataFrame,
1588
- feedback_sql: str,
1589
- attempts_left: int,
1590
- ):
1591
- rating = (feedback_rating or "").strip().lower()
1592
- comment = (feedback_comment or "").strip()
1593
- attempts_left = int(attempts_left or 0)
1594
-
1595
- # ---------------- Guard ----------------
1596
- if not last_question or not last_sql:
1597
- return (
1598
- last_answer_md,
1599
- last_df,
1600
- last_sql,
1601
- last_rag_context,
1602
- last_question,
1603
- last_answer_md,
1604
- last_df,
1605
- False,
1606
- "",
1607
- attempts_left,
1608
- )
1609
-
1610
- if rating not in ("correct", "wrong"):
1611
- return (
1612
- last_answer_md + "\n\n Please select **Correct** or **Wrong**.",
1613
- last_df,
1614
- last_sql,
1615
- last_rag_context,
1616
- last_question,
1617
- last_answer_md,
1618
- last_df,
1619
- False,
1620
- "",
1621
- attempts_left,
1622
- )
1623
-
1624
- # ============================================================
1625
- # CORRECT -> no attempt decrement
1626
- # ============================================================
1627
- if rating == "correct":
1628
- original_doc = find_cached_doc_by_sql(
1629
- last_question, last_sql, store=get_sql_cache_store()
1630
- )
1631
-
1632
- update_cache_on_feedback(
1633
- question=last_question,
1634
- original_doc_md=original_doc,
1635
- user_marked_good=True,
1636
- llm_corrected_sql=None,
1637
- llm_corrected_answer_md=None,
1638
- store=get_sql_cache_store(),
1639
- )
1640
-
1641
- record_feedback(
1642
- conn=conn,
1643
- question=last_question,
1644
- generated_sql=last_sql,
1645
- model_answer=last_answer_md,
1646
- rating="good",
1647
- comment=comment or None,
1648
- corrected_sql=None,
1649
- )
1650
-
1651
- return (
1652
- last_answer_md + "\n\n **Feedback recorded as GOOD.**",
1653
- last_df,
1654
- last_sql,
1655
- last_rag_context,
1656
- last_question,
1657
- last_answer_md,
1658
- last_df,
1659
- False,
1660
- "",
1661
- attempts_left,
1662
- )
1663
-
1664
- # ============================================================
1665
- # WRONG -> decrement immediately
1666
- # ============================================================
1667
- attempts_left = max(0, attempts_left - 1)
1668
-
1669
- # ============================================================
1670
- # Attempts exhausted → FORCE LLM
1671
- # ============================================================
1672
- if attempts_left == 0:
1673
- comment = comment or "User marked result as wrong."
1674
-
1675
- # ============================================================
1676
- # Insufficient feedback -> FOLLOW-UP (only if attempts remain)
1677
- # ============================================================
1678
- if attempts_left > 0 and not is_feedback_sufficient(comment):
1679
- return (
1680
- last_answer_md,
1681
- last_df,
1682
- last_sql,
1683
- last_rag_context,
1684
- last_question,
1685
- last_answer_md,
1686
- last_df,
1687
- True, # awaiting follow-up
1688
- build_followup_prompt_for_user(comment),
1689
- attempts_left,
1690
- )
1691
-
1692
- # ============================================================
1693
- # Run LLM review
1694
- # ============================================================
1695
- original_doc = find_cached_doc_by_sql(
1696
- last_question, last_sql, store=get_sql_cache_store()
1697
- )
1698
-
1699
- corrected_sql, explanation = review_and_correct_sql_with_llm(
1700
- question=last_question,
1701
- generated_sql=last_sql,
1702
- user_feedback_comment=comment,
1703
- rag_context=last_rag_context,
1704
- )
1705
-
1706
- corrected_sql = corrected_sql or last_sql
1707
- df_new, exec_error = execute_sql(corrected_sql, conn)
1708
-
1709
- if exec_error:
1710
- answer_core = summarize_results(
1711
- question=last_question,
1712
- sql=corrected_sql,
1713
- df=None,
1714
- rag_context=last_rag_context,
1715
- error=exec_error,
1716
- )
1717
- df_new = pd.DataFrame()
1718
- else:
1719
- answer_core = summarize_results(
1720
- question=last_question,
1721
- sql=corrected_sql,
1722
- df=df_new,
1723
- rag_context=last_rag_context,
1724
- error=None,
1725
- )
1726
-
1727
- update_cache_on_feedback(
1728
- question=last_question,
1729
- original_doc_md=original_doc,
1730
- user_marked_good=False,
1731
- llm_corrected_sql=(
1732
- corrected_sql if corrected_sql.strip() != last_sql.strip() else None
1733
- ),
1734
- llm_corrected_answer_md=(
1735
- answer_core if corrected_sql.strip() != last_sql.strip() else None
1736
- ),
1737
- store=get_sql_cache_store(),
1738
- )
1739
-
1740
- record_feedback(
1741
- conn=conn,
1742
- question=last_question,
1743
- generated_sql=last_sql,
1744
- model_answer=last_answer_md,
1745
- rating="bad",
1746
- comment=comment + "\n\nLLM explanation:\n" + (explanation or ""),
1747
- corrected_sql=corrected_sql,
1748
- )
1749
-
1750
- final_md = (
1751
- answer_core
1752
- + "\n\n---\n\n"
1753
- + f"**Final corrected SQL:**\n```sql\n{corrected_sql}\n```\n\n"
1754
- + "### LLM Review Explanation\n"
1755
- + (explanation or "")
1756
- )
1757
-
1758
- return (
1759
- final_md,
1760
- df_new,
1761
- corrected_sql,
1762
- last_rag_context,
1763
- last_question,
1764
- final_md,
1765
- df_new,
1766
- False,
1767
- "",
1768
- attempts_left,
1769
- )
1770
-
1771
-
1772
- # %%
1773
- import gradio as gr
1774
- import pandas as pd
1775
-
1776
- with gr.Blocks() as demo:
1777
- gr.Markdown("# Olist Analytics Assistant (RAG + SQL + Feedback)")
1778
-
1779
- # ==================== STATE ====================
1780
- last_sql_state = gr.State("")
1781
- last_rag_state = gr.State("")
1782
- last_question_state = gr.State("")
1783
- last_answer_state = gr.State("")
1784
- last_df_state = gr.State(pd.DataFrame())
1785
-
1786
- attempts_state = gr.State(4)
1787
- feedback_sql_state = gr.State("")
1788
-
1789
- # ==================== MAIN UI ====================
1790
- with gr.Row():
1791
- with gr.Column(scale=1):
1792
- question_in = gr.Textbox(
1793
- label="Your question",
1794
- placeholder="e.g. Total number of customers",
1795
- lines=4,
1796
- )
1797
- submit_btn = gr.Button("Run")
1798
- with gr.Column(scale=2):
1799
- answer_out = gr.Markdown()
1800
- table_out = gr.Dataframe()
1801
-
1802
- attempts_display = gr.Markdown("**Feedback attempts remaining: 4**")
1803
- cached_stats_md = gr.Markdown(visible=False)
1804
-
1805
- # ==================== FEEDBACK ====================
1806
- gr.Markdown("### Feedback")
1807
-
1808
- feedback_rating = gr.Radio(
1809
- ["Correct", "Wrong"],
1810
- label="Is the answer correct?",
1811
- value=None,
1812
- )
1813
- feedback_comment = gr.Textbox(
1814
- label="Explain (required if Wrong)",
1815
- lines=3,
1816
- )
1817
- feedback_btn = gr.Button("Submit feedback")
1818
-
1819
- # ==================== FOLLOW-UP ====================
1820
- followup_prompt_md = gr.Markdown(visible=False)
1821
- followup_input = gr.Textbox(
1822
- label="Please clarify",
1823
- visible=False,
1824
- lines=4,
1825
- )
1826
- followup_submit_btn = gr.Button(
1827
- "Submit follow-up",
1828
- visible=False,
1829
- )
1830
-
1831
- exhausted_md = gr.Markdown(
1832
- "**You have exhausted your feedback attempts. Please ask a new question to continue.**",
1833
- visible=False,
1834
- )
1835
-
1836
- # ==================== UI HELPERS ====================
1837
- def reset_feedback_ui():
1838
- return (
1839
- gr.update(value=None, visible=True), # rating
1840
- gr.update(value="", visible=True), # comment
1841
- gr.update(visible=True), # submit
1842
- gr.update(visible=False), # followup input
1843
- gr.update(visible=False), # followup btn
1844
- gr.update(visible=False), # followup prompt
1845
- gr.update(visible=False), # exhausted
1846
- )
1847
-
1848
- def show_followup_ui(prompt: str):
1849
- return (
1850
- gr.update(visible=False), # rating
1851
- gr.update(visible=False), # comment
1852
- gr.update(visible=False), # submit
1853
- gr.update(value="", visible=True), # followup input
1854
- gr.update(visible=True), # followup btn
1855
- gr.update(value=prompt, visible=True), # followup prompt
1856
- gr.update(visible=False), # exhausted
1857
- )
1858
-
1859
- def show_exhausted_ui():
1860
- return (
1861
- gr.update(visible=False), # rating
1862
- gr.update(visible=False), # comment
1863
- gr.update(visible=False), # submit
1864
- gr.update(visible=False), # followup input
1865
- gr.update(visible=False), # followup btn
1866
- gr.update(visible=False), # followup prompt
1867
- gr.update(visible=True), # exhausted
1868
- )
1869
-
1870
- # ==================== RUN PIPELINE ====================
1871
- def run_and_render(question):
1872
- (
1873
- answer_md,
1874
- df,
1875
- sql,
1876
- rag,
1877
- q,
1878
- answer_state,
1879
- df_state,
1880
- _cached_matches,
1881
- _dropdown_choices,
1882
- _dropdown_visible,
1883
- attempts,
1884
- attempts_text,
1885
- cached_stats_update,
1886
- ) = backend_pipeline(question)
1887
-
1888
- return (
1889
- answer_md,
1890
- df,
1891
- sql,
1892
- rag,
1893
- q,
1894
- answer_state,
1895
- df_state,
1896
- attempts,
1897
- attempts_text,
1898
- cached_stats_update,
1899
- *reset_feedback_ui(),
1900
- )
1901
-
1902
- submit_btn.click(
1903
- run_and_render,
1904
- inputs=[question_in],
1905
- outputs=[
1906
- answer_out,
1907
- table_out,
1908
- last_sql_state,
1909
- last_rag_state,
1910
- last_question_state,
1911
- last_answer_state,
1912
- last_df_state,
1913
- attempts_state,
1914
- attempts_display,
1915
- cached_stats_md,
1916
- feedback_rating,
1917
- feedback_comment,
1918
- feedback_btn,
1919
- followup_input,
1920
- followup_submit_btn,
1921
- followup_prompt_md,
1922
- exhausted_md,
1923
- ],
1924
- )
1925
-
1926
- # ==================== FEEDBACK HANDLER ====================
1927
- def handle_feedback(
1928
- rating,
1929
- comment,
1930
- last_sql,
1931
- last_rag,
1932
- last_question,
1933
- last_answer,
1934
- last_df,
1935
- feedback_sql,
1936
- attempts_left,
1937
- ):
1938
- (
1939
- answer_md,
1940
- df_new,
1941
- sql_new,
1942
- rag_new,
1943
- q_new,
1944
- ans_state,
1945
- df_state,
1946
- awaiting_followup,
1947
- followup_prompt,
1948
- attempts_new,
1949
- ) = feedback_pipeline_interactive(
1950
- rating,
1951
- comment,
1952
- last_sql,
1953
- last_rag,
1954
- last_question,
1955
- last_answer,
1956
- last_df,
1957
- feedback_sql,
1958
- attempts_left,
1959
- )
1960
-
1961
- attempts_md = f"**Feedback attempts remaining: {attempts_new}**"
1962
-
1963
- # Exhausted
1964
- if attempts_new <= 0:
1965
- ui = show_exhausted_ui()
1966
- return (
1967
- answer_md,
1968
- df_new,
1969
- sql_new,
1970
- rag_new,
1971
- q_new,
1972
- ans_state,
1973
- df_state,
1974
- attempts_new,
1975
- attempts_md,
1976
- cached_stats_md,
1977
- *ui,
1978
- )
1979
-
1980
- # Follow-up only
1981
- if awaiting_followup:
1982
- ui = show_followup_ui(followup_prompt)
1983
- return (
1984
- answer_md,
1985
- df_new,
1986
- sql_new,
1987
- rag_new,
1988
- q_new,
1989
- ans_state,
1990
- df_state,
1991
- attempts_new,
1992
- attempts_md,
1993
- cached_stats_md,
1994
- *ui,
1995
- )
1996
-
1997
- # Normal reset
1998
- ui = reset_feedback_ui()
1999
- return (
2000
- answer_md,
2001
- df_new,
2002
- sql_new,
2003
- rag_new,
2004
- q_new,
2005
- ans_state,
2006
- df_state,
2007
- attempts_new,
2008
- attempts_md,
2009
- cached_stats_md,
2010
- *ui,
2011
- )
2012
-
2013
- feedback_btn.click(
2014
- handle_feedback,
2015
- inputs=[
2016
- feedback_rating,
2017
- feedback_comment,
2018
- last_sql_state,
2019
- last_rag_state,
2020
- last_question_state,
2021
- last_answer_state,
2022
- last_df_state,
2023
- feedback_sql_state,
2024
- attempts_state,
2025
- ],
2026
- outputs=[
2027
- answer_out,
2028
- table_out,
2029
- last_sql_state,
2030
- last_rag_state,
2031
- last_question_state,
2032
- last_answer_state,
2033
- last_df_state,
2034
- attempts_state,
2035
- attempts_display,
2036
- cached_stats_md,
2037
- feedback_rating,
2038
- feedback_comment,
2039
- feedback_btn,
2040
- followup_input,
2041
- followup_submit_btn,
2042
- followup_prompt_md,
2043
- exhausted_md,
2044
- ],
2045
- )
2046
-
2047
- followup_submit_btn.click(
2048
- handle_feedback,
2049
- inputs=[
2050
- feedback_rating,
2051
- followup_input,
2052
- last_sql_state,
2053
- last_rag_state,
2054
- last_question_state,
2055
- last_answer_state,
2056
- last_df_state,
2057
- feedback_sql_state,
2058
- attempts_state,
2059
- ],
2060
- outputs=[
2061
- answer_out,
2062
- table_out,
2063
- last_sql_state,
2064
- last_rag_state,
2065
- last_question_state,
2066
- last_answer_state,
2067
- last_df_state,
2068
- attempts_state,
2069
- attempts_display,
2070
- cached_stats_md,
2071
- feedback_rating,
2072
- feedback_comment,
2073
- feedback_btn,
2074
- followup_input,
2075
- followup_submit_btn,
2076
- followup_prompt_md,
2077
- exhausted_md,
2078
- ],
2079
- )
2080
-
2081
-
2082
- # %%
2083
- if __name__ == "__main__":
2084
- demo.launch()