Shashwat-18's picture
Upload app.py
44eea1c verified
# %%
import os
import uuid
import sqlite3
import datetime
import json
import re
import time
from typing import Dict, Any, List, Tuple, Optional
import hashlib
import pandas as pd
from groq import Groq
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
# %%
# 0. Global config
DB_PATH = "olist.db"
DATA_DIR = "data"
# Groq Model
GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found.")
groq_client = Groq(api_key=GROQ_API_KEY)
# Embedding model
EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
# Logging paths
PROMPT_LOG_PATH = "prompt_logs.txt"
RUN_LOG_PATH = "run_logs.txt"
# %%
# SINGLE, CORRECT, PERSISTENT CHROMA CLIENT
import os
import chromadb
from chromadb.config import Settings
CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data")
os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
_CHROMA_CLIENT = None
def get_chroma_client():
global _CHROMA_CLIENT
if _CHROMA_CLIENT is None:
_CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_PERSIST_DIR,
settings=Settings(
anonymized_telemetry=False,
),
)
return _CHROMA_CLIENT
# %%
# %%
# 1. Logging helpers
def log_prompt(tag: str, prompt: str) -> None:
"""
Append the full prompt to a log file, with a tag and timestamp.
"""
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
header = f"\n\n================ {tag} @ {timestamp} ================\n"
with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
f.write(header)
f.write(prompt)
f.write("\n================ END PROMPT ================\n")
def log_run_event(tag: str, content: str) -> None:
"""
Append model response, final SQL, and error info into a run log.
"""
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
header = f"\n\n================ {tag} @ {timestamp} ================\n"
with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
f.write(header)
f.write(content)
f.write("\n================ END EVENT ================\n")
# %%
# 2. Feedback table + helpers
def init_feedback_table(conn: sqlite3.Connection) -> None:
"""
Create (or upgrade) a table to capture user feedback on model answers.
"""
conn.execute("""
CREATE TABLE IF NOT EXISTS user_feedback (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TEXT NOT NULL,
question TEXT NOT NULL,
generated_sql TEXT,
model_answer TEXT,
rating TEXT CHECK(rating IN ('good','bad')) NOT NULL,
comment TEXT,
corrected_sql TEXT
)
""")
conn.commit()
def record_feedback(
conn: sqlite3.Connection,
question: str,
generated_sql: str,
model_answer: str,
rating: str, # "good" or "bad"
comment: Optional[str] = None,
corrected_sql: Optional[str] = None,
) -> None:
"""
Store user feedback about a particular model answer / SQL query.
If corrected_sql is provided, it is treated as an external correction.
"""
rating = rating.lower()
if rating not in ("good", "bad"):
raise ValueError("rating must be 'good' or 'bad'")
ts = datetime.datetime.now().isoformat(timespec="seconds")
conn.execute(
"""
INSERT INTO user_feedback (
created_at, question, generated_sql, model_answer,
rating, comment, corrected_sql
)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(ts, question, generated_sql, model_answer, rating, comment, corrected_sql),
)
conn.commit()
def get_last_feedback_for_question(
conn: sqlite3.Connection,
question: str,
) -> Optional[Dict[str, Any]]:
"""
Return the most recent feedback row for this question (if any).
"""
cur = conn.cursor()
cur.execute(
"""
SELECT created_at, generated_sql, model_answer,
rating, comment, corrected_sql
FROM user_feedback
WHERE question = ?
ORDER BY created_at DESC
LIMIT 1
""",
(question,),
)
row = cur.fetchone()
if not row:
return None
return {
"created_at": row[0],
"generated_sql": row[1],
"model_answer": row[2],
"rating": row[3],
"comment": row[4],
"corrected_sql": row[5],
}
# %%
# 3. Database setup (from CSVs)
def init_db() -> sqlite3.Connection:
"""
Load all CSVs from the data/ folder into a local SQLite DB.
Table names are derived from file names (without .csv).
"""
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
csv_files = [
"olist_customers_dataset.csv",
"olist_orders_dataset.csv",
"olist_order_items_dataset.csv",
"olist_products_dataset.csv",
"olist_order_reviews_dataset.csv",
"olist_order_payments_dataset.csv",
"product_category_name_translation.csv",
"olist_sellers_dataset.csv",
"olist_geolocation_dataset.csv",
]
for fname in csv_files:
path = os.path.join(DATA_DIR, fname)
print(path)
if not os.path.exists(path):
print(f"CSV not found: {path} - skipping")
continue
table_name = os.path.splitext(fname)[0]
print(f"Loading {path} into table {table_name}...")
df = pd.read_csv(path)
df.to_sql(table_name, conn, if_exists="replace", index=False)
init_feedback_table(conn)
return conn
conn = init_db()
# %%
# 4. Manual docs for Olist tables
OLIST_DOCS: Dict[str, Dict[str, Any]] = {
"olist_customers_dataset": {
"description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
"columns": {
"customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
"customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
"customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
"customer_city": "Customer's city as captured at the time of the order or registration.",
"customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
}
},
"olist_orders_dataset": {
"description": "Customer orders placed on the Olist marketplace, one row per order.",
"columns": {
"order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
"customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
"order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
"order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
"order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
"order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
"order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
"order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
}
},
"olist_order_items_dataset": {
"description": "Order line items, one row per product per order.",
"columns": {
"order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
"order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
"product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
"seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
"shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
"price": "Item price paid by the customer for this line (in BRL, not including freight).",
"freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
}
},
"olist_products_dataset": {
"description": "Product catalog with physical and category attributes, one row per product.",
"columns": {
"product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
"product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
"product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
"product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
"product_photos_qty": "Number of product images associated with the listing.",
"product_weight_g": "Product weight in grams.",
"product_length_cm": "Product length in centimeters (package dimension).",
"product_height_cm": "Product height in centimeters (package dimension).",
"product_width_cm": "Product width in centimeters (package dimension)."
}
},
"olist_order_reviews_dataset": {
"description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
"columns": {
"review_id": "Primary key. Unique identifier for each review record.",
"order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
"review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
"review_comment_title": "Optional short text title or summary of the review.",
"review_comment_message": "Optional detailed free-text comment describing the customer experience.",
"review_creation_date": "Date when the customer created the review.",
"review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
}
},
"olist_order_payments_dataset": {
"description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
"columns": {
"order_id": "Foreign key to olist_orders_dataset.order_id.",
"payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
"payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
"payment_installments": "Number of installments chosen by the customer for this payment.",
"payment_value": "Monetary amount paid in this payment record (in BRL)."
}
},
"product_category_name_translation": {
"description": "Lookup table mapping Portuguese product category names to English equivalents.",
"columns": {
"product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
"product_category_name_english": "Translated product category name in English."
}
},
"olist_sellers_dataset": {
"description": "Seller master data, one row per seller operating on the Olist marketplace.",
"columns": {
"seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
"seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
"seller_city": "City where the seller is located.",
"seller_state": "State where the seller is located (two-letter Brazilian state code)."
}
},
"olist_geolocation_dataset": {
"description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
"columns": {
"geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
"geolocation_lat": "Latitude coordinate of the location.",
"geolocation_lng": "Longitude coordinate of the location.",
"geolocation_city": "City name for the location.",
"geolocation_state": "State code for the location (two-letter Brazilian state code)."
}
}
}
# %%
# 5. Schema extractor + metadata YAML
def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
"""
Introspect SQLite tables, columns and foreign key relationships.
"""
cursor = connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
metadata: Dict[str, Any] = {"tables": {}}
for table in tables:
cursor.execute(f"PRAGMA table_info('{table}')")
cols = cursor.fetchall()
# Manual docs for the table
table_docs = OLIST_DOCS.get(table, {})
table_desc: str = table_docs.get(
"description",
f"Table '{table}' from Olist dataset."
)
column_docs: Dict[str, str] = table_docs.get("columns", {})
columns_meta: Dict[str, str] = {}
for c in cols:
col_name = c[1]
col_type = c[2] or "TEXT"
if col_name in column_docs:
columns_meta[col_name] = column_docs[col_name]
else:
columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
cursor.execute(f"PRAGMA foreign_key_list('{table}')")
fk_rows = cursor.fetchall()
relationships: List[str] = []
for fk in fk_rows:
ref_table = fk[2]
from_col = fk[3]
to_col = fk[4]
relationships.append(
f"{table}.{from_col}{ref_table}.{to_col} (foreign key)"
)
metadata["tables"][table] = {
"description": table_desc,
"columns": columns_meta,
"relationships": relationships,
}
return metadata
schema_metadata = extract_schema_metadata(conn)
def build_schema_yaml(metadata: Dict[str, Any]) -> str:
"""
Render metadata dict into a YAML-style string.
"""
lines: List[str] = ["tables:"]
for tname, tinfo in metadata["tables"].items():
lines.append(f" {tname}:")
desc = tinfo.get("description", "").replace('"', "'")
lines.append(f' description: "{desc}"')
lines.append(" columns:")
for col_name, col_desc in tinfo.get("columns", {}).items():
col_desc_clean = col_desc.replace('"', "'")
lines.append(f' {col_name}: "{col_desc_clean}"')
rels = tinfo.get("relationships", [])
if rels:
lines.append(" relationships:")
for rel in rels:
rel_clean = rel.replace('"', "'")
lines.append(f' - "{rel_clean}"')
return "\n".join(lines)
schema_yaml = build_schema_yaml(schema_metadata)
# %%
# 6. Build schema documents for RAG (taking samples from the table)
def build_schema_documents(
connection: sqlite3.Connection,
schema_metadata: Dict[str, Any],
sample_rows: int = 5,
) -> Tuple[List[str], List[Dict[str, Any]]]:
"""
Build one rich RAG document per table, using schema_metadata.
Each document includes:
- Table name
- Table description
- Columns with type + description
- Relationships (FKs)
- A few sample rows
"""
cursor = connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
docs: List[str] = []
metadatas: List[Dict[str, Any]] = []
for table in tables:
tmeta = schema_metadata["tables"][table]
table_desc = tmeta.get("description", "")
columns_meta = tmeta.get("columns", {})
relationships = tmeta.get("relationships", [])
# Use PRAGMA to get types, then enrich with descriptions
cursor.execute(f"PRAGMA table_info('{table}')")
cols = cursor.fetchall()
col_lines = []
for c in cols:
col_name = c[1]
col_type = c[2] or "TEXT"
col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
# Sample rows
try:
sample_df = pd.read_sql_query(
f"SELECT * FROM '{table}' LIMIT {sample_rows}",
connection,
)
sample_text = sample_df.to_markdown(index=False)
except Exception:
sample_text = "(could not fetch sample rows)"
# Relationships block
rel_block = ""
if relationships:
rel_block = "Relationships:\n" + "\n".join(
f"- {rel}" for rel in relationships
) + "\n"
doc_text = (
f"Table: {table}\n"
f"Description: {table_desc}\n\n"
f"Columns:\n" + "\n".join(col_lines) + "\n\n"
f"{rel_block}\n"
f"Example rows:\n{sample_text}\n"
)
docs.append(doc_text)
metadatas.append({
"doc_type": "table_schema",
"table_name": table,
})
return docs, metadatas
# Build RAG texts + metadata
schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
RAG_TEXTS: List[str] = []
RAG_METADATAS: List[Dict[str, Any]] = []
# 1) Per-table docs
RAG_TEXTS.extend(schema_docs)
RAG_METADATAS.extend(schema_doc_metas)
# 2) Global YAML as a separate doc
RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
RAG_METADATAS.append({"doc_type": "global_schema"})
# %%
def build_store_final():
embedding_model = HuggingFaceEmbeddings(
model_name=EMBED_MODEL_NAME,
encode_kwargs={"normalize_embeddings": True},
)
rag_store = Chroma(
client=get_chroma_client(),
collection_name="rag_schema_store",
embedding_function=embedding_model,
)
if rag_store._collection.count() == 0:
rag_store.add_texts(
texts=RAG_TEXTS,
metadatas=RAG_METADATAS,
)
rag_retriever = rag_store.as_retriever(
search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}}
)
return rag_store, rag_retriever
# Initialize once
rag_store, rag_retriever = build_store_final()
# %%
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
_sql_cache_store = None
_sql_embedding_fn = None
SQL_CACHE_COLLECTION = "sql_cache_mpnet"
def get_sql_cache_store():
global _sql_cache_store, _sql_embedding_fn
if _sql_cache_store is not None:
return _sql_cache_store
if _sql_embedding_fn is None:
_sql_embedding_fn = HuggingFaceEmbeddings(
model_name=EMBED_MODEL_NAME,
encode_kwargs={"normalize_embeddings": True},
)
_sql_cache_store = Chroma(
client=get_chroma_client(),
collection_name=SQL_CACHE_COLLECTION,
embedding_function=_sql_embedding_fn,
)
return _sql_cache_store
# %%
def sanitize_metadata_for_chroma(metadata: dict) -> dict:
safe = {}
for k, v in (metadata or {}).items():
if v is None:
continue
if isinstance(v, (int, float, bool)):
safe[k] = v
elif isinstance(v, str):
safe[k] = v
else:
safe[k] = str(v)
return safe
# %%
def normalize_question_strict(q: str) -> str:
"""
Deterministic normalization for exact cache hits.
"""
if not q:
return ""
q = q.lower().strip()
q = re.sub(r"[^\w\s]", "", q)
q = re.sub(r"\s+", " ", q)
return q
# %%
# Helper: normalize question
def normalize_question_text(q: str) -> str:
if not q:
return ""
q = q.strip().lower()
q = re.sub(r"[^\w\s]", " ", q)
q = re.sub(r"\s+", " ", q).strip()
return q
# Helper: compute success rate
def compute_success_rate(md: dict) -> float:
sc = md.get("success_count", 0) or 0
tf = md.get("total_feedbacks", 0) or 0
if tf <= 0:
return 0.0
return float(sc) / float(tf)
# Insert initial cache entry (no feedback yet)
def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None):
"""
Insert a cached entry when you run a query and want to cache it regardless of feedback.
initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0
"""
if store is None:
store = get_sql_cache_store()
ident = uuid.uuid4().hex
norm = normalize_question_text(question)
md = {
"id": ident,
"normalized_question": norm,
"sql": sql,
"answer_md": answer_md,
"saved_at": time.time(),
"views": 1,
"success_count": 0,
"total_feedbacks": 0,
"success_rate": 1.0,
}
if extra_metadata:
md.update(extra_metadata)
# Use store's API;
store.add_texts([question], metadatas=[md])
return ident
# %%
import time
import logging
from typing import Optional, List, Dict, Any
from difflib import SequenceMatcher
_logger = logging.getLogger(__name__)
# -------------------------------------------------------------------------
# Utility helpers
# -------------------------------------------------------------------------
def _now_ts() -> float:
return time.time()
def similarity_score(a: str, b: str) -> float:
return SequenceMatcher(None, (a or ""), (b or "")).ratio()
# -------------------------------------------------------------------------
# Persist helper
# -------------------------------------------------------------------------
def langchain_upsert(
store,
text: str,
metadata: dict,
cache_id: str,
):
safe_md = sanitize_metadata_for_chroma(metadata)
try:
store.add_texts(
texts=[text],
metadatas=[safe_md],
ids=[cache_id],
)
except TypeError:
store.add_texts(
texts=[text],
metadatas=[safe_md],
)
# -------------------------------------------------------------------------
# Cache insert / update
# -------------------------------------------------------------------------
def cache_sql_answer_dedup(
question: str,
sql: str,
answer_md: str,
metadata: dict,
store,
):
norm_q_semantic = normalize_text(question)
norm_q_exact = normalize_question_strict(question)
cache_id = generate_cache_id(question, sql)
now = _now_ts()
md = {
# identity
"cache_id": cache_id,
"sql": sql,
"answer_md": answer_md,
# exact match key
"normalized_question": norm_q_exact,
# timestamps
"saved_at": metadata.get("saved_at", now),
"last_updated_at": now,
"last_viewed_at": metadata.get("last_viewed_at", 0),
# metrics
"good_count": metadata.get("good_count", 0),
"bad_count": metadata.get("bad_count", 0),
"total_feedbacks": metadata.get("total_feedbacks", 0),
"success_rate": metadata.get("success_rate", 0.5),
"views": metadata.get("views", 0),
}
langchain_upsert(
store=store,
text=norm_q_semantic, # semantic vector
metadata=md,
cache_id=cache_id,
)
return {
"question": question,
"sql": sql,
"answer_md": answer_md,
"metadata": md,
}
# -------------------------------------------------------------------------
# Find exact cached entry (question + SQL)
# -------------------------------------------------------------------------
def find_cached_doc_by_sql(question: str, sql: str, store):
cache_id = generate_cache_id(question, sql)
coll = getattr(store, "_collection", None)
if coll and hasattr(coll, "get"):
try:
res = coll.get(ids=[cache_id])
if res and res.get("metadatas"):
md = res["metadatas"][0]
return {
"id": cache_id,
"question": question,
"sql": md.get("sql"),
"answer_md": md.get("answer_md"),
"metadata": md,
}
except Exception:
pass
return None
# -------------------------------------------------------------------------
# Retrieve cached answers ranked primarily by success rate
# -------------------------------------------------------------------------
import re
import unicodedata
def normalize_text(text: str) -> str:
if not text:
return ""
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii")
text = text.lower()
text = re.sub(r"[^\w\s]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
import hashlib
def generate_cache_id(question: str, sql: str) -> str:
q = normalize_text(question)
s = (sql or "").strip()
key = f"{q}||{s}".encode("utf-8")
return hashlib.sha1(key).hexdigest()
def rank_cached_candidates(candidates: list[dict]) -> dict:
"""
Deterministic quality-first ranking.
"""
candidates.sort(
key=lambda c: (
-float(c["metadata"].get("success_rate", 0.5)),
-int(c["metadata"].get("good_count", 0)),
int(c["metadata"].get("bad_count", 0)),
-float(c["metadata"].get("last_updated_at", 0)),
)
)
return candidates[0]
def retrieve_exact_cached_sql(question: str, store):
"""
Exact question match, but still quality-ranked.
"""
norm_q = normalize_question_strict(question)
coll = store._collection
try:
res = coll.get(where={"normalized_question": norm_q})
if not res or not res.get("metadatas"):
return None
candidates = []
for md in res["metadatas"]:
candidates.append({
"matched_question": question,
"sql": md["sql"],
"answer_md": md.get("answer_md", ""),
"distance": 0.0,
"metadata": md,
})
return rank_cached_candidates(candidates)
except Exception:
return None
def retrieve_best_cached_sql(
question: str,
store,
max_distance: float = 0.25,
):
norm_q = normalize_text(question)
results = store.similarity_search_with_score(norm_q, k=20)
candidates = []
for doc, score in results:
distance = float(score)
if distance > max_distance:
continue
md = doc.metadata or {}
if "sql" not in md:
continue
candidates.append({
"matched_question": doc.page_content,
"sql": md["sql"],
"answer_md": md.get("answer_md", ""),
"distance": distance,
"metadata": md,
# ranking signals (FORCED numeric)
"success_rate": float(md.get("success_rate", 0.5)),
"good": int(md.get("good_count", 0)),
"bad": int(md.get("bad_count", 0)),
"views": int(md.get("views", 0)),
"last_updated": float(md.get("last_updated_at", 0)),
})
if not candidates:
return None
# QUALITY-FIRST RANKING
candidates.sort(
key=lambda c: (
-c["success_rate"],
-c["good"],
c["bad"],
c["distance"],
-c["last_updated"],
)
)
return candidates[0]
# -------------------------------------------------------------------------
# Increment views
# -------------------------------------------------------------------------
def increment_cache_views(metadata: dict, store):
if not metadata:
return False
cache_id = metadata.get("cache_id")
if not cache_id:
return False
md = dict(metadata)
md["views"] = int(md.get("views", 0)) + 1
md["last_viewed_at"] = _now_ts()
md["last_updated_at"] = _now_ts()
md["saved_at"] = md.get("saved_at", _now_ts())
try:
langchain_upsert(
store=store,
text=normalize_text(md.get("normalized_question", "")),
metadata=md,
cache_id=cache_id,
)
return True
except Exception:
_logger.exception("increment_cache_views failed")
return False
# Update metrics on feedback
def update_cache_on_feedback(
question: str,
original_doc_md: dict,
user_marked_good: bool,
llm_corrected_sql: str | None,
llm_corrected_answer_md: str | None,
store,
):
if not original_doc_md:
return
md = dict(original_doc_md["metadata"])
cache_id = md["cache_id"]
# ---- feedback counts ----
if user_marked_good:
md["good_count"] = md.get("good_count", 0) + 1
else:
md["bad_count"] = md.get("bad_count", 0) + 1
md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1
md["success_rate"] = (
md["good_count"] / md["total_feedbacks"]
if md["total_feedbacks"] > 0 else 0.5
)
# ---- timestamps ----
md["saved_at"] = md.get("saved_at", _now_ts()) # preserve
md["last_updated_at"] = _now_ts()
langchain_upsert(
store=store,
text=normalize_text(question),
metadata=md,
cache_id=cache_id,
)
# -------------------------
# Corrected SQL --> NEW ENTRY
# -------------------------
if llm_corrected_sql and llm_corrected_answer_md:
cache_sql_answer_dedup(
question=question,
sql=llm_corrected_sql,
answer_md=llm_corrected_answer_md,
metadata={
"good_count": 1,
"bad_count": 0,
"total_feedbacks": 1,
"success_rate": 1.0,
"views": 0,
"saved_at": _now_ts(),
},
store=store,
)
# %%
# ### 8. Groq LLM via LangChain
from langchain_groq import ChatGroq
import re
import gradio as gr
llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
# %%
def get_rag_context(question: str) -> str:
"""
Retrieve the most relevant schema documents for the question.
"""
docs = rag_retriever.invoke(question)
return "\n\n---\n\n".join(d.page_content for d in docs)
def clean_sql(sql: str) -> str:
sql = sql.strip()
if "```" in sql:
sql = sql.replace("```sql", "").replace("```", "").strip()
return sql
def extract_sql_from_markdown(text: str) -> str:
"""
Extract the first ```sql ... ``` block from LLM output.
If not found, return the whole text.
"""
match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
if match:
return match.group(1).strip()
return text.strip()
def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str:
"""
After the given marker, return the rest of the text as explanation.
"""
idx = text.upper().find(marker.upper())
if idx == -1:
return text.strip()
return text[idx + len(marker):].strip()
# 6b. General / descriptive questions
GENERAL_DESC_KEYWORDS = [
"what is this dataset about",
"what is this data about",
"describe this dataset",
"describe the dataset",
"dataset overview",
"data overview",
"summary of the dataset",
"explain this dataset",
]
def is_general_question(question: str) -> bool:
"""
Detect high-level descriptive questions where we should answer
directly from schema context instead of generating SQL.
"""
q = question.lower().strip()
return any(key in q for key in GENERAL_DESC_KEYWORDS)
def answer_general_question(question: str) -> str:
"""
Use the RAG schema docs to generate a rich, high-level description
of the Olist dataset for conceptual questions.
"""
rag_context = get_rag_context(question)
system_instructions = """
You are a data documentation expert.
You will be given:
- Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
- A high-level user question like "what is this dataset about?".
Your job:
- Write a clear, structured overview of the dataset.
- Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
- Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
- Target a non-technical person.
Do NOT write SQL. Answer in Markdown.
"""
prompt = (
system_instructions
+ "\n\n=== SCHEMA CONTEXT ===\n"
+ rag_context
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\nDetailed dataset overview:"
)
log_prompt("GENERAL_DATASET_QUESTION", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
return response.content.strip()
# %%
# ### 9. SQL execution / validation
def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
"""
Execute SQL on SQLite and return a DataFrame, else return an error.
"""
try:
df = pd.read_sql_query(sql, connection)
return df, None
except Exception as e:
return None, str(e)
def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
"""
Basic SQL validator:
- Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
"""
try:
cursor = connection.cursor()
cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
return True, None
except Exception as e:
return False, str(e)
# %%
# ### 10. SQL Generation + Repair with LLM (feedback-aware)
def build_sql_review_prompt(
question: str,
generated_sql: str,
user_feedback_comment: str,
rag_context: str,
) -> str:
"""
Prompt to let the LLM compare its SQL with the user's feedback,
decide if the join/logic is wrong, and produce a corrected SQL + explanation.
We explicitly allow the model to say:
- "query was already correct, user mistaken" OR
- "query was wrong, here is the fix".
"""
prompt = f"""
You previously generated the following SQL for a SQLite database:
```sql
{generated_sql}
The user now says this query is WRONG and provided this feedback:
"{user_feedback_comment}"
TASKS:
Compare the SQL with the database schema (given in the context) and the user's feedback.
Decide whether the query is actually correct or incorrect.
Make sure that the user clearly specifies why they think it is incorrect.
It can be if they are unsatisfied with the numbers, or the logic is incorrect, or SQL is invalid etc.
If any reason is not clearly specified, return the previous result as it is.
If it is already correct, keep it unchanged and explain why the user might be mistaken.
If it is incorrect (e.g., wrong joins, missing filters, wrong aggregation), fix it.
Produce a corrected SQL query that better answers the question.
If the original query is already correct, just repeat the same SQL.
Explain in a few sentences WHO is correct (you or the user) and WHY.
DATABASE SCHEMA (partial, from RAG):
{rag_context}
User question:
{question}
Return your answer in this format:
CORRECTED_SQL:
-- your (possibly unchanged) SQL here
SELECT ...
EXPLANATION:
Your explanation here, clearly stating whether:
the original query was correct or not, and
how your corrected SQL addresses the issue (or why it didn't need changes).
"""
return prompt.strip()
# %%
# Review and Correct SQL based on Feedback
def review_and_correct_sql_with_llm(
question: str,
generated_sql: str,
user_feedback_comment: str,
rag_context: str,
) -> Tuple[str, str]:
"""
Ask the LLM to compare its SQL with user's feedback, decide what is wrong (or not),
and propose a corrected SQL (possibly unchanged) + explanation.
Returns:
corrected_sql, explanation
"""
prompt = build_sql_review_prompt(
question=question,
generated_sql=generated_sql,
user_feedback_comment=user_feedback_comment,
rag_context=rag_context,
)
log_prompt("SQL_REVIEW_FEEDBACK", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_SQL_REVIEW", response.content)
# Try to extract SQL; if none, fall back to original
extracted_sql = extract_sql_from_markdown(response.content)
corrected_sql = clean_sql(extracted_sql) if extracted_sql else generated_sql
if not corrected_sql.strip():
corrected_sql = generated_sql
explanation = extract_explanation_after_marker(
response.content,
marker="EXPLANATION:",
)
return corrected_sql, explanation
# %%
# Generate SQL
def generate_sql(question: str, rag_context: str) -> str:
"""
Generate SQL using LLM, but also pull in any past user feedback
(corrected SQL) for this question as external guidance.
"""
# External correction from past feedback, if any
last_fb = get_last_feedback_for_question(conn, question)
if last_fb and last_fb.get("corrected_sql"):
previous_feedback_block = f"""
EXTERNAL USER FEEDBACK FROM PAST RUNS:
Previous generated SQL:
{last_fb['generated_sql']}
Corrected SQL (preferred reference):
{last_fb['corrected_sql']}
User comment & prior explanation:
{last_fb.get('comment') or '(none)'}
You must avoid repeating the same mistake and should follow the logic of
the corrected SQL when appropriate, while still reasoning from the schema.
"""
else:
previous_feedback_block = ""
system_instructions = f"""
You are a senior data analyst writing SQL for a SQLite database.
You will be given:
- A description of available tables, columns, and relationships (schema + YAML metadata).
- A natural language question from the user.
Your job:
- Write ONE valid SQLite SQL query that answers the question.
- ONLY use tables and columns that exist in the schema_context.
- Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
- Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
- Always use floating-point division for percentage calculations using 1.0 * numerator / denominator,
and round to 2 decimals when appropriate.
{previous_feedback_block}
"""
prompt = (
system_instructions
+ "\n\n=== RAG CONTEXT ===\n"
+ rag_context
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\nSQL query:"
)
log_prompt("SQL_GENERATION", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
sql = clean_sql(response.content)
return sql
# %%
# Repair SQL
def repair_sql(
question: str,
rag_context: str,
bad_sql: str,
error_message: str,
) -> str:
"""
Ask the LLM to correct a failing SQL query.
"""
system_instructions = """
You are a senior data analyst fixing an existing SQL query for a SQLite database.
You will be given:
- Schema context (tables, columns, relationships).
- The user's question.
- A previously generated SQL query that failed.
- The SQLite error message.
Your job:
- Diagnose why the query failed.
- Rewrite ONE valid SQLite SQL query that answers the question.
- ONLY use tables and columns that exist in the schema_context.
- Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
- Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
- Return ONLY the corrected SQL query, no explanation or markdown.
"""
prompt = (
system_instructions
+ "\n\n=== RAG CONTEXT ===\n"
+ rag_context
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\n=== PREVIOUS (FAILING) SQL ===\n"
+ bad_sql
+ "\n\n=== SQLITE ERROR ===\n"
+ error_message
+ "\n\nCorrected SQL query:"
)
log_prompt("SQL_REPAIR", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
sql = clean_sql(response.content)
return sql
# %%
### 11. Result summarization
def summarize_results(
question: str,
sql: str,
df: Optional[pd.DataFrame],
rag_context: str,
error: Optional[str] = None,
) -> str:
"""
Ask the LLM to produce a concise, human-readable answer.
"""
system_instructions = """
You are a senior data analyst.
You will be given:
The user's question.
The final SQL that was executed.
A small preview of the query result (as a Markdown table, if available).
Optional error information if the query failed.
Your job:
Provide a clear, concise answer in Markdown.
If the result is numeric / aggregated, explain what it means in business terms.
If there was an error, explain it simply and suggest how the user could rephrase.
Do NOT show raw SQL unless it is helpful to the user.
"""
# Build a markdown table preview if we have data
if df is not None and not df.empty:
preview_rows = min(len(df), 50)
df_preview_md = df.head(preview_rows).to_markdown(index=False)
else:
df_preview_md = "(no rows returned)"
prompt = (
system_instructions
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\n=== EXECUTED SQL ===\n"
+ sql
+ "\n\n=== QUERY RESULT PREVIEW ===\n"
+ df_preview_md
+ "\n\n=== RAG CONTEXT (schema) ===\n"
+ rag_context
)
if error:
prompt += "\n\n=== ERROR ===\n" + error
# Logging helpers assumed to exist
log_prompt("RESULT_SUMMARY", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
return response.content.strip()
# %%
def backend_pipeline(question: str):
"""
STRICT cache-first backend.
Priority:
1. Exact question cache hit
2. Best semantic cache hit
3. LLM fallback
"""
# ------------------------------------------------------------------
# Guards
# ------------------------------------------------------------------
if not question or not question.strip():
return (
"Please type a question.",
pd.DataFrame(),
"",
"",
"",
"",
pd.DataFrame(),
[],
[],
False,
4,
"**Feedback attempts remaining: 4**",
gr.update(value="", visible=False),
)
attempts_left = 4
attempts_text = f"**Feedback attempts remaining: {attempts_left}**"
# ------------------------------------------------------------------
# General / descriptive questions
# ------------------------------------------------------------------
if is_general_question(question):
overview_md = answer_general_question(question)
return (
overview_md,
pd.DataFrame(),
"",
"",
question,
overview_md,
pd.DataFrame(),
[],
[],
False,
attempts_left,
attempts_text,
gr.update(value="", visible=False),
)
store = get_sql_cache_store()
# ------------------------------------------------------------------
# STEP 0A: EXACT CACHE LOOKUP (deterministic)
# ------------------------------------------------------------------
cached = retrieve_exact_cached_sql(question, store)
# ------------------------------------------------------------------
# STEP 0B: SEMANTIC CACHE LOOKUP (ranked)
# ------------------------------------------------------------------
if cached is None:
cached = retrieve_best_cached_sql(
question=question,
store=store,
max_distance=0.25,
)
# ------------------------------------------------------------------
# CACHE HIT PATH
# ------------------------------------------------------------------
if cached:
try:
increment_cache_views(cached["metadata"], store=store)
except Exception:
pass
rag_context = get_rag_context(question)
header = (
"### Cache Hit\n"
f"- **Matched question:** \"{cached['matched_question']}\"\n"
f"- **Success rate:** {cached['metadata'].get('success_rate', 0.5):.2f}\n"
f"- **Similarity distance:** {cached['distance']:.4f}\n\n"
"---\n\n"
)
answer_md = header + (cached.get("answer_md") or "")
try:
df, exec_error = execute_sql(cached["sql"], conn)
if exec_error:
df = pd.DataFrame()
answer_md += f"\n\n Error re-running cached SQL: `{exec_error}`"
except Exception as e:
df = pd.DataFrame()
answer_md += f"\n\n Exception re-running cached SQL: `{e}`"
md = cached["metadata"]
stats_md = (
f"**Cached entry stats**\n\n"
f"- **Success rate:** {md.get('success_rate', 0.5):.2f} \n"
f"- **Total feedbacks:** {md.get('total_feedbacks', 0)} \n"
f"- **Good / Bad:** {md.get('good_count', 0)} / {md.get('bad_count', 0)} \n"
f"- **Views:** {md.get('views', 0)} \n"
f"- **Saved at:** "
f"{datetime.datetime.fromtimestamp(md.get('saved_at')).strftime('%Y-%m-%d %H:%M') if md.get('saved_at') else 'unknown'} \n"
f"- **Last updated:** "
f"{datetime.datetime.fromtimestamp(md.get('last_updated_at')).strftime('%Y-%m-%d %H:%M') if md.get('last_updated_at') else 'unknown'}\n\n"
f"**SQL preview:**\n\n```sql\n{cached['sql']}\n```\n"
)
return (
answer_md,
df,
cached["sql"],
rag_context,
question,
answer_md,
df,
[],
[],
False,
attempts_left,
attempts_text,
gr.update(value=stats_md, visible=True),
)
# ------------------------------------------------------------------
# STEP 1: LLM FLOW (NO CACHE HIT)
# ------------------------------------------------------------------
rag_context = get_rag_context(question)
sql = generate_sql(question, rag_context)
original_sql = sql
is_valid, validation_error = validate_sql(sql, conn)
repaired = False
if not is_valid and validation_error:
repaired_sql = repair_sql(question, rag_context, sql, validation_error)
repaired_valid, repaired_error = validate_sql(repaired_sql, conn)
if repaired_valid:
sql = repaired_sql
repaired = True
validation_error = None
else:
validation_error = repaired_error or validation_error
df, exec_error = (None, None)
if not validation_error:
df, exec_error = execute_sql(sql, conn)
else:
exec_error = validation_error
summary_text = summarize_results(
question=question,
sql=sql,
df=df,
rag_context=rag_context,
error=exec_error,
)
sql_status = []
if exec_error:
sql_status.append(f"**Error:** `{exec_error}`")
else:
sql_status.append("Query ran successfully.")
if repaired:
sql_status.append("_Note: SQL was auto-repaired._")
sql_status.append("\n**Final SQL used:**\n")
sql_status.append(f"```sql\n{sql}\n```")
answer_md = summary_text + "\n\n---\n\n" + "\n".join(sql_status)
df_preview = df if df is not None and exec_error is None else pd.DataFrame()
# ------------------------------------------------------------------
# STEP 2: CACHE LLM RESULT (ALWAYS)
# ------------------------------------------------------------------
try:
cache_sql_answer_dedup(
question=question,
sql=sql,
answer_md=answer_md,
metadata={
"good_count": 0,
"bad_count": 0,
"total_feedbacks": 0,
"success_rate": 0.5,
"views": 1,
"saved_at": _now_ts(),
},
store=store,
)
except Exception:
_logger.exception("backend_pipeline: failed to cache LLM result")
return (
answer_md,
df_preview,
sql,
rag_context,
question,
answer_md,
df_preview,
[],
[],
False,
attempts_left,
attempts_text,
gr.update(value="", visible=False),
)
# %%
def _looks_like_sql(text: str) -> bool:
"""Quick heuristic: does text contain SQL keywords / SELECT ?"""
if not text:
return False
return bool(re.search(r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b", text, flags=re.I))
def is_feedback_sufficient(feedback_text: str) -> bool:
"""
Heuristic to decide whether the user's free-text feedback is actionable.
Returns True if:
- length >= 20 characters AND contains a signal word (e.g., 'filter', 'year', 'should', 'instead', 'missing', 'wrong', digits),
OR
- it looks like SQL (user pasted corrected SQL),
OR
- length >= 60 characters (long feedback).
"""
if not feedback_text:
return False
text = feedback_text.strip()
if len(text) >= 60:
return True
if _looks_like_sql(text):
return True
# look for signal words that indicate specificity
signal_words = [
"filter", "where", "year", "month", "should", "instead", "expected",
"wrong", "missing", "aggregate", "sum", "avg", "count", "distinct",
"join", "left join", "inner join", "group by", "order by", "date",
"range", "exclude", "include", "only"
]
lower = text.lower()
signals = sum(1 for w in signal_words if w in lower)
if signals >= 1 and len(text) >= 20:
return True
# short hits like "numbers look off" are insufficient
return False
def build_followup_prompt_for_user(sample_feedback: str = "") -> str:
"""
Deterministic follow-up question to ask the user when feedback is vague.
Returns a friendly prompt that the UI can display to the user.
"""
base = (
"Thanks — I need a bit more detail to act on this feedback.\n\n"
"Please tell me one (or more) of the following so I can check and correct the result:\n\n"
"1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n"
" the **time range** (year/month), or the **filters** applied?\n"
"2. If you expected a different number, what was the expected number (and how was it computed)?\n"
"3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n"
"Examples you can copy-paste:\n"
)
examples = (
"- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n"
"- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n"
"- \"Please exclude canceled orders (order_status = 'canceled').\"\n"
"- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n"
)
hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare."
prompt = base + examples + hint
if sample_feedback:
prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt
return prompt
# %%
def feedback_pipeline_interactive(
feedback_rating: str,
feedback_comment: str,
last_sql: str,
last_rag_context: str,
last_question: str,
last_answer_md: str,
last_df: pd.DataFrame,
feedback_sql: str,
attempts_left: int,
):
rating = (feedback_rating or "").strip().lower()
comment = (feedback_comment or "").strip()
attempts_left = int(attempts_left or 0)
# ---------------- Guard ----------------
if not last_question or not last_sql:
return (
last_answer_md,
last_df,
last_sql,
last_rag_context,
last_question,
last_answer_md,
last_df,
False,
"",
attempts_left,
)
if rating not in ("correct", "wrong"):
return (
last_answer_md + "\n\n Please select **Correct** or **Wrong**.",
last_df,
last_sql,
last_rag_context,
last_question,
last_answer_md,
last_df,
False,
"",
attempts_left,
)
# ============================================================
# CORRECT -> no attempt decrement
# ============================================================
if rating == "correct":
original_doc = find_cached_doc_by_sql(
last_question, last_sql, store=get_sql_cache_store()
)
update_cache_on_feedback(
question=last_question,
original_doc_md=original_doc,
user_marked_good=True,
llm_corrected_sql=None,
llm_corrected_answer_md=None,
store=get_sql_cache_store(),
)
record_feedback(
conn=conn,
question=last_question,
generated_sql=last_sql,
model_answer=last_answer_md,
rating="good",
comment=comment or None,
corrected_sql=None,
)
return (
last_answer_md + "\n\n **Feedback recorded as GOOD.**",
last_df,
last_sql,
last_rag_context,
last_question,
last_answer_md,
last_df,
False,
"",
attempts_left,
)
# ============================================================
# WRONG -> decrement immediately
# ============================================================
attempts_left = max(0, attempts_left - 1)
# ============================================================
# Attempts exhausted → FORCE LLM
# ============================================================
if attempts_left == 0:
comment = comment or "User marked result as wrong."
# ============================================================
# Insufficient feedback -> FOLLOW-UP (only if attempts remain)
# ============================================================
if attempts_left > 0 and not is_feedback_sufficient(comment):
return (
last_answer_md,
last_df,
last_sql,
last_rag_context,
last_question,
last_answer_md,
last_df,
True, # awaiting follow-up
build_followup_prompt_for_user(comment),
attempts_left,
)
# ============================================================
# Run LLM review
# ============================================================
original_doc = find_cached_doc_by_sql(
last_question, last_sql, store=get_sql_cache_store()
)
corrected_sql, explanation = review_and_correct_sql_with_llm(
question=last_question,
generated_sql=last_sql,
user_feedback_comment=comment,
rag_context=last_rag_context,
)
corrected_sql = corrected_sql or last_sql
df_new, exec_error = execute_sql(corrected_sql, conn)
if exec_error:
answer_core = summarize_results(
question=last_question,
sql=corrected_sql,
df=None,
rag_context=last_rag_context,
error=exec_error,
)
df_new = pd.DataFrame()
else:
answer_core = summarize_results(
question=last_question,
sql=corrected_sql,
df=df_new,
rag_context=last_rag_context,
error=None,
)
update_cache_on_feedback(
question=last_question,
original_doc_md=original_doc,
user_marked_good=False,
llm_corrected_sql=(
corrected_sql if corrected_sql.strip() != last_sql.strip() else None
),
llm_corrected_answer_md=(
answer_core if corrected_sql.strip() != last_sql.strip() else None
),
store=get_sql_cache_store(),
)
record_feedback(
conn=conn,
question=last_question,
generated_sql=last_sql,
model_answer=last_answer_md,
rating="bad",
comment=comment + "\n\nLLM explanation:\n" + (explanation or ""),
corrected_sql=corrected_sql,
)
final_md = (
answer_core
+ "\n\n---\n\n"
+ f"**Final corrected SQL:**\n```sql\n{corrected_sql}\n```\n\n"
+ "### LLM Review Explanation\n"
+ (explanation or "")
)
return (
final_md,
df_new,
corrected_sql,
last_rag_context,
last_question,
final_md,
df_new,
False,
"",
attempts_left,
)
# %%
import gradio as gr
import pandas as pd
with gr.Blocks() as demo:
gr.Markdown("# Olist Analytics Assistant (RAG + SQL + Feedback)")
# ==================== STATE ====================
last_sql_state = gr.State("")
last_rag_state = gr.State("")
last_question_state = gr.State("")
last_answer_state = gr.State("")
last_df_state = gr.State(pd.DataFrame())
attempts_state = gr.State(4)
feedback_sql_state = gr.State("")
# ==================== MAIN UI ====================
with gr.Row():
with gr.Column(scale=1):
question_in = gr.Textbox(
label="Your question",
placeholder="e.g. Total number of customers",
lines=4,
)
submit_btn = gr.Button("Run")
with gr.Column(scale=2):
answer_out = gr.Markdown()
table_out = gr.Dataframe()
attempts_display = gr.Markdown("**Feedback attempts remaining: 4**")
cached_stats_md = gr.Markdown(visible=False)
# ==================== FEEDBACK ====================
gr.Markdown("### Feedback")
feedback_rating = gr.Radio(
["Correct", "Wrong"],
label="Is the answer correct?",
value=None,
)
feedback_comment = gr.Textbox(
label="Explain (required if Wrong)",
lines=3,
)
feedback_btn = gr.Button("Submit feedback")
# ==================== FOLLOW-UP ====================
followup_prompt_md = gr.Markdown(visible=False)
followup_input = gr.Textbox(
label="Please clarify",
visible=False,
lines=4,
)
followup_submit_btn = gr.Button(
"Submit follow-up",
visible=False,
)
exhausted_md = gr.Markdown(
"**You have exhausted your feedback attempts. Please ask a new question to continue.**",
visible=False,
)
# ==================== UI HELPERS ====================
def reset_feedback_ui():
return (
gr.update(value=None, visible=True), # rating
gr.update(value="", visible=True), # comment
gr.update(visible=True), # submit
gr.update(visible=False), # followup input
gr.update(visible=False), # followup btn
gr.update(visible=False), # followup prompt
gr.update(visible=False), # exhausted
)
def show_followup_ui(prompt: str):
return (
gr.update(visible=False), # rating
gr.update(visible=False), # comment
gr.update(visible=False), # submit
gr.update(value="", visible=True), # followup input
gr.update(visible=True), # followup btn
gr.update(value=prompt, visible=True), # followup prompt
gr.update(visible=False), # exhausted
)
def show_exhausted_ui():
return (
gr.update(visible=False), # rating
gr.update(visible=False), # comment
gr.update(visible=False), # submit
gr.update(visible=False), # followup input
gr.update(visible=False), # followup btn
gr.update(visible=False), # followup prompt
gr.update(visible=True), # exhausted
)
# ==================== RUN PIPELINE ====================
def run_and_render(question):
(
answer_md,
df,
sql,
rag,
q,
answer_state,
df_state,
_cached_matches,
_dropdown_choices,
_dropdown_visible,
attempts,
attempts_text,
cached_stats_update,
) = backend_pipeline(question)
return (
answer_md,
df,
sql,
rag,
q,
answer_state,
df_state,
attempts,
attempts_text,
cached_stats_update,
*reset_feedback_ui(),
)
submit_btn.click(
run_and_render,
inputs=[question_in],
outputs=[
answer_out,
table_out,
last_sql_state,
last_rag_state,
last_question_state,
last_answer_state,
last_df_state,
attempts_state,
attempts_display,
cached_stats_md,
feedback_rating,
feedback_comment,
feedback_btn,
followup_input,
followup_submit_btn,
followup_prompt_md,
exhausted_md,
],
)
# ==================== FEEDBACK HANDLER ====================
def handle_feedback(
rating,
comment,
last_sql,
last_rag,
last_question,
last_answer,
last_df,
feedback_sql,
attempts_left,
):
(
answer_md,
df_new,
sql_new,
rag_new,
q_new,
ans_state,
df_state,
awaiting_followup,
followup_prompt,
attempts_new,
) = feedback_pipeline_interactive(
rating,
comment,
last_sql,
last_rag,
last_question,
last_answer,
last_df,
feedback_sql,
attempts_left,
)
attempts_md = f"**Feedback attempts remaining: {attempts_new}**"
# Exhausted
if attempts_new <= 0:
ui = show_exhausted_ui()
return (
answer_md,
df_new,
sql_new,
rag_new,
q_new,
ans_state,
df_state,
attempts_new,
attempts_md,
cached_stats_md,
*ui,
)
# Follow-up only
if awaiting_followup:
ui = show_followup_ui(followup_prompt)
return (
answer_md,
df_new,
sql_new,
rag_new,
q_new,
ans_state,
df_state,
attempts_new,
attempts_md,
cached_stats_md,
*ui,
)
# Normal reset
ui = reset_feedback_ui()
return (
answer_md,
df_new,
sql_new,
rag_new,
q_new,
ans_state,
df_state,
attempts_new,
attempts_md,
cached_stats_md,
*ui,
)
feedback_btn.click(
handle_feedback,
inputs=[
feedback_rating,
feedback_comment,
last_sql_state,
last_rag_state,
last_question_state,
last_answer_state,
last_df_state,
feedback_sql_state,
attempts_state,
],
outputs=[
answer_out,
table_out,
last_sql_state,
last_rag_state,
last_question_state,
last_answer_state,
last_df_state,
attempts_state,
attempts_display,
cached_stats_md,
feedback_rating,
feedback_comment,
feedback_btn,
followup_input,
followup_submit_btn,
followup_prompt_md,
exhausted_md,
],
)
followup_submit_btn.click(
handle_feedback,
inputs=[
feedback_rating,
followup_input,
last_sql_state,
last_rag_state,
last_question_state,
last_answer_state,
last_df_state,
feedback_sql_state,
attempts_state,
],
outputs=[
answer_out,
table_out,
last_sql_state,
last_rag_state,
last_question_state,
last_answer_state,
last_df_state,
attempts_state,
attempts_display,
cached_stats_md,
feedback_rating,
feedback_comment,
feedback_btn,
followup_input,
followup_submit_btn,
followup_prompt_md,
exhausted_md,
],
)
# %%
if __name__ == "__main__":
demo.launch()