|
|
| import os
|
| import uuid
|
| import sqlite3
|
| import datetime
|
| import json
|
| import re
|
| import time
|
| from typing import Dict, Any, List, Tuple, Optional
|
| import hashlib
|
|
|
| import pandas as pd
|
| from groq import Groq
|
| from langchain_community.vectorstores import Chroma
|
| from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
|
|
|
| DB_PATH = "olist.db"
|
| DATA_DIR = "data"
|
|
|
|
|
| GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
|
|
|
| GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| if not GROQ_API_KEY:
|
| raise ValueError("GROQ_API_KEY not found.")
|
|
|
| groq_client = Groq(api_key=GROQ_API_KEY)
|
|
|
|
|
| EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
|
|
|
|
| PROMPT_LOG_PATH = "prompt_logs.txt"
|
| RUN_LOG_PATH = "run_logs.txt"
|
|
|
|
|
|
|
|
|
| import os
|
| import chromadb
|
| from chromadb.config import Settings
|
|
|
| CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data")
|
| os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
|
|
|
| _CHROMA_CLIENT = None
|
|
|
| def get_chroma_client():
|
| global _CHROMA_CLIENT
|
| if _CHROMA_CLIENT is None:
|
| _CHROMA_CLIENT = chromadb.PersistentClient(
|
| path=CHROMA_PERSIST_DIR,
|
| settings=Settings(
|
| anonymized_telemetry=False,
|
| ),
|
| )
|
| return _CHROMA_CLIENT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def log_prompt(tag: str, prompt: str) -> None:
|
| """
|
| Append the full prompt to a log file, with a tag and timestamp.
|
| """
|
| timestamp = datetime.datetime.now().isoformat(timespec="seconds")
|
| header = f"\n\n================ {tag} @ {timestamp} ================\n"
|
| with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
|
| f.write(header)
|
| f.write(prompt)
|
| f.write("\n================ END PROMPT ================\n")
|
|
|
|
|
| def log_run_event(tag: str, content: str) -> None:
|
| """
|
| Append model response, final SQL, and error info into a run log.
|
| """
|
| timestamp = datetime.datetime.now().isoformat(timespec="seconds")
|
| header = f"\n\n================ {tag} @ {timestamp} ================\n"
|
| with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
|
| f.write(header)
|
| f.write(content)
|
| f.write("\n================ END EVENT ================\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def init_feedback_table(conn: sqlite3.Connection) -> None:
|
| """
|
| Create (or upgrade) a table to capture user feedback on model answers.
|
| """
|
| conn.execute("""
|
| CREATE TABLE IF NOT EXISTS user_feedback (
|
| id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| created_at TEXT NOT NULL,
|
| question TEXT NOT NULL,
|
| generated_sql TEXT,
|
| model_answer TEXT,
|
| rating TEXT CHECK(rating IN ('good','bad')) NOT NULL,
|
| comment TEXT,
|
| corrected_sql TEXT
|
| )
|
| """)
|
| conn.commit()
|
|
|
|
|
| def record_feedback(
|
| conn: sqlite3.Connection,
|
| question: str,
|
| generated_sql: str,
|
| model_answer: str,
|
| rating: str,
|
| comment: Optional[str] = None,
|
| corrected_sql: Optional[str] = None,
|
| ) -> None:
|
| """
|
| Store user feedback about a particular model answer / SQL query.
|
| If corrected_sql is provided, it is treated as an external correction.
|
| """
|
| rating = rating.lower()
|
| if rating not in ("good", "bad"):
|
| raise ValueError("rating must be 'good' or 'bad'")
|
|
|
| ts = datetime.datetime.now().isoformat(timespec="seconds")
|
| conn.execute(
|
| """
|
| INSERT INTO user_feedback (
|
| created_at, question, generated_sql, model_answer,
|
| rating, comment, corrected_sql
|
| )
|
| VALUES (?, ?, ?, ?, ?, ?, ?)
|
| """,
|
| (ts, question, generated_sql, model_answer, rating, comment, corrected_sql),
|
| )
|
| conn.commit()
|
|
|
|
|
| def get_last_feedback_for_question(
|
| conn: sqlite3.Connection,
|
| question: str,
|
| ) -> Optional[Dict[str, Any]]:
|
| """
|
| Return the most recent feedback row for this question (if any).
|
| """
|
| cur = conn.cursor()
|
| cur.execute(
|
| """
|
| SELECT created_at, generated_sql, model_answer,
|
| rating, comment, corrected_sql
|
| FROM user_feedback
|
| WHERE question = ?
|
| ORDER BY created_at DESC
|
| LIMIT 1
|
| """,
|
| (question,),
|
| )
|
| row = cur.fetchone()
|
| if not row:
|
| return None
|
|
|
| return {
|
| "created_at": row[0],
|
| "generated_sql": row[1],
|
| "model_answer": row[2],
|
| "rating": row[3],
|
| "comment": row[4],
|
| "corrected_sql": row[5],
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| def init_db() -> sqlite3.Connection:
|
| """
|
| Load all CSVs from the data/ folder into a local SQLite DB.
|
| Table names are derived from file names (without .csv).
|
| """
|
| conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
|
|
| csv_files = [
|
| "olist_customers_dataset.csv",
|
| "olist_orders_dataset.csv",
|
| "olist_order_items_dataset.csv",
|
| "olist_products_dataset.csv",
|
| "olist_order_reviews_dataset.csv",
|
| "olist_order_payments_dataset.csv",
|
| "product_category_name_translation.csv",
|
| "olist_sellers_dataset.csv",
|
| "olist_geolocation_dataset.csv",
|
| ]
|
|
|
| for fname in csv_files:
|
| path = os.path.join(DATA_DIR, fname)
|
| print(path)
|
| if not os.path.exists(path):
|
| print(f"CSV not found: {path} - skipping")
|
| continue
|
|
|
| table_name = os.path.splitext(fname)[0]
|
| print(f"Loading {path} into table {table_name}...")
|
| df = pd.read_csv(path)
|
| df.to_sql(table_name, conn, if_exists="replace", index=False)
|
|
|
|
|
| init_feedback_table(conn)
|
|
|
| return conn
|
|
|
|
|
| conn = init_db()
|
|
|
|
|
|
|
|
|
|
|
|
|
| OLIST_DOCS: Dict[str, Dict[str, Any]] = {
|
| "olist_customers_dataset": {
|
| "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
|
| "columns": {
|
| "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
|
| "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
|
| "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
|
| "customer_city": "Customer's city as captured at the time of the order or registration.",
|
| "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
|
| }
|
| },
|
| "olist_orders_dataset": {
|
| "description": "Customer orders placed on the Olist marketplace, one row per order.",
|
| "columns": {
|
| "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
|
| "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
|
| "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
|
| "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
|
| "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
|
| "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
|
| "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
|
| "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
|
| }
|
| },
|
| "olist_order_items_dataset": {
|
| "description": "Order line items, one row per product per order.",
|
| "columns": {
|
| "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
|
| "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
|
| "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
|
| "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
|
| "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
|
| "price": "Item price paid by the customer for this line (in BRL, not including freight).",
|
| "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
|
| }
|
| },
|
| "olist_products_dataset": {
|
| "description": "Product catalog with physical and category attributes, one row per product.",
|
| "columns": {
|
| "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
|
| "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
|
| "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
|
| "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
|
| "product_photos_qty": "Number of product images associated with the listing.",
|
| "product_weight_g": "Product weight in grams.",
|
| "product_length_cm": "Product length in centimeters (package dimension).",
|
| "product_height_cm": "Product height in centimeters (package dimension).",
|
| "product_width_cm": "Product width in centimeters (package dimension)."
|
| }
|
| },
|
| "olist_order_reviews_dataset": {
|
| "description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
|
| "columns": {
|
| "review_id": "Primary key. Unique identifier for each review record.",
|
| "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
|
| "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
|
| "review_comment_title": "Optional short text title or summary of the review.",
|
| "review_comment_message": "Optional detailed free-text comment describing the customer experience.",
|
| "review_creation_date": "Date when the customer created the review.",
|
| "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
|
| }
|
| },
|
| "olist_order_payments_dataset": {
|
| "description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
|
| "columns": {
|
| "order_id": "Foreign key to olist_orders_dataset.order_id.",
|
| "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
|
| "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
|
| "payment_installments": "Number of installments chosen by the customer for this payment.",
|
| "payment_value": "Monetary amount paid in this payment record (in BRL)."
|
| }
|
| },
|
| "product_category_name_translation": {
|
| "description": "Lookup table mapping Portuguese product category names to English equivalents.",
|
| "columns": {
|
| "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
|
| "product_category_name_english": "Translated product category name in English."
|
| }
|
| },
|
| "olist_sellers_dataset": {
|
| "description": "Seller master data, one row per seller operating on the Olist marketplace.",
|
| "columns": {
|
| "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
|
| "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
|
| "seller_city": "City where the seller is located.",
|
| "seller_state": "State where the seller is located (two-letter Brazilian state code)."
|
| }
|
| },
|
| "olist_geolocation_dataset": {
|
| "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
|
| "columns": {
|
| "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
|
| "geolocation_lat": "Latitude coordinate of the location.",
|
| "geolocation_lng": "Longitude coordinate of the location.",
|
| "geolocation_city": "City name for the location.",
|
| "geolocation_state": "State code for the location (two-letter Brazilian state code)."
|
| }
|
| }
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
| def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
|
| """
|
| Introspect SQLite tables, columns and foreign key relationships.
|
| """
|
| cursor = connection.cursor()
|
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| tables = [row[0] for row in cursor.fetchall()]
|
|
|
| metadata: Dict[str, Any] = {"tables": {}}
|
|
|
| for table in tables:
|
| cursor.execute(f"PRAGMA table_info('{table}')")
|
| cols = cursor.fetchall()
|
|
|
|
|
| table_docs = OLIST_DOCS.get(table, {})
|
| table_desc: str = table_docs.get(
|
| "description",
|
| f"Table '{table}' from Olist dataset."
|
| )
|
| column_docs: Dict[str, str] = table_docs.get("columns", {})
|
|
|
| columns_meta: Dict[str, str] = {}
|
| for c in cols:
|
| col_name = c[1]
|
| col_type = c[2] or "TEXT"
|
|
|
| if col_name in column_docs:
|
| columns_meta[col_name] = column_docs[col_name]
|
| else:
|
| columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
|
|
|
| cursor.execute(f"PRAGMA foreign_key_list('{table}')")
|
| fk_rows = cursor.fetchall()
|
| relationships: List[str] = []
|
| for fk in fk_rows:
|
| ref_table = fk[2]
|
| from_col = fk[3]
|
| to_col = fk[4]
|
| relationships.append(
|
| f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)"
|
| )
|
|
|
| metadata["tables"][table] = {
|
| "description": table_desc,
|
| "columns": columns_meta,
|
| "relationships": relationships,
|
| }
|
|
|
| return metadata
|
|
|
|
|
| schema_metadata = extract_schema_metadata(conn)
|
|
|
|
|
| def build_schema_yaml(metadata: Dict[str, Any]) -> str:
|
| """
|
| Render metadata dict into a YAML-style string.
|
| """
|
| lines: List[str] = ["tables:"]
|
| for tname, tinfo in metadata["tables"].items():
|
| lines.append(f" {tname}:")
|
| desc = tinfo.get("description", "").replace('"', "'")
|
| lines.append(f' description: "{desc}"')
|
| lines.append(" columns:")
|
| for col_name, col_desc in tinfo.get("columns", {}).items():
|
| col_desc_clean = col_desc.replace('"', "'")
|
| lines.append(f' {col_name}: "{col_desc_clean}"')
|
| rels = tinfo.get("relationships", [])
|
| if rels:
|
| lines.append(" relationships:")
|
| for rel in rels:
|
| rel_clean = rel.replace('"', "'")
|
| lines.append(f' - "{rel_clean}"')
|
| return "\n".join(lines)
|
|
|
|
|
| schema_yaml = build_schema_yaml(schema_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_schema_documents(
|
| connection: sqlite3.Connection,
|
| schema_metadata: Dict[str, Any],
|
| sample_rows: int = 5,
|
| ) -> Tuple[List[str], List[Dict[str, Any]]]:
|
| """
|
| Build one rich RAG document per table, using schema_metadata.
|
|
|
| Each document includes:
|
| - Table name
|
| - Table description
|
| - Columns with type + description
|
| - Relationships (FKs)
|
| - A few sample rows
|
| """
|
| cursor = connection.cursor()
|
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| tables = [row[0] for row in cursor.fetchall()]
|
|
|
| docs: List[str] = []
|
| metadatas: List[Dict[str, Any]] = []
|
|
|
| for table in tables:
|
| tmeta = schema_metadata["tables"][table]
|
| table_desc = tmeta.get("description", "")
|
| columns_meta = tmeta.get("columns", {})
|
| relationships = tmeta.get("relationships", [])
|
|
|
|
|
| cursor.execute(f"PRAGMA table_info('{table}')")
|
| cols = cursor.fetchall()
|
|
|
| col_lines = []
|
| for c in cols:
|
| col_name = c[1]
|
| col_type = c[2] or "TEXT"
|
| col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
|
| col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
|
|
|
|
|
| try:
|
| sample_df = pd.read_sql_query(
|
| f"SELECT * FROM '{table}' LIMIT {sample_rows}",
|
| connection,
|
| )
|
| sample_text = sample_df.to_markdown(index=False)
|
| except Exception:
|
| sample_text = "(could not fetch sample rows)"
|
|
|
|
|
| rel_block = ""
|
| if relationships:
|
| rel_block = "Relationships:\n" + "\n".join(
|
| f"- {rel}" for rel in relationships
|
| ) + "\n"
|
|
|
| doc_text = (
|
| f"Table: {table}\n"
|
| f"Description: {table_desc}\n\n"
|
| f"Columns:\n" + "\n".join(col_lines) + "\n\n"
|
| f"{rel_block}\n"
|
| f"Example rows:\n{sample_text}\n"
|
| )
|
|
|
| docs.append(doc_text)
|
| metadatas.append({
|
| "doc_type": "table_schema",
|
| "table_name": table,
|
| })
|
|
|
| return docs, metadatas
|
|
|
|
|
|
|
| schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
|
|
|
| RAG_TEXTS: List[str] = []
|
| RAG_METADATAS: List[Dict[str, Any]] = []
|
|
|
|
|
| RAG_TEXTS.extend(schema_docs)
|
| RAG_METADATAS.extend(schema_doc_metas)
|
|
|
|
|
| RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
|
| RAG_METADATAS.append({"doc_type": "global_schema"})
|
|
|
|
|
|
|
| def build_store_final():
|
| embedding_model = HuggingFaceEmbeddings(
|
| model_name=EMBED_MODEL_NAME,
|
| encode_kwargs={"normalize_embeddings": True},
|
| )
|
|
|
| rag_store = Chroma(
|
| client=get_chroma_client(),
|
| collection_name="rag_schema_store",
|
| embedding_function=embedding_model,
|
| )
|
|
|
| if rag_store._collection.count() == 0:
|
| rag_store.add_texts(
|
| texts=RAG_TEXTS,
|
| metadatas=RAG_METADATAS,
|
| )
|
|
|
| rag_retriever = rag_store.as_retriever(
|
| search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}}
|
| )
|
|
|
| return rag_store, rag_retriever
|
|
|
|
|
| rag_store, rag_retriever = build_store_final()
|
|
|
|
|
|
|
| from langchain.vectorstores import Chroma
|
| from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
| _sql_cache_store = None
|
| _sql_embedding_fn = None
|
|
|
| SQL_CACHE_COLLECTION = "sql_cache_mpnet"
|
|
|
| def get_sql_cache_store():
|
| global _sql_cache_store, _sql_embedding_fn
|
|
|
| if _sql_cache_store is not None:
|
| return _sql_cache_store
|
|
|
| if _sql_embedding_fn is None:
|
| _sql_embedding_fn = HuggingFaceEmbeddings(
|
| model_name=EMBED_MODEL_NAME,
|
| encode_kwargs={"normalize_embeddings": True},
|
| )
|
|
|
| _sql_cache_store = Chroma(
|
| client=get_chroma_client(),
|
| collection_name=SQL_CACHE_COLLECTION,
|
| embedding_function=_sql_embedding_fn,
|
| )
|
|
|
| return _sql_cache_store
|
|
|
|
|
|
|
| def sanitize_metadata_for_chroma(metadata: dict) -> dict:
|
| safe = {}
|
| for k, v in (metadata or {}).items():
|
| if v is None:
|
| continue
|
| if isinstance(v, (int, float, bool)):
|
| safe[k] = v
|
| elif isinstance(v, str):
|
| safe[k] = v
|
| else:
|
| safe[k] = str(v)
|
| return safe
|
|
|
|
|
|
|
| def normalize_question_strict(q: str) -> str:
|
| """
|
| Deterministic normalization for exact cache hits.
|
| """
|
| if not q:
|
| return ""
|
| q = q.lower().strip()
|
| q = re.sub(r"[^\w\s]", "", q)
|
| q = re.sub(r"\s+", " ", q)
|
| return q
|
|
|
|
|
|
|
|
|
|
|
| def normalize_question_text(q: str) -> str:
|
| if not q:
|
| return ""
|
| q = q.strip().lower()
|
| q = re.sub(r"[^\w\s]", " ", q)
|
| q = re.sub(r"\s+", " ", q).strip()
|
| return q
|
|
|
|
|
| def compute_success_rate(md: dict) -> float:
|
| sc = md.get("success_count", 0) or 0
|
| tf = md.get("total_feedbacks", 0) or 0
|
| if tf <= 0:
|
| return 0.0
|
| return float(sc) / float(tf)
|
|
|
|
|
| def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None):
|
| """
|
| Insert a cached entry when you run a query and want to cache it regardless of feedback.
|
| initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0
|
| """
|
| if store is None:
|
| store = get_sql_cache_store()
|
|
|
| ident = uuid.uuid4().hex
|
| norm = normalize_question_text(question)
|
| md = {
|
| "id": ident,
|
| "normalized_question": norm,
|
| "sql": sql,
|
| "answer_md": answer_md,
|
| "saved_at": time.time(),
|
| "views": 1,
|
| "success_count": 0,
|
| "total_feedbacks": 0,
|
| "success_rate": 1.0,
|
| }
|
| if extra_metadata:
|
| md.update(extra_metadata)
|
|
|
|
|
| store.add_texts([question], metadatas=[md])
|
| return ident
|
|
|
|
|
|
|
|
|
| import time
|
| import logging
|
| from typing import Optional, List, Dict, Any
|
| from difflib import SequenceMatcher
|
|
|
| _logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
| def _now_ts() -> float:
|
| return time.time()
|
|
|
| def similarity_score(a: str, b: str) -> float:
|
| return SequenceMatcher(None, (a or ""), (b or "")).ratio()
|
|
|
|
|
|
|
|
|
| def langchain_upsert(
|
| store,
|
| text: str,
|
| metadata: dict,
|
| cache_id: str,
|
| ):
|
| safe_md = sanitize_metadata_for_chroma(metadata)
|
|
|
| try:
|
| store.add_texts(
|
| texts=[text],
|
| metadatas=[safe_md],
|
| ids=[cache_id],
|
| )
|
| except TypeError:
|
| store.add_texts(
|
| texts=[text],
|
| metadatas=[safe_md],
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def cache_sql_answer_dedup(
|
| question: str,
|
| sql: str,
|
| answer_md: str,
|
| metadata: dict,
|
| store,
|
| ):
|
| norm_q_semantic = normalize_text(question)
|
| norm_q_exact = normalize_question_strict(question)
|
|
|
| cache_id = generate_cache_id(question, sql)
|
| now = _now_ts()
|
|
|
| md = {
|
|
|
| "cache_id": cache_id,
|
| "sql": sql,
|
| "answer_md": answer_md,
|
|
|
|
|
| "normalized_question": norm_q_exact,
|
|
|
|
|
| "saved_at": metadata.get("saved_at", now),
|
| "last_updated_at": now,
|
| "last_viewed_at": metadata.get("last_viewed_at", 0),
|
|
|
|
|
| "good_count": metadata.get("good_count", 0),
|
| "bad_count": metadata.get("bad_count", 0),
|
| "total_feedbacks": metadata.get("total_feedbacks", 0),
|
| "success_rate": metadata.get("success_rate", 0.5),
|
| "views": metadata.get("views", 0),
|
| }
|
|
|
| langchain_upsert(
|
| store=store,
|
| text=norm_q_semantic,
|
| metadata=md,
|
| cache_id=cache_id,
|
| )
|
|
|
| return {
|
| "question": question,
|
| "sql": sql,
|
| "answer_md": answer_md,
|
| "metadata": md,
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def find_cached_doc_by_sql(question: str, sql: str, store):
|
| cache_id = generate_cache_id(question, sql)
|
| coll = getattr(store, "_collection", None)
|
|
|
| if coll and hasattr(coll, "get"):
|
| try:
|
| res = coll.get(ids=[cache_id])
|
| if res and res.get("metadatas"):
|
| md = res["metadatas"][0]
|
| return {
|
| "id": cache_id,
|
| "question": question,
|
| "sql": md.get("sql"),
|
| "answer_md": md.get("answer_md"),
|
| "metadata": md,
|
| }
|
| except Exception:
|
| pass
|
|
|
| return None
|
|
|
|
|
|
|
|
|
|
|
| import re
|
| import unicodedata
|
|
|
|
|
| def normalize_text(text: str) -> str:
|
| if not text:
|
| return ""
|
| text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii")
|
| text = text.lower()
|
| text = re.sub(r"[^\w\s]", " ", text)
|
| text = re.sub(r"\s+", " ", text).strip()
|
| return text
|
|
|
| import hashlib
|
|
|
| def generate_cache_id(question: str, sql: str) -> str:
|
| q = normalize_text(question)
|
| s = (sql or "").strip()
|
| key = f"{q}||{s}".encode("utf-8")
|
| return hashlib.sha1(key).hexdigest()
|
|
|
|
|
| def rank_cached_candidates(candidates: list[dict]) -> dict:
|
| """
|
| Deterministic quality-first ranking.
|
| """
|
| candidates.sort(
|
| key=lambda c: (
|
| -float(c["metadata"].get("success_rate", 0.5)),
|
| -int(c["metadata"].get("good_count", 0)),
|
| int(c["metadata"].get("bad_count", 0)),
|
| -float(c["metadata"].get("last_updated_at", 0)),
|
| )
|
| )
|
| return candidates[0]
|
|
|
|
|
|
|
| def retrieve_exact_cached_sql(question: str, store):
|
| """
|
| Exact question match, but still quality-ranked.
|
| """
|
| norm_q = normalize_question_strict(question)
|
| coll = store._collection
|
|
|
| try:
|
| res = coll.get(where={"normalized_question": norm_q})
|
| if not res or not res.get("metadatas"):
|
| return None
|
|
|
| candidates = []
|
| for md in res["metadatas"]:
|
| candidates.append({
|
| "matched_question": question,
|
| "sql": md["sql"],
|
| "answer_md": md.get("answer_md", ""),
|
| "distance": 0.0,
|
| "metadata": md,
|
| })
|
|
|
| return rank_cached_candidates(candidates)
|
|
|
| except Exception:
|
| return None
|
|
|
|
|
| def retrieve_best_cached_sql(
|
| question: str,
|
| store,
|
| max_distance: float = 0.25,
|
| ):
|
| norm_q = normalize_text(question)
|
| results = store.similarity_search_with_score(norm_q, k=20)
|
|
|
| candidates = []
|
|
|
| for doc, score in results:
|
| distance = float(score)
|
| if distance > max_distance:
|
| continue
|
|
|
| md = doc.metadata or {}
|
| if "sql" not in md:
|
| continue
|
|
|
| candidates.append({
|
| "matched_question": doc.page_content,
|
| "sql": md["sql"],
|
| "answer_md": md.get("answer_md", ""),
|
| "distance": distance,
|
| "metadata": md,
|
|
|
|
|
| "success_rate": float(md.get("success_rate", 0.5)),
|
| "good": int(md.get("good_count", 0)),
|
| "bad": int(md.get("bad_count", 0)),
|
| "views": int(md.get("views", 0)),
|
| "last_updated": float(md.get("last_updated_at", 0)),
|
| })
|
|
|
| if not candidates:
|
| return None
|
|
|
|
|
| candidates.sort(
|
| key=lambda c: (
|
| -c["success_rate"],
|
| -c["good"],
|
| c["bad"],
|
| c["distance"],
|
| -c["last_updated"],
|
| )
|
| )
|
|
|
|
|
| return candidates[0]
|
|
|
|
|
|
|
|
|
|
|
| def increment_cache_views(metadata: dict, store):
|
| if not metadata:
|
| return False
|
|
|
| cache_id = metadata.get("cache_id")
|
| if not cache_id:
|
| return False
|
|
|
| md = dict(metadata)
|
| md["views"] = int(md.get("views", 0)) + 1
|
| md["last_viewed_at"] = _now_ts()
|
| md["last_updated_at"] = _now_ts()
|
| md["saved_at"] = md.get("saved_at", _now_ts())
|
|
|
| try:
|
| langchain_upsert(
|
| store=store,
|
| text=normalize_text(md.get("normalized_question", "")),
|
| metadata=md,
|
| cache_id=cache_id,
|
| )
|
| return True
|
| except Exception:
|
| _logger.exception("increment_cache_views failed")
|
| return False
|
|
|
|
|
|
|
| def update_cache_on_feedback(
|
| question: str,
|
| original_doc_md: dict,
|
| user_marked_good: bool,
|
| llm_corrected_sql: str | None,
|
| llm_corrected_answer_md: str | None,
|
| store,
|
| ):
|
| if not original_doc_md:
|
| return
|
|
|
| md = dict(original_doc_md["metadata"])
|
| cache_id = md["cache_id"]
|
|
|
|
|
| if user_marked_good:
|
| md["good_count"] = md.get("good_count", 0) + 1
|
| else:
|
| md["bad_count"] = md.get("bad_count", 0) + 1
|
|
|
| md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1
|
| md["success_rate"] = (
|
| md["good_count"] / md["total_feedbacks"]
|
| if md["total_feedbacks"] > 0 else 0.5
|
| )
|
|
|
|
|
| md["saved_at"] = md.get("saved_at", _now_ts())
|
| md["last_updated_at"] = _now_ts()
|
|
|
| langchain_upsert(
|
| store=store,
|
| text=normalize_text(question),
|
| metadata=md,
|
| cache_id=cache_id,
|
| )
|
|
|
|
|
|
|
|
|
| if llm_corrected_sql and llm_corrected_answer_md:
|
| cache_sql_answer_dedup(
|
| question=question,
|
| sql=llm_corrected_sql,
|
| answer_md=llm_corrected_answer_md,
|
| metadata={
|
| "good_count": 1,
|
| "bad_count": 0,
|
| "total_feedbacks": 1,
|
| "success_rate": 1.0,
|
| "views": 0,
|
| "saved_at": _now_ts(),
|
| },
|
| store=store,
|
| )
|
|
|
|
|
|
|
|
|
|
|
| from langchain_groq import ChatGroq
|
| import re
|
| import gradio as gr
|
|
|
| llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
|
|
|
|
|
| def get_rag_context(question: str) -> str:
|
| """
|
| Retrieve the most relevant schema documents for the question.
|
| """
|
| docs = rag_retriever.invoke(question)
|
| return "\n\n---\n\n".join(d.page_content for d in docs)
|
|
|
| def clean_sql(sql: str) -> str:
|
| sql = sql.strip()
|
| if "```" in sql:
|
| sql = sql.replace("```sql", "").replace("```", "").strip()
|
| return sql
|
|
|
| def extract_sql_from_markdown(text: str) -> str:
|
| """
|
| Extract the first ```sql ... ``` block from LLM output.
|
| If not found, return the whole text.
|
| """
|
| match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
|
| if match:
|
| return match.group(1).strip()
|
| return text.strip()
|
|
|
| def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str:
|
| """
|
| After the given marker, return the rest of the text as explanation.
|
| """
|
| idx = text.upper().find(marker.upper())
|
| if idx == -1:
|
| return text.strip()
|
| return text[idx + len(marker):].strip()
|
|
|
|
|
|
|
| GENERAL_DESC_KEYWORDS = [
|
| "what is this dataset about",
|
| "what is this data about",
|
| "describe this dataset",
|
| "describe the dataset",
|
| "dataset overview",
|
| "data overview",
|
| "summary of the dataset",
|
| "explain this dataset",
|
| ]
|
|
|
|
|
| def is_general_question(question: str) -> bool:
|
| """
|
| Detect high-level descriptive questions where we should answer
|
| directly from schema context instead of generating SQL.
|
| """
|
| q = question.lower().strip()
|
| return any(key in q for key in GENERAL_DESC_KEYWORDS)
|
|
|
|
|
| def answer_general_question(question: str) -> str:
|
| """
|
| Use the RAG schema docs to generate a rich, high-level description
|
| of the Olist dataset for conceptual questions.
|
| """
|
| rag_context = get_rag_context(question)
|
|
|
| system_instructions = """
|
| You are a data documentation expert.
|
|
|
| You will be given:
|
| - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
|
| - A high-level user question like "what is this dataset about?".
|
|
|
| Your job:
|
| - Write a clear, structured overview of the dataset.
|
| - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
|
| - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
|
| - Target a non-technical person.
|
|
|
| Do NOT write SQL. Answer in Markdown.
|
| """
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== SCHEMA CONTEXT ===\n"
|
| + rag_context
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\nDetailed dataset overview:"
|
| )
|
|
|
| log_prompt("GENERAL_DATASET_QUESTION", prompt)
|
|
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
|
|
|
| return response.content.strip()
|
|
|
|
|
|
|
|
|
| def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
|
| """
|
| Execute SQL on SQLite and return a DataFrame, else return an error.
|
| """
|
| try:
|
| df = pd.read_sql_query(sql, connection)
|
| return df, None
|
| except Exception as e:
|
| return None, str(e)
|
|
|
| def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
|
| """
|
| Basic SQL validator:
|
| - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
|
| """
|
| try:
|
| cursor = connection.cursor()
|
| cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| return True, None
|
| except Exception as e:
|
| return False, str(e)
|
|
|
|
|
|
|
|
|
| def build_sql_review_prompt(
|
| question: str,
|
| generated_sql: str,
|
| user_feedback_comment: str,
|
| rag_context: str,
|
| ) -> str:
|
| """
|
| Prompt to let the LLM compare its SQL with the user's feedback,
|
| decide if the join/logic is wrong, and produce a corrected SQL + explanation.
|
|
|
| We explicitly allow the model to say:
|
| - "query was already correct, user mistaken" OR
|
| - "query was wrong, here is the fix".
|
| """
|
| prompt = f"""
|
| You previously generated the following SQL for a SQLite database:
|
|
|
| ```sql
|
| {generated_sql}
|
|
|
| The user now says this query is WRONG and provided this feedback:
|
| "{user_feedback_comment}"
|
|
|
| TASKS:
|
|
|
| Compare the SQL with the database schema (given in the context) and the user's feedback.
|
|
|
| Decide whether the query is actually correct or incorrect.
|
| Make sure that the user clearly specifies why they think it is incorrect.
|
| It can be if they are unsatisfied with the numbers, or the logic is incorrect, or SQL is invalid etc.
|
| If any reason is not clearly specified, return the previous result as it is.
|
|
|
| If it is already correct, keep it unchanged and explain why the user might be mistaken.
|
|
|
| If it is incorrect (e.g., wrong joins, missing filters, wrong aggregation), fix it.
|
|
|
| Produce a corrected SQL query that better answers the question.
|
|
|
| If the original query is already correct, just repeat the same SQL.
|
|
|
| Explain in a few sentences WHO is correct (you or the user) and WHY.
|
|
|
| DATABASE SCHEMA (partial, from RAG):
|
| {rag_context}
|
|
|
| User question:
|
| {question}
|
|
|
| Return your answer in this format:
|
|
|
| CORRECTED_SQL:
|
|
|
| -- your (possibly unchanged) SQL here
|
| SELECT ...
|
|
|
| EXPLANATION:
|
| Your explanation here, clearly stating whether:
|
|
|
| the original query was correct or not, and
|
|
|
| how your corrected SQL addresses the issue (or why it didn't need changes).
|
| """
|
| return prompt.strip()
|
|
|
|
|
|
|
|
|
|
|
| def review_and_correct_sql_with_llm(
|
| question: str,
|
| generated_sql: str,
|
| user_feedback_comment: str,
|
| rag_context: str,
|
| ) -> Tuple[str, str]:
|
| """
|
| Ask the LLM to compare its SQL with user's feedback, decide what is wrong (or not),
|
| and propose a corrected SQL (possibly unchanged) + explanation.
|
|
|
| Returns:
|
| corrected_sql, explanation
|
| """
|
| prompt = build_sql_review_prompt(
|
| question=question,
|
| generated_sql=generated_sql,
|
| user_feedback_comment=user_feedback_comment,
|
| rag_context=rag_context,
|
| )
|
|
|
| log_prompt("SQL_REVIEW_FEEDBACK", prompt)
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_SQL_REVIEW", response.content)
|
|
|
|
|
| extracted_sql = extract_sql_from_markdown(response.content)
|
| corrected_sql = clean_sql(extracted_sql) if extracted_sql else generated_sql
|
| if not corrected_sql.strip():
|
| corrected_sql = generated_sql
|
|
|
| explanation = extract_explanation_after_marker(
|
| response.content,
|
| marker="EXPLANATION:",
|
| )
|
| return corrected_sql, explanation
|
|
|
|
|
|
|
|
|
| def generate_sql(question: str, rag_context: str) -> str:
|
| """
|
| Generate SQL using LLM, but also pull in any past user feedback
|
| (corrected SQL) for this question as external guidance.
|
| """
|
|
|
| last_fb = get_last_feedback_for_question(conn, question)
|
| if last_fb and last_fb.get("corrected_sql"):
|
| previous_feedback_block = f"""
|
| EXTERNAL USER FEEDBACK FROM PAST RUNS:
|
|
|
| Previous generated SQL:
|
| {last_fb['generated_sql']}
|
|
|
| Corrected SQL (preferred reference):
|
| {last_fb['corrected_sql']}
|
|
|
| User comment & prior explanation:
|
| {last_fb.get('comment') or '(none)'}
|
|
|
| You must avoid repeating the same mistake and should follow the logic of
|
| the corrected SQL when appropriate, while still reasoning from the schema.
|
| """
|
| else:
|
| previous_feedback_block = ""
|
|
|
| system_instructions = f"""
|
| You are a senior data analyst writing SQL for a SQLite database.
|
|
|
| You will be given:
|
|
|
| - A description of available tables, columns, and relationships (schema + YAML metadata).
|
| - A natural language question from the user.
|
|
|
| Your job:
|
|
|
| - Write ONE valid SQLite SQL query that answers the question.
|
| - ONLY use tables and columns that exist in the schema_context.
|
| - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
|
| - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
|
| - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator,
|
| and round to 2 decimals when appropriate.
|
|
|
| {previous_feedback_block}
|
| """
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== RAG CONTEXT ===\n"
|
| + rag_context
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\nSQL query:"
|
| )
|
|
|
| log_prompt("SQL_GENERATION", prompt)
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
|
|
|
| sql = clean_sql(response.content)
|
| return sql
|
|
|
|
|
|
|
|
|
|
|
| def repair_sql(
|
| question: str,
|
| rag_context: str,
|
| bad_sql: str,
|
| error_message: str,
|
| ) -> str:
|
| """
|
| Ask the LLM to correct a failing SQL query.
|
| """
|
| system_instructions = """
|
| You are a senior data analyst fixing an existing SQL query for a SQLite database.
|
|
|
| You will be given:
|
|
|
| - Schema context (tables, columns, relationships).
|
| - The user's question.
|
| - A previously generated SQL query that failed.
|
| - The SQLite error message.
|
|
|
| Your job:
|
|
|
| - Diagnose why the query failed.
|
| - Rewrite ONE valid SQLite SQL query that answers the question.
|
| - ONLY use tables and columns that exist in the schema_context.
|
| - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
|
| - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
|
| - Return ONLY the corrected SQL query, no explanation or markdown.
|
| """
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== RAG CONTEXT ===\n"
|
| + rag_context
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\n=== PREVIOUS (FAILING) SQL ===\n"
|
| + bad_sql
|
| + "\n\n=== SQLITE ERROR ===\n"
|
| + error_message
|
| + "\n\nCorrected SQL query:"
|
| )
|
|
|
| log_prompt("SQL_REPAIR", prompt)
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
|
|
|
| sql = clean_sql(response.content)
|
| return sql
|
|
|
|
|
|
|
|
|
|
|
| def summarize_results(
|
| question: str,
|
| sql: str,
|
| df: Optional[pd.DataFrame],
|
| rag_context: str,
|
| error: Optional[str] = None,
|
| ) -> str:
|
| """
|
| Ask the LLM to produce a concise, human-readable answer.
|
| """
|
| system_instructions = """
|
| You are a senior data analyst.
|
|
|
| You will be given:
|
|
|
| The user's question.
|
| The final SQL that was executed.
|
| A small preview of the query result (as a Markdown table, if available).
|
| Optional error information if the query failed.
|
|
|
| Your job:
|
|
|
| Provide a clear, concise answer in Markdown.
|
| If the result is numeric / aggregated, explain what it means in business terms.
|
| If there was an error, explain it simply and suggest how the user could rephrase.
|
| Do NOT show raw SQL unless it is helpful to the user.
|
| """
|
|
|
|
|
| if df is not None and not df.empty:
|
| preview_rows = min(len(df), 50)
|
| df_preview_md = df.head(preview_rows).to_markdown(index=False)
|
| else:
|
| df_preview_md = "(no rows returned)"
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\n=== EXECUTED SQL ===\n"
|
| + sql
|
| + "\n\n=== QUERY RESULT PREVIEW ===\n"
|
| + df_preview_md
|
| + "\n\n=== RAG CONTEXT (schema) ===\n"
|
| + rag_context
|
| )
|
|
|
| if error:
|
| prompt += "\n\n=== ERROR ===\n" + error
|
|
|
|
|
| log_prompt("RESULT_SUMMARY", prompt)
|
|
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
|
|
|
| return response.content.strip()
|
|
|
|
|
|
|
| def backend_pipeline(question: str):
|
| """
|
| STRICT cache-first backend.
|
|
|
| Priority:
|
| 1. Exact question cache hit
|
| 2. Best semantic cache hit
|
| 3. LLM fallback
|
| """
|
|
|
|
|
|
|
|
|
| if not question or not question.strip():
|
| return (
|
| "Please type a question.",
|
| pd.DataFrame(),
|
| "",
|
| "",
|
| "",
|
| "",
|
| pd.DataFrame(),
|
| [],
|
| [],
|
| False,
|
| 4,
|
| "**Feedback attempts remaining: 4**",
|
| gr.update(value="", visible=False),
|
| )
|
|
|
| attempts_left = 4
|
| attempts_text = f"**Feedback attempts remaining: {attempts_left}**"
|
|
|
|
|
|
|
|
|
| if is_general_question(question):
|
| overview_md = answer_general_question(question)
|
| return (
|
| overview_md,
|
| pd.DataFrame(),
|
| "",
|
| "",
|
| question,
|
| overview_md,
|
| pd.DataFrame(),
|
| [],
|
| [],
|
| False,
|
| attempts_left,
|
| attempts_text,
|
| gr.update(value="", visible=False),
|
| )
|
|
|
| store = get_sql_cache_store()
|
|
|
|
|
|
|
|
|
| cached = retrieve_exact_cached_sql(question, store)
|
|
|
|
|
|
|
|
|
| if cached is None:
|
| cached = retrieve_best_cached_sql(
|
| question=question,
|
| store=store,
|
| max_distance=0.25,
|
| )
|
|
|
|
|
|
|
|
|
| if cached:
|
| try:
|
| increment_cache_views(cached["metadata"], store=store)
|
| except Exception:
|
| pass
|
|
|
| rag_context = get_rag_context(question)
|
|
|
| header = (
|
| "### Cache Hit\n"
|
| f"- **Matched question:** \"{cached['matched_question']}\"\n"
|
| f"- **Success rate:** {cached['metadata'].get('success_rate', 0.5):.2f}\n"
|
| f"- **Similarity distance:** {cached['distance']:.4f}\n\n"
|
| "---\n\n"
|
| )
|
|
|
| answer_md = header + (cached.get("answer_md") or "")
|
|
|
| try:
|
| df, exec_error = execute_sql(cached["sql"], conn)
|
| if exec_error:
|
| df = pd.DataFrame()
|
| answer_md += f"\n\n Error re-running cached SQL: `{exec_error}`"
|
| except Exception as e:
|
| df = pd.DataFrame()
|
| answer_md += f"\n\n Exception re-running cached SQL: `{e}`"
|
|
|
| md = cached["metadata"]
|
|
|
| stats_md = (
|
| f"**Cached entry stats**\n\n"
|
| f"- **Success rate:** {md.get('success_rate', 0.5):.2f} \n"
|
| f"- **Total feedbacks:** {md.get('total_feedbacks', 0)} \n"
|
| f"- **Good / Bad:** {md.get('good_count', 0)} / {md.get('bad_count', 0)} \n"
|
| f"- **Views:** {md.get('views', 0)} \n"
|
| f"- **Saved at:** "
|
| f"{datetime.datetime.fromtimestamp(md.get('saved_at')).strftime('%Y-%m-%d %H:%M') if md.get('saved_at') else 'unknown'} \n"
|
| f"- **Last updated:** "
|
| f"{datetime.datetime.fromtimestamp(md.get('last_updated_at')).strftime('%Y-%m-%d %H:%M') if md.get('last_updated_at') else 'unknown'}\n\n"
|
| f"**SQL preview:**\n\n```sql\n{cached['sql']}\n```\n"
|
| )
|
|
|
| return (
|
| answer_md,
|
| df,
|
| cached["sql"],
|
| rag_context,
|
| question,
|
| answer_md,
|
| df,
|
| [],
|
| [],
|
| False,
|
| attempts_left,
|
| attempts_text,
|
| gr.update(value=stats_md, visible=True),
|
| )
|
|
|
|
|
|
|
|
|
| rag_context = get_rag_context(question)
|
|
|
| sql = generate_sql(question, rag_context)
|
| original_sql = sql
|
|
|
| is_valid, validation_error = validate_sql(sql, conn)
|
| repaired = False
|
|
|
| if not is_valid and validation_error:
|
| repaired_sql = repair_sql(question, rag_context, sql, validation_error)
|
| repaired_valid, repaired_error = validate_sql(repaired_sql, conn)
|
| if repaired_valid:
|
| sql = repaired_sql
|
| repaired = True
|
| validation_error = None
|
| else:
|
| validation_error = repaired_error or validation_error
|
|
|
| df, exec_error = (None, None)
|
| if not validation_error:
|
| df, exec_error = execute_sql(sql, conn)
|
| else:
|
| exec_error = validation_error
|
|
|
| summary_text = summarize_results(
|
| question=question,
|
| sql=sql,
|
| df=df,
|
| rag_context=rag_context,
|
| error=exec_error,
|
| )
|
|
|
| sql_status = []
|
| if exec_error:
|
| sql_status.append(f"**Error:** `{exec_error}`")
|
| else:
|
| sql_status.append("Query ran successfully.")
|
| if repaired:
|
| sql_status.append("_Note: SQL was auto-repaired._")
|
|
|
| sql_status.append("\n**Final SQL used:**\n")
|
| sql_status.append(f"```sql\n{sql}\n```")
|
|
|
| answer_md = summary_text + "\n\n---\n\n" + "\n".join(sql_status)
|
| df_preview = df if df is not None and exec_error is None else pd.DataFrame()
|
|
|
|
|
|
|
|
|
| try:
|
| cache_sql_answer_dedup(
|
| question=question,
|
| sql=sql,
|
| answer_md=answer_md,
|
| metadata={
|
| "good_count": 0,
|
| "bad_count": 0,
|
| "total_feedbacks": 0,
|
| "success_rate": 0.5,
|
| "views": 1,
|
| "saved_at": _now_ts(),
|
| },
|
| store=store,
|
| )
|
| except Exception:
|
| _logger.exception("backend_pipeline: failed to cache LLM result")
|
|
|
| return (
|
| answer_md,
|
| df_preview,
|
| sql,
|
| rag_context,
|
| question,
|
| answer_md,
|
| df_preview,
|
| [],
|
| [],
|
| False,
|
| attempts_left,
|
| attempts_text,
|
| gr.update(value="", visible=False),
|
| )
|
|
|
|
|
|
|
| def _looks_like_sql(text: str) -> bool:
|
| """Quick heuristic: does text contain SQL keywords / SELECT ?"""
|
| if not text:
|
| return False
|
| return bool(re.search(r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b", text, flags=re.I))
|
|
|
|
|
| def is_feedback_sufficient(feedback_text: str) -> bool:
|
| """
|
| Heuristic to decide whether the user's free-text feedback is actionable.
|
|
|
| Returns True if:
|
| - length >= 20 characters AND contains a signal word (e.g., 'filter', 'year', 'should', 'instead', 'missing', 'wrong', digits),
|
| OR
|
| - it looks like SQL (user pasted corrected SQL),
|
| OR
|
| - length >= 60 characters (long feedback).
|
| """
|
| if not feedback_text:
|
| return False
|
|
|
| text = feedback_text.strip()
|
| if len(text) >= 60:
|
| return True
|
|
|
| if _looks_like_sql(text):
|
| return True
|
|
|
|
|
| signal_words = [
|
| "filter", "where", "year", "month", "should", "instead", "expected",
|
| "wrong", "missing", "aggregate", "sum", "avg", "count", "distinct",
|
| "join", "left join", "inner join", "group by", "order by", "date",
|
| "range", "exclude", "include", "only"
|
| ]
|
| lower = text.lower()
|
| signals = sum(1 for w in signal_words if w in lower)
|
| if signals >= 1 and len(text) >= 20:
|
| return True
|
|
|
|
|
| return False
|
|
|
|
|
| def build_followup_prompt_for_user(sample_feedback: str = "") -> str:
|
| """
|
| Deterministic follow-up question to ask the user when feedback is vague.
|
| Returns a friendly prompt that the UI can display to the user.
|
| """
|
| base = (
|
| "Thanks — I need a bit more detail to act on this feedback.\n\n"
|
| "Please tell me one (or more) of the following so I can check and correct the result:\n\n"
|
| "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n"
|
| " the **time range** (year/month), or the **filters** applied?\n"
|
| "2. If you expected a different number, what was the expected number (and how was it computed)?\n"
|
| "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n"
|
| "Examples you can copy-paste:\n"
|
| )
|
| examples = (
|
| "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n"
|
| "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n"
|
| "- \"Please exclude canceled orders (order_status = 'canceled').\"\n"
|
| "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n"
|
| )
|
| hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare."
|
| prompt = base + examples + hint
|
| if sample_feedback:
|
| prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt
|
| return prompt
|
|
|
|
|
|
|
|
|
| def feedback_pipeline_interactive(
|
| feedback_rating: str,
|
| feedback_comment: str,
|
| last_sql: str,
|
| last_rag_context: str,
|
| last_question: str,
|
| last_answer_md: str,
|
| last_df: pd.DataFrame,
|
| feedback_sql: str,
|
| attempts_left: int,
|
| ):
|
| rating = (feedback_rating or "").strip().lower()
|
| comment = (feedback_comment or "").strip()
|
| attempts_left = int(attempts_left or 0)
|
|
|
|
|
| if not last_question or not last_sql:
|
| return (
|
| last_answer_md,
|
| last_df,
|
| last_sql,
|
| last_rag_context,
|
| last_question,
|
| last_answer_md,
|
| last_df,
|
| False,
|
| "",
|
| attempts_left,
|
| )
|
|
|
| if rating not in ("correct", "wrong"):
|
| return (
|
| last_answer_md + "\n\n Please select **Correct** or **Wrong**.",
|
| last_df,
|
| last_sql,
|
| last_rag_context,
|
| last_question,
|
| last_answer_md,
|
| last_df,
|
| False,
|
| "",
|
| attempts_left,
|
| )
|
|
|
|
|
|
|
|
|
| if rating == "correct":
|
| original_doc = find_cached_doc_by_sql(
|
| last_question, last_sql, store=get_sql_cache_store()
|
| )
|
|
|
| update_cache_on_feedback(
|
| question=last_question,
|
| original_doc_md=original_doc,
|
| user_marked_good=True,
|
| llm_corrected_sql=None,
|
| llm_corrected_answer_md=None,
|
| store=get_sql_cache_store(),
|
| )
|
|
|
| record_feedback(
|
| conn=conn,
|
| question=last_question,
|
| generated_sql=last_sql,
|
| model_answer=last_answer_md,
|
| rating="good",
|
| comment=comment or None,
|
| corrected_sql=None,
|
| )
|
|
|
| return (
|
| last_answer_md + "\n\n **Feedback recorded as GOOD.**",
|
| last_df,
|
| last_sql,
|
| last_rag_context,
|
| last_question,
|
| last_answer_md,
|
| last_df,
|
| False,
|
| "",
|
| attempts_left,
|
| )
|
|
|
|
|
|
|
|
|
| attempts_left = max(0, attempts_left - 1)
|
|
|
|
|
|
|
|
|
| if attempts_left == 0:
|
| comment = comment or "User marked result as wrong."
|
|
|
|
|
|
|
|
|
| if attempts_left > 0 and not is_feedback_sufficient(comment):
|
| return (
|
| last_answer_md,
|
| last_df,
|
| last_sql,
|
| last_rag_context,
|
| last_question,
|
| last_answer_md,
|
| last_df,
|
| True,
|
| build_followup_prompt_for_user(comment),
|
| attempts_left,
|
| )
|
|
|
|
|
|
|
|
|
| original_doc = find_cached_doc_by_sql(
|
| last_question, last_sql, store=get_sql_cache_store()
|
| )
|
|
|
| corrected_sql, explanation = review_and_correct_sql_with_llm(
|
| question=last_question,
|
| generated_sql=last_sql,
|
| user_feedback_comment=comment,
|
| rag_context=last_rag_context,
|
| )
|
|
|
| corrected_sql = corrected_sql or last_sql
|
| df_new, exec_error = execute_sql(corrected_sql, conn)
|
|
|
| if exec_error:
|
| answer_core = summarize_results(
|
| question=last_question,
|
| sql=corrected_sql,
|
| df=None,
|
| rag_context=last_rag_context,
|
| error=exec_error,
|
| )
|
| df_new = pd.DataFrame()
|
| else:
|
| answer_core = summarize_results(
|
| question=last_question,
|
| sql=corrected_sql,
|
| df=df_new,
|
| rag_context=last_rag_context,
|
| error=None,
|
| )
|
|
|
| update_cache_on_feedback(
|
| question=last_question,
|
| original_doc_md=original_doc,
|
| user_marked_good=False,
|
| llm_corrected_sql=(
|
| corrected_sql if corrected_sql.strip() != last_sql.strip() else None
|
| ),
|
| llm_corrected_answer_md=(
|
| answer_core if corrected_sql.strip() != last_sql.strip() else None
|
| ),
|
| store=get_sql_cache_store(),
|
| )
|
|
|
| record_feedback(
|
| conn=conn,
|
| question=last_question,
|
| generated_sql=last_sql,
|
| model_answer=last_answer_md,
|
| rating="bad",
|
| comment=comment + "\n\nLLM explanation:\n" + (explanation or ""),
|
| corrected_sql=corrected_sql,
|
| )
|
|
|
| final_md = (
|
| answer_core
|
| + "\n\n---\n\n"
|
| + f"**Final corrected SQL:**\n```sql\n{corrected_sql}\n```\n\n"
|
| + "### LLM Review Explanation\n"
|
| + (explanation or "")
|
| )
|
|
|
| return (
|
| final_md,
|
| df_new,
|
| corrected_sql,
|
| last_rag_context,
|
| last_question,
|
| final_md,
|
| df_new,
|
| False,
|
| "",
|
| attempts_left,
|
| )
|
|
|
|
|
|
|
| import gradio as gr
|
| import pandas as pd
|
|
|
| with gr.Blocks() as demo:
|
| gr.Markdown("# Olist Analytics Assistant (RAG + SQL + Feedback)")
|
|
|
|
|
| last_sql_state = gr.State("")
|
| last_rag_state = gr.State("")
|
| last_question_state = gr.State("")
|
| last_answer_state = gr.State("")
|
| last_df_state = gr.State(pd.DataFrame())
|
|
|
| attempts_state = gr.State(4)
|
| feedback_sql_state = gr.State("")
|
|
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| question_in = gr.Textbox(
|
| label="Your question",
|
| placeholder="e.g. Total number of customers",
|
| lines=4,
|
| )
|
| submit_btn = gr.Button("Run")
|
| with gr.Column(scale=2):
|
| answer_out = gr.Markdown()
|
| table_out = gr.Dataframe()
|
|
|
| attempts_display = gr.Markdown("**Feedback attempts remaining: 4**")
|
| cached_stats_md = gr.Markdown(visible=False)
|
|
|
|
|
| gr.Markdown("### Feedback")
|
|
|
| feedback_rating = gr.Radio(
|
| ["Correct", "Wrong"],
|
| label="Is the answer correct?",
|
| value=None,
|
| )
|
| feedback_comment = gr.Textbox(
|
| label="Explain (required if Wrong)",
|
| lines=3,
|
| )
|
| feedback_btn = gr.Button("Submit feedback")
|
|
|
|
|
| followup_prompt_md = gr.Markdown(visible=False)
|
| followup_input = gr.Textbox(
|
| label="Please clarify",
|
| visible=False,
|
| lines=4,
|
| )
|
| followup_submit_btn = gr.Button(
|
| "Submit follow-up",
|
| visible=False,
|
| )
|
|
|
| exhausted_md = gr.Markdown(
|
| "**You have exhausted your feedback attempts. Please ask a new question to continue.**",
|
| visible=False,
|
| )
|
|
|
|
|
| def reset_feedback_ui():
|
| return (
|
| gr.update(value=None, visible=True),
|
| gr.update(value="", visible=True),
|
| gr.update(visible=True),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| )
|
|
|
| def show_followup_ui(prompt: str):
|
| return (
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(value="", visible=True),
|
| gr.update(visible=True),
|
| gr.update(value=prompt, visible=True),
|
| gr.update(visible=False),
|
| )
|
|
|
| def show_exhausted_ui():
|
| return (
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=False),
|
| gr.update(visible=True),
|
| )
|
|
|
|
|
| def run_and_render(question):
|
| (
|
| answer_md,
|
| df,
|
| sql,
|
| rag,
|
| q,
|
| answer_state,
|
| df_state,
|
| _cached_matches,
|
| _dropdown_choices,
|
| _dropdown_visible,
|
| attempts,
|
| attempts_text,
|
| cached_stats_update,
|
| ) = backend_pipeline(question)
|
|
|
| return (
|
| answer_md,
|
| df,
|
| sql,
|
| rag,
|
| q,
|
| answer_state,
|
| df_state,
|
| attempts,
|
| attempts_text,
|
| cached_stats_update,
|
| *reset_feedback_ui(),
|
| )
|
|
|
| submit_btn.click(
|
| run_and_render,
|
| inputs=[question_in],
|
| outputs=[
|
| answer_out,
|
| table_out,
|
| last_sql_state,
|
| last_rag_state,
|
| last_question_state,
|
| last_answer_state,
|
| last_df_state,
|
| attempts_state,
|
| attempts_display,
|
| cached_stats_md,
|
| feedback_rating,
|
| feedback_comment,
|
| feedback_btn,
|
| followup_input,
|
| followup_submit_btn,
|
| followup_prompt_md,
|
| exhausted_md,
|
| ],
|
| )
|
|
|
|
|
| def handle_feedback(
|
| rating,
|
| comment,
|
| last_sql,
|
| last_rag,
|
| last_question,
|
| last_answer,
|
| last_df,
|
| feedback_sql,
|
| attempts_left,
|
| ):
|
| (
|
| answer_md,
|
| df_new,
|
| sql_new,
|
| rag_new,
|
| q_new,
|
| ans_state,
|
| df_state,
|
| awaiting_followup,
|
| followup_prompt,
|
| attempts_new,
|
| ) = feedback_pipeline_interactive(
|
| rating,
|
| comment,
|
| last_sql,
|
| last_rag,
|
| last_question,
|
| last_answer,
|
| last_df,
|
| feedback_sql,
|
| attempts_left,
|
| )
|
|
|
| attempts_md = f"**Feedback attempts remaining: {attempts_new}**"
|
|
|
|
|
| if attempts_new <= 0:
|
| ui = show_exhausted_ui()
|
| return (
|
| answer_md,
|
| df_new,
|
| sql_new,
|
| rag_new,
|
| q_new,
|
| ans_state,
|
| df_state,
|
| attempts_new,
|
| attempts_md,
|
| cached_stats_md,
|
| *ui,
|
| )
|
|
|
|
|
| if awaiting_followup:
|
| ui = show_followup_ui(followup_prompt)
|
| return (
|
| answer_md,
|
| df_new,
|
| sql_new,
|
| rag_new,
|
| q_new,
|
| ans_state,
|
| df_state,
|
| attempts_new,
|
| attempts_md,
|
| cached_stats_md,
|
| *ui,
|
| )
|
|
|
|
|
| ui = reset_feedback_ui()
|
| return (
|
| answer_md,
|
| df_new,
|
| sql_new,
|
| rag_new,
|
| q_new,
|
| ans_state,
|
| df_state,
|
| attempts_new,
|
| attempts_md,
|
| cached_stats_md,
|
| *ui,
|
| )
|
|
|
| feedback_btn.click(
|
| handle_feedback,
|
| inputs=[
|
| feedback_rating,
|
| feedback_comment,
|
| last_sql_state,
|
| last_rag_state,
|
| last_question_state,
|
| last_answer_state,
|
| last_df_state,
|
| feedback_sql_state,
|
| attempts_state,
|
| ],
|
| outputs=[
|
| answer_out,
|
| table_out,
|
| last_sql_state,
|
| last_rag_state,
|
| last_question_state,
|
| last_answer_state,
|
| last_df_state,
|
| attempts_state,
|
| attempts_display,
|
| cached_stats_md,
|
| feedback_rating,
|
| feedback_comment,
|
| feedback_btn,
|
| followup_input,
|
| followup_submit_btn,
|
| followup_prompt_md,
|
| exhausted_md,
|
| ],
|
| )
|
|
|
| followup_submit_btn.click(
|
| handle_feedback,
|
| inputs=[
|
| feedback_rating,
|
| followup_input,
|
| last_sql_state,
|
| last_rag_state,
|
| last_question_state,
|
| last_answer_state,
|
| last_df_state,
|
| feedback_sql_state,
|
| attempts_state,
|
| ],
|
| outputs=[
|
| answer_out,
|
| table_out,
|
| last_sql_state,
|
| last_rag_state,
|
| last_question_state,
|
| last_answer_state,
|
| last_df_state,
|
| attempts_state,
|
| attempts_display,
|
| cached_stats_md,
|
| feedback_rating,
|
| feedback_comment,
|
| feedback_btn,
|
| followup_input,
|
| followup_submit_btn,
|
| followup_prompt_md,
|
| exhausted_md,
|
| ],
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| demo.launch()
|
|
|
|
|
|
|