# %% import os import uuid import sqlite3 import datetime import json import re import time from typing import Dict, Any, List, Tuple, Optional import hashlib import pandas as pd from groq import Groq from langchain_community.vectorstores import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings # %% # 0. Global config DB_PATH = "olist.db" DATA_DIR = "data" # Groq Model GROQ_MODEL_NAME = "llama-3.3-70b-versatile" GROQ_API_KEY = os.getenv("GROQ_API_KEY") if not GROQ_API_KEY: raise ValueError("GROQ_API_KEY not found.") groq_client = Groq(api_key=GROQ_API_KEY) # Embedding model EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" # Logging paths PROMPT_LOG_PATH = "prompt_logs.txt" RUN_LOG_PATH = "run_logs.txt" # %% import os import chromadb from chromadb.config import Settings CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data") os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True) _CHROMA_CLIENT = None def get_chroma_client(): global _CHROMA_CLIENT if _CHROMA_CLIENT is None: _CHROMA_CLIENT = chromadb.PersistentClient( path=CHROMA_PERSIST_DIR, settings=Settings( anonymized_telemetry=False, ), ) return _CHROMA_CLIENT # %% def log_prompt(tag: str, prompt: str) -> None: """ Append the full prompt to a log file, with a tag and timestamp. """ timestamp = datetime.datetime.now().isoformat(timespec="seconds") header = f"\n\n================ {tag} @ {timestamp} ================\n" with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f: f.write(header) f.write(prompt) f.write("\n================ END PROMPT ================\n") def log_run_event(tag: str, content: str) -> None: """ Append model response, final SQL, and error info into a run log. """ timestamp = datetime.datetime.now().isoformat(timespec="seconds") header = f"\n\n================ {tag} @ {timestamp} ================\n" with open(RUN_LOG_PATH, "a", encoding="utf-8") as f: f.write(header) f.write(content) f.write("\n================ END EVENT ================\n") # %% # ============================== # Chat message helpers # ============================== def assistant_msg(text: str): return (None, text) def user_msg(text: str): return (text, None) def add_assistant_state(history, text): history.append(assistant_msg(text)) return history # %% import pandas as pd def df_to_state(df: pd.DataFrame) -> dict: if df is None or not isinstance(df, pd.DataFrame) or df.empty: return {"columns": [], "rows": []} cols = list(df.columns) rows = df.astype(str).values.tolist() if not rows: return {"columns": [], "rows": []} return {"columns": cols, "rows": rows} def state_to_df(state: dict) -> pd.DataFrame: if not isinstance(state, dict): return pd.DataFrame() cols = state.get("columns") rows = state.get("rows") if not cols or rows is None: return pd.DataFrame() return pd.DataFrame(rows, columns=cols) # %% def init_feedback_table(conn: sqlite3.Connection) -> None: """ Create (or upgrade) a table to capture user feedback on model answers. """ conn.execute(""" CREATE TABLE IF NOT EXISTS user_feedback ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_at TEXT NOT NULL, question TEXT NOT NULL, generated_sql TEXT, model_answer TEXT, rating TEXT CHECK(rating IN ('good','bad')) NOT NULL, comment TEXT, corrected_sql TEXT ) """) conn.commit() def record_feedback( conn: sqlite3.Connection, question: str, generated_sql: str, model_answer: str, rating: str, # "good" or "bad" comment: Optional[str] = None, corrected_sql: Optional[str] = None, ) -> None: """ Store user feedback about a particular model answer / SQL query. If corrected_sql is provided, it is treated as an external correction. """ rating = rating.lower() if rating not in ("good", "bad"): raise ValueError("rating must be 'good' or 'bad'") ts = datetime.datetime.now().isoformat(timespec="seconds") conn.execute( """ INSERT INTO user_feedback ( created_at, question, generated_sql, model_answer, rating, comment, corrected_sql ) VALUES (?, ?, ?, ?, ?, ?, ?) """, (ts, question, generated_sql, model_answer, rating, comment, corrected_sql), ) conn.commit() def get_last_feedback_for_question( conn: sqlite3.Connection, question: str, ) -> Optional[Dict[str, Any]]: """ Return the most recent feedback row for this question (if any). """ cur = conn.cursor() cur.execute( """ SELECT created_at, generated_sql, model_answer, rating, comment, corrected_sql FROM user_feedback WHERE question = ? ORDER BY created_at DESC LIMIT 1 """, (question,), ) row = cur.fetchone() if not row: return None return { "created_at": row[0], "generated_sql": row[1], "model_answer": row[2], "rating": row[3], "comment": row[4], "corrected_sql": row[5], } # %% def init_db() -> sqlite3.Connection: """ Load all CSVs from the data/ folder into a local SQLite DB. Table names are derived from file names (without .csv). """ conn = sqlite3.connect(DB_PATH, check_same_thread=False) csv_files = [ "olist_customers_dataset.csv", "olist_orders_dataset.csv", "olist_order_items_dataset.csv", "olist_products_dataset.csv", "olist_order_reviews_dataset.csv", "olist_order_payments_dataset.csv", "product_category_name_translation.csv", "olist_sellers_dataset.csv", "olist_geolocation_dataset.csv", ] for fname in csv_files: path = os.path.join(DATA_DIR, fname) print(path) if not os.path.exists(path): print(f"CSV not found: {path} - skipping") continue table_name = os.path.splitext(fname)[0] print(f"Loading {path} into table {table_name}...") df = pd.read_csv(path) df.to_sql(table_name, conn, if_exists="replace", index=False) init_feedback_table(conn) return conn conn = init_db() # %% # 4. Manual docs for Olist tables OLIST_DOCS: Dict[str, Dict[str, Any]] = { "olist_customers_dataset": { "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).", "columns": { "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.", "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.", "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.", "customer_city": "Customer's city as captured at the time of the order or registration.", "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)." } }, "olist_orders_dataset": { "description": "Customer orders placed on the Olist marketplace, one row per order.", "columns": { "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.", "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.", "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).", "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).", "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.", "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.", "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.", "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout." } }, "olist_order_items_dataset": { "description": "Order line items, one row per product per order.", "columns": { "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.", "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.", "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.", "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.", "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.", "price": "Item price paid by the customer for this line (in BRL, not including freight).", "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)." } }, "olist_products_dataset": { "description": "Product catalog with physical and category attributes, one row per product.", "columns": { "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.", "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.", "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).", "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').", "product_photos_qty": "Number of product images associated with the listing.", "product_weight_g": "Product weight in grams.", "product_length_cm": "Product length in centimeters (package dimension).", "product_height_cm": "Product height in centimeters (package dimension).", "product_width_cm": "Product width in centimeters (package dimension)." } }, "olist_order_reviews_dataset": { "description": "Post-purchase customer reviews and satisfaction scores, one row per review.", "columns": { "review_id": "Primary key. Unique identifier for each review record.", "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.", "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).", "review_comment_title": "Optional short text title or summary of the review.", "review_comment_message": "Optional detailed free-text comment describing the customer experience.", "review_creation_date": "Date when the customer created the review.", "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)." } }, "olist_order_payments_dataset": { "description": "Payments associated with orders, one row per payment record (order can have multiple payments).", "columns": { "order_id": "Foreign key to olist_orders_dataset.order_id.", "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).", "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).", "payment_installments": "Number of installments chosen by the customer for this payment.", "payment_value": "Monetary amount paid in this payment record (in BRL)." } }, "product_category_name_translation": { "description": "Lookup table mapping Portuguese product category names to English equivalents.", "columns": { "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.", "product_category_name_english": "Translated product category name in English." } }, "olist_sellers_dataset": { "description": "Seller master data, one row per seller operating on the Olist marketplace.", "columns": { "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.", "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.", "seller_city": "City where the seller is located.", "seller_state": "State where the seller is located (two-letter Brazilian state code)." } }, "olist_geolocation_dataset": { "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.", "columns": { "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.", "geolocation_lat": "Latitude coordinate of the location.", "geolocation_lng": "Longitude coordinate of the location.", "geolocation_city": "City name for the location.", "geolocation_state": "State code for the location (two-letter Brazilian state code)." } } } # %% def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]: """ Introspect SQLite tables, columns and foreign key relationships. """ cursor = connection.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [row[0] for row in cursor.fetchall()] metadata: Dict[str, Any] = {"tables": {}} for table in tables: cursor.execute(f"PRAGMA table_info('{table}')") cols = cursor.fetchall() # Manual docs for the table table_docs = OLIST_DOCS.get(table, {}) table_desc: str = table_docs.get( "description", f"Table '{table}' from Olist dataset." ) column_docs: Dict[str, str] = table_docs.get("columns", {}) columns_meta: Dict[str, str] = {} for c in cols: col_name = c[1] col_type = c[2] or "TEXT" if col_name in column_docs: columns_meta[col_name] = column_docs[col_name] else: columns_meta[col_name] = f"Column '{col_name}' of type {col_type}" cursor.execute(f"PRAGMA foreign_key_list('{table}')") fk_rows = cursor.fetchall() relationships: List[str] = [] for fk in fk_rows: ref_table = fk[2] from_col = fk[3] to_col = fk[4] relationships.append( f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)" ) metadata["tables"][table] = { "description": table_desc, "columns": columns_meta, "relationships": relationships, } return metadata schema_metadata = extract_schema_metadata(conn) def build_schema_yaml(metadata: Dict[str, Any]) -> str: """ Render metadata dict into a YAML-style string. """ lines: List[str] = ["tables:"] for tname, tinfo in metadata["tables"].items(): lines.append(f" {tname}:") desc = tinfo.get("description", "").replace('"', "'") lines.append(f' description: "{desc}"') lines.append(" columns:") for col_name, col_desc in tinfo.get("columns", {}).items(): col_desc_clean = col_desc.replace('"', "'") lines.append(f' {col_name}: "{col_desc_clean}"') rels = tinfo.get("relationships", []) if rels: lines.append(" relationships:") for rel in rels: rel_clean = rel.replace('"', "'") lines.append(f' - "{rel_clean}"') return "\n".join(lines) schema_yaml = build_schema_yaml(schema_metadata) # %% # 6. Build schema documents for RAG (taking samples from the table) def build_schema_documents( connection: sqlite3.Connection, schema_metadata: Dict[str, Any], sample_rows: int = 5, ) -> Tuple[List[str], List[Dict[str, Any]]]: """ Build one rich RAG document per table, using schema_metadata. Each document includes: - Table name - Table description - Columns with type + description - Relationships (FKs) - A few sample rows """ cursor = connection.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [row[0] for row in cursor.fetchall()] docs: List[str] = [] metadatas: List[Dict[str, Any]] = [] for table in tables: tmeta = schema_metadata["tables"][table] table_desc = tmeta.get("description", "") columns_meta = tmeta.get("columns", {}) relationships = tmeta.get("relationships", []) # Use PRAGMA to get types, then enrich with descriptions cursor.execute(f"PRAGMA table_info('{table}')") cols = cursor.fetchall() col_lines = [] for c in cols: col_name = c[1] col_type = c[2] or "TEXT" col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}") col_lines.append(f"- {col_name} ({col_type}): {col_desc}") # Sample rows try: sample_df = pd.read_sql_query( f"SELECT * FROM '{table}' LIMIT {sample_rows}", connection, ) sample_text = sample_df.to_markdown(index=False) except Exception: sample_text = "(could not fetch sample rows)" # Relationships block rel_block = "" if relationships: rel_block = "Relationships:\n" + "\n".join( f"- {rel}" for rel in relationships ) + "\n" doc_text = ( f"Table: {table}\n" f"Description: {table_desc}\n\n" f"Columns:\n" + "\n".join(col_lines) + "\n\n" f"{rel_block}\n" f"Example rows:\n{sample_text}\n" ) docs.append(doc_text) metadatas.append({ "doc_type": "table_schema", "table_name": table, }) return docs, metadatas # Build RAG texts + metadata schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata) RAG_TEXTS: List[str] = [] RAG_METADATAS: List[Dict[str, Any]] = [] # 1) Per-table docs RAG_TEXTS.extend(schema_docs) RAG_METADATAS.extend(schema_doc_metas) # 2) Global YAML as a separate doc RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml) RAG_METADATAS.append({"doc_type": "global_schema"}) # %% def build_store_final(): embedding_model = HuggingFaceEmbeddings( model_name=EMBED_MODEL_NAME, encode_kwargs={"normalize_embeddings": True}, ) rag_store = Chroma( client=get_chroma_client(), collection_name="rag_schema_store", embedding_function=embedding_model, ) if rag_store._collection.count() == 0: rag_store.add_texts( texts=RAG_TEXTS, metadatas=RAG_METADATAS, ) rag_retriever = rag_store.as_retriever( search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}} ) return rag_store, rag_retriever # Initialize once rag_store, rag_retriever = build_store_final() # %% from langchain_community.vectorstores import Chroma from langchain_community.embeddings import HuggingFaceEmbeddings _sql_cache_store = None _sql_embedding_fn = None SQL_CACHE_COLLECTION = "sql_cache_mpnet" def get_sql_cache_store(): global _sql_cache_store, _sql_embedding_fn if _sql_cache_store is not None: return _sql_cache_store if _sql_embedding_fn is None: _sql_embedding_fn = HuggingFaceEmbeddings( model_name=EMBED_MODEL_NAME, encode_kwargs={"normalize_embeddings": True}, ) _sql_cache_store = Chroma( client=get_chroma_client(), collection_name=SQL_CACHE_COLLECTION, embedding_function=_sql_embedding_fn, ) return _sql_cache_store # %% def sanitize_metadata_for_chroma(metadata: dict) -> dict: safe = {} for k, v in (metadata or {}).items(): if v is None: continue if isinstance(v, (int, float, bool)): safe[k] = v elif isinstance(v, str): safe[k] = v else: safe[k] = str(v) return safe # %% def normalize_question_strict(q: str) -> str: """ Deterministic normalization for exact cache hits. """ if not q: return "" q = q.lower().strip() q = re.sub(r"[^\w\s]", "", q) q = re.sub(r"\s+", " ", q) return q # %% # Helper: normalize question def normalize_question_text(q: str) -> str: if not q: return "" q = q.strip().lower() q = re.sub(r"[^\w\s]", " ", q) q = re.sub(r"\s+", " ", q).strip() return q # Helper: compute success rate def compute_success_rate(md: dict) -> float: sc = md.get("success_count", 0) or 0 tf = md.get("total_feedbacks", 0) or 0 if tf <= 0: return 0.0 return float(sc) / float(tf) # Insert initial cache entry (no feedback yet) def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None): """ Insert a cached entry when you run a query and want to cache it regardless of feedback. initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0 """ if store is None: store = get_sql_cache_store() ident = uuid.uuid4().hex norm = normalize_question_text(question) md = { "id": ident, "normalized_question": norm, "sql": sql, "answer_md": answer_md, "saved_at": time.time(), "views": 1, "success_count": 0, "total_feedbacks": 0, "success_rate": 1.0, } if extra_metadata: md.update(extra_metadata) # Use store's API; store.add_texts([question], metadatas=[md]) return ident # %% import time import logging from typing import Optional, List, Dict, Any from difflib import SequenceMatcher _logger = logging.getLogger(__name__) # ------------------------------------------------------------------------- # Utility helpers # ------------------------------------------------------------------------- def _now_ts() -> float: return time.time() def similarity_score(a: str, b: str) -> float: return SequenceMatcher(None, (a or ""), (b or "")).ratio() # ------------------------------------------------------------------------- # Persist helper # ------------------------------------------------------------------------- def langchain_upsert( store, text: str, metadata: dict, cache_id: str, ): safe_md = sanitize_metadata_for_chroma(metadata) try: store.add_texts( texts=[text], metadatas=[safe_md], ids=[cache_id], ) except TypeError: store.add_texts( texts=[text], metadatas=[safe_md], ) # ------------------------------------------------------------------------- # Cache insert / update # ------------------------------------------------------------------------- def cache_sql_answer_dedup( question: str, sql: str, answer_md: str, metadata: dict, store, ): norm_q_semantic = normalize_text(question) norm_q_exact = normalize_question_strict(question) cache_id = generate_cache_id(question, sql) now = _now_ts() md = { # identity "cache_id": cache_id, "sql": sql, "answer_md": answer_md, # exact match key "normalized_question": norm_q_exact, # timestamps "saved_at": metadata.get("saved_at", now), "last_updated_at": now, "last_viewed_at": metadata.get("last_viewed_at", 0), # metrics "good_count": metadata.get("good_count", 0), "bad_count": metadata.get("bad_count", 0), "total_feedbacks": metadata.get("total_feedbacks", 0), "success_rate": metadata.get("success_rate", 0.5), "views": metadata.get("views", 0), } langchain_upsert( store=store, text=norm_q_semantic, # semantic vector metadata=md, cache_id=cache_id, ) return { "question": question, "sql": sql, "answer_md": answer_md, "metadata": md, } # ------------------------------------------------------------------------- # Find exact cached entry (question + SQL) # ------------------------------------------------------------------------- def find_cached_doc_by_sql(question: str, sql: str, store): cache_id = generate_cache_id(question, sql) coll = getattr(store, "_collection", None) if coll and hasattr(coll, "get"): try: res = coll.get(ids=[cache_id]) if res and res.get("metadatas"): md = res["metadatas"][0] return { "id": cache_id, "question": question, "sql": md.get("sql"), "answer_md": md.get("answer_md"), "metadata": md, } except Exception: pass return None # ------------------------------------------------------------------------- # Retrieve cached answers ranked primarily by success rate # ------------------------------------------------------------------------- import re import unicodedata def normalize_text(text: str) -> str: if not text: return "" text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii") text = text.lower() text = re.sub(r"[^\w\s]", " ", text) text = re.sub(r"\s+", " ", text).strip() return text import hashlib def generate_cache_id(question: str, sql: str) -> str: q = normalize_text(question) s = (sql or "").strip() key = f"{q}||{s}".encode("utf-8") return hashlib.sha1(key).hexdigest() def rank_cached_candidates(candidates: list[dict]) -> dict: """ Deterministic quality-first ranking. """ candidates.sort( key=lambda c: ( -float(c["metadata"].get("success_rate", 0.5)), -int(c["metadata"].get("good_count", 0)), int(c["metadata"].get("bad_count", 0)), -float(c["metadata"].get("last_updated_at", 0)), ) ) return candidates[0] def retrieve_exact_cached_sql(question: str, store): """ Exact question match, but still quality-ranked. """ norm_q = normalize_question_strict(question) coll = store._collection try: res = coll.get(where={"normalized_question": norm_q}) if not res or not res.get("metadatas"): return None candidates = [] for md in res["metadatas"]: candidates.append({ "matched_question": question, "sql": md["sql"], "answer_md": md.get("answer_md", ""), "distance": 0.0, "metadata": md, }) return rank_cached_candidates(candidates) except Exception: return None def retrieve_best_cached_sql( question: str, store, max_distance: float = 0.25, ): norm_q = normalize_text(question) results = store.similarity_search_with_score(norm_q, k=20) candidates = [] for doc, score in results: distance = float(score) if distance > max_distance: continue md = doc.metadata or {} if "sql" not in md: continue candidates.append({ "matched_question": doc.page_content, "sql": md["sql"], "answer_md": md.get("answer_md", ""), "distance": distance, "metadata": md, # ranking signals (FORCED numeric) "success_rate": float(md.get("success_rate", 0.5)), "good": int(md.get("good_count", 0)), "bad": int(md.get("bad_count", 0)), "views": int(md.get("views", 0)), "last_updated": float(md.get("last_updated_at", 0)), }) if not candidates: return None # QUALITY-FIRST RANKING candidates.sort( key=lambda c: ( -c["success_rate"], -c["good"], c["bad"], c["distance"], -c["last_updated"], ) ) return candidates[0] # ------------------------------------------------------------------------- # Increment views # ------------------------------------------------------------------------- def increment_cache_views(metadata: dict, store): if not metadata: return False cache_id = metadata.get("cache_id") if not cache_id: return False md = dict(metadata) md["views"] = int(md.get("views", 0)) + 1 md["last_viewed_at"] = _now_ts() md["last_updated_at"] = _now_ts() md["saved_at"] = md.get("saved_at", _now_ts()) try: langchain_upsert( store=store, text=normalize_text(md.get("normalized_question", "")), metadata=md, cache_id=cache_id, ) return True except Exception: _logger.exception("increment_cache_views failed") return False # Update metrics on feedback def update_cache_on_feedback( question: str, original_doc_md: dict, user_marked_good: bool, llm_corrected_sql: str | None, llm_corrected_answer_md: str | None, store, ): if not original_doc_md: return md = dict(original_doc_md["metadata"]) cache_id = md["cache_id"] # ---- feedback counts ---- if user_marked_good: md["good_count"] = md.get("good_count", 0) + 1 else: md["bad_count"] = md.get("bad_count", 0) + 1 md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1 md["success_rate"] = ( md["good_count"] / md["total_feedbacks"] if md["total_feedbacks"] > 0 else 0.5 ) # ---- timestamps ---- md["saved_at"] = md.get("saved_at", _now_ts()) # preserve md["last_updated_at"] = _now_ts() langchain_upsert( store=store, text=normalize_text(question), metadata=md, cache_id=cache_id, ) # ------------------------- # Corrected SQL --> NEW ENTRY # ------------------------- if llm_corrected_sql and llm_corrected_answer_md: cache_sql_answer_dedup( question=question, sql=llm_corrected_sql, answer_md=llm_corrected_answer_md, metadata={ "good_count": 1, "bad_count": 0, "total_feedbacks": 1, "success_rate": 1.0, "views": 0, "saved_at": _now_ts(), }, store=store, ) # %% import datetime def log_pipeline(event: str, detail: str = ""): """ Centralized pipeline logger. Shows in console + can be routed to file later. """ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") msg = f"[{ts}] PIPELINE::{event}" if detail: msg += f" | {detail}" print(msg) # %% # %% # ### 8. Groq LLM via LangChain from langchain_groq import ChatGroq import re import gradio as gr llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY) # %% def get_rag_context(question: str) -> str: """ Retrieve the most relevant schema documents for the question. """ docs = rag_retriever.invoke(question) return "\n\n---\n\n".join(d.page_content for d in docs) def clean_sql(sql: str) -> str: sql = sql.strip() if "```" in sql: sql = sql.replace("```sql", "").replace("```", "").strip() return sql def extract_sql_from_markdown(text: str) -> str: """ Extract the first ```sql ... ``` block from LLM output. If not found, return the whole text. """ match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() return text.strip() def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str: """ After the given marker, return the rest of the text as explanation. """ idx = text.upper().find(marker.upper()) if idx == -1: return text.strip() return text[idx + len(marker):].strip() # 6b. General / descriptive questions GENERAL_DESC_KEYWORDS = [ "what is this dataset about", "what is this data about", "describe this dataset", "describe the dataset", "dataset overview", "data overview", "summary of the dataset", "explain this dataset", ] def is_general_question(question: str) -> bool: """ Detect high-level descriptive questions where we should answer directly from schema context instead of generating SQL. """ q = question.lower().strip() return any(key in q for key in GENERAL_DESC_KEYWORDS) def answer_general_question(question: str) -> str: """ Use the RAG schema docs to generate a rich, high-level description of the Olist dataset for conceptual questions. """ rag_context = get_rag_context(question) system_instructions = """ You are a data documentation expert. You will be given: - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships). - A high-level user question like "what is this dataset about?". Your job: - Write a clear, structured overview of the dataset. - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation). - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.). - Target a non-technical person. Do NOT write SQL. Answer in Markdown. """ prompt = ( system_instructions + "\n\n=== SCHEMA CONTEXT ===\n" + rag_context + "\n\n=== USER QUESTION ===\n" + question + "\n\nDetailed dataset overview:" ) log_prompt("GENERAL_DATASET_QUESTION", prompt) response = llm.invoke(prompt) log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content) return response.content.strip() # %% # ### 9. SQL execution / validation def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]: """ Execute SQL on SQLite and return a DataFrame, else return an error. """ try: df = pd.read_sql_query(sql, connection) return df, None except Exception as e: return None, str(e) def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]: """ Basic SQL validator: - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues. """ try: cursor = connection.cursor() cursor.execute(f"EXPLAIN QUERY PLAN {sql}") return True, None except Exception as e: return False, str(e) # %% # ### 10. SQL Generation + Repair with LLM def review_and_correct_sql_with_llm( question: str, generated_sql: str, user_feedback_comment: str, validation_payload: dict, decision_reason: str, rag_context: str, ) -> tuple[str, str]: """ Fix SQL ONLY after LLM has decided it is necessary. """ log_pipeline("SQL_CORRECTION_START") evidence_md = "" for chk in validation_payload.get("equivalence", []): evidence_md += chk["df"].to_markdown(index=False) + "\n\n" prompt = f""" Original question: {question} Original SQL: ```sql {generated_sql} User feedback: {user_feedback_comment} Reason SQL is considered incorrect: {decision_reason} Validation evidence: {evidence_md} Schema: {rag_context} TASK: Fix the SQL ONLY if necessary If SQL is already correct, return it unchanged Explain your reasoning Return corrected SQL and explanation. """ resp = llm.invoke(prompt) corrected_sql = extract_sql_from_markdown(resp.content) or generated_sql corrected_sql = clean_sql(corrected_sql) explanation = extract_explanation_after_marker(resp.content, "EXPLANATION:") log_pipeline( "SQL_CORRECTION_DONE", f"sql_changed={corrected_sql.strip() != generated_sql.strip()}" ) return corrected_sql, explanation # %% # %% # Generate SQL def generate_sql(question: str, rag_context: str) -> str: """ Generate SQL using LLM, but also pull in any past user feedback (corrected SQL) for this question as external guidance. """ # External correction from past feedback, if any last_fb = get_last_feedback_for_question(conn, question) if last_fb and last_fb.get("corrected_sql"): previous_feedback_block = f""" EXTERNAL USER FEEDBACK FROM PAST RUNS: Previous generated SQL: {last_fb['generated_sql']} Corrected SQL (preferred reference): {last_fb['corrected_sql']} User comment & prior explanation: {last_fb.get('comment') or '(none)'} You must avoid repeating the same mistake and should follow the logic of the corrected SQL when appropriate, while still reasoning from the schema. """ else: previous_feedback_block = "" system_instructions = f""" You are a senior data analyst writing SQL for a SQLite database. You will be given: - A description of available tables, columns, and relationships (schema + YAML metadata). - A natural language question from the user. Your job: - Write ONE valid SQLite SQL query that answers the question. - ONLY use tables and columns that exist in the schema_context. - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions. - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations. - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator, and round to 2 decimals when appropriate. {previous_feedback_block} """ prompt = ( system_instructions + "\n\n=== RAG CONTEXT ===\n" + rag_context + "\n\n=== USER QUESTION ===\n" + question + "\n\nSQL query:" ) log_prompt("SQL_GENERATION", prompt) response = llm.invoke(prompt) log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content) sql = clean_sql(response.content) return sql # %% # Repair SQL def repair_sql( question: str, rag_context: str, bad_sql: str, error_message: str, ) -> str: """ Ask the LLM to correct a failing SQL query. """ system_instructions = """ You are a senior data analyst fixing an existing SQL query for a SQLite database. You will be given: - Schema context (tables, columns, relationships). - The user's question. - A previously generated SQL query that failed. - The SQLite error message. Your job: - Diagnose why the query failed. - Rewrite ONE valid SQLite SQL query that answers the question. - ONLY use tables and columns that exist in the schema_context. - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions. - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations. - Return ONLY the corrected SQL query, no explanation or markdown. """ prompt = ( system_instructions + "\n\n=== RAG CONTEXT ===\n" + rag_context + "\n\n=== USER QUESTION ===\n" + question + "\n\n=== PREVIOUS (FAILING) SQL ===\n" + bad_sql + "\n\n=== SQLITE ERROR ===\n" + error_message + "\n\nCorrected SQL query:" ) log_prompt("SQL_REPAIR", prompt) response = llm.invoke(prompt) log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content) sql = clean_sql(response.content) return sql # %% ### 11. Result summarization def summarize_results( question: str, sql: str, df: Optional[pd.DataFrame], rag_context: str, error: Optional[str] = None, ) -> str: """ Ask the LLM to produce a concise, human-readable answer. """ system_instructions = """ You are a senior data analyst. You will be given: The user's question. The final SQL that was executed. A small preview of the query result (as a Markdown table, if available). Optional error information if the query failed. Your job: Provide a clear, concise answer in Markdown. If the result is numeric / aggregated, explain what it means in business terms. If there was an error, explain it simply and suggest how the user could rephrase. Do NOT show raw SQL unless it is helpful to the user. """ # Build a markdown table preview if we have data if df is not None and not df.empty: preview_rows = min(len(df), 50) df_preview_md = df.head(preview_rows).to_markdown(index=False) else: df_preview_md = "(no rows returned)" prompt = ( system_instructions + "\n\n=== USER QUESTION ===\n" + question + "\n\n=== EXECUTED SQL ===\n" + sql + "\n\n=== QUERY RESULT PREVIEW ===\n" + df_preview_md + "\n\n=== RAG CONTEXT (schema) ===\n" + rag_context ) if error: prompt += "\n\n=== ERROR ===\n" + error # Logging helpers assumed to exist log_prompt("RESULT_SUMMARY", prompt) response = llm.invoke(prompt) log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content) return response.content.strip() # %% def backend_pipeline(question: str): """ STRICT cache-first backend. ALWAYS returns EXACTLY 10 values. """ if not question or not question.strip(): return ( "Please type a question.", pd.DataFrame(), "", "", "", "", {}, 4, "**Feedback attempts remaining: 4**", "", ) attempts_left = 4 attempts_text = f"**Feedback attempts remaining: {attempts_left}**" store = get_sql_cache_store() # ---------------- CACHE LOOKUP ---------------- cached = retrieve_exact_cached_sql(question, store) if cached is None: cached = retrieve_best_cached_sql(question, store, max_distance=0.25) # ---------------- CACHE HIT ---------------- if cached: increment_cache_views(cached["metadata"], store) rag_context = get_rag_context(question) sql = cached["sql"] df, err = execute_sql(sql, conn) if err: df = pd.DataFrame() md = cached["metadata"] answer_md = cached.get("answer_md", "") debug_md = f""" ### Query Provenance (Cache) **Matched question:** `{cached.get("matched_question", "")}` **Success rate:** {md.get("success_rate", 0.5):.2f} **Good / Bad:** {md.get("good_count", 0)} / {md.get("bad_count", 0)} **Views:** {md.get("views", 0)} **SQL used:** ```sql {sql} """ return ( answer_md, df, sql, rag_context, question, answer_md, df_to_state(df), attempts_left, attempts_text, debug_md, ) # ---------------- LLM FLOW ---------------- rag_context = get_rag_context(question) sql = generate_sql(question, rag_context) is_valid, err = validate_sql(sql, conn) if not is_valid and err: sql = repair_sql(question, rag_context, sql, err) df, exec_error = execute_sql(sql, conn) if exec_error: df = pd.DataFrame() answer_md = summarize_results( question=question, sql=sql, df=df if not exec_error else None, rag_context=rag_context, error=exec_error, ) # Cache LLM result cache_sql_answer_dedup( question=question, sql=sql, answer_md=answer_md, metadata={ "good_count": 0, "bad_count": 0, "total_feedbacks": 0, "success_rate": 0.5, "views": 1, "saved_at": time.time(), }, store=store, ) debug_md = f""" Query Provenance (LLM Generated) Source: LLM Execution status: {"Error" if exec_error else "Success"} SQL used: {sql} """ return ( answer_md, df, sql, rag_context, question, answer_md, df_to_state(df), attempts_left, attempts_text, debug_md, ) # %% import re def looks_like_sql(text: str) -> bool: if not text: return False return bool(re.search( r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b", text, flags=re.I, )) def is_feedback_sufficient(feedback_text: str) -> bool: """ Decide whether feedback is actionable WITHOUT LLM. """ if not feedback_text: return False text = feedback_text.strip().lower() # Long feedback → sufficient if len(text) >= 60: return True # SQL pasted if looks_like_sql(text): return True # Numbers + comparison language → sufficient if re.search(r"\b\d+\b", text) and any( k in text for k in [ "last year", "previous", "this year", "increase", "decrease", "higher", "lower" ] ): return True # Analytical keywords signal_words = [ "filter", "where", "year", "month", "should", "instead", "expected", "wrong", "missing", "count", "distinct", "join", "exclude", "include", "only" ] signals = sum(1 for w in signal_words if w in text) if signals >= 1 and len(text) >= 20: return True return False # %% def build_followup_prompt_for_user(sample_feedback: str = "") -> str: """ Deterministic follow-up question to ask the user when feedback is vague. Returns a friendly prompt that the UI can display to the user. """ base = ( "Thanks — I need a bit more detail to act on this feedback.\n\n" "Please tell me one (or more) of the following so I can check and correct the result:\n\n" "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n" " the **time range** (year/month), or the **filters** applied?\n" "2. If you expected a different number, what was the expected number (and how was it computed)?\n" "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n" "Examples you can copy-paste:\n" ) examples = ( "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n" "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n" "- \"Please exclude canceled orders (order_status = 'canceled').\"\n" "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n" ) hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare." prompt = base + examples + hint if sample_feedback: prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt return prompt # %% def interpret_feedback_intent(question: str, feedback_text: str) -> dict: """ Convert business feedback into a testable hypothesis. NO SQL generation here. """ system_prompt = """ You are a senior analytics reviewer. Your task: - Identify what the user is CLAIMING. - Do NOT assume they are correct. - Do NOT write SQL. Possible intents: - comparison_claim (mentions last year / previous / higher / lower) - overcount - undercount - missing_entities - wrong_filter - wrong_time_range - unclear Return JSON ONLY. """ prompt = f""" USER QUESTION: {question} USER FEEDBACK: {feedback_text} Return JSON exactly like this: {{ "intent": "comparison_claim | overcount | undercount | missing_entities | wrong_filter | wrong_time_range | unclear", "confidence": "low | medium | high" }} """ resp = llm.invoke(system_prompt + "\n\n" + prompt) print(resp) try: return json.loads(resp.content) except Exception: return { "intent": "unclear", "confidence": "low", } # %% def run_validation_queries( intent_info: dict, question: str, original_sql: str, user_feedback: str, rag_context: str, ) -> dict: """ Generate AND EXECUTE validation SQL. Separates equivalence checks vs context checks. Uses USER FEEDBACK explicitly for context validation. """ log_pipeline( "VALIDATION_START", f"intent={intent_info.get('intent')}" ) system_prompt = f""" You are validating a business claim. User intent: {intent_info.get("intent")} Original question: {question} User feedback (IMPORTANT): {user_feedback} Original SQL: {original_sql} Schema context: {rag_context} TASK: - Generate 2-4 SQL queries. - Mark EACH query explicitly as one of: - "equivalence": * Alternative formulations of the SAME question * Used to verify SQL correctness (joins, filters, logic) - "context": * Used to VERIFY or CHALLENGE the user's feedback * If the user mentions years, comparisons, growth, or missing entities, generate year-over-year or comparative queries. - DO NOT fix or modify the original SQL. - These queries are ONLY for validation and analysis. IMPORTANT: - Context queries MUST be driven by the user's feedback. - If the user claims a number, for example, values of a previous year, then generate queries that can verify that claim from the data. - If the user provides some query to test in the feedback for the original question, use it as well. Do not assume it is correct. Return JSON ONLY in this format: {{ "checks": [ {{ "type": "equivalence | context", "sql": "SELECT ...", "label": "short description of what this validates" }} ] }} """ resp = llm.invoke(system_prompt) raw = resp.content.strip() print(raw) # Strip ```json fences if present if raw.startswith("```"): raw = raw.strip("`") raw = raw.replace("json", "", 1).strip() try: parsed = json.loads(raw) except Exception as e: log_pipeline("VALIDATION_PARSE_FAILED", str(e)) return { "equivalence": [], "context": [], } equivalence = [] context = [] for chk in parsed.get("checks", []): sql = chk.get("sql") chk_type = chk.get("type") label = chk.get("label", "unnamed check") if not sql or chk_type not in ("equivalence", "context"): continue df, err = execute_sql(sql, conn) print(df) log_pipeline( "VALIDATION_SQL_EXECUTED", f"type={chk_type} | label={label}" ) if err or df is None or df.empty: continue payload = { "label": label, "sql": sql, "df": df, } if chk_type == "equivalence": equivalence.append(payload) else: context.append(payload) log_pipeline( "VALIDATION_DONE", f"equivalence={len(equivalence)}, context={len(context)}" ) return { "equivalence": equivalence, "context": context, } # %% def safe_json_load(text: str) -> dict | None: try: return json.loads(text) except Exception: pass # Try extracting first JSON block match = re.search(r"\{[\s\S]*\}", text) if match: try: return json.loads(match.group(0)) except Exception: return None return None # %% def decide_after_validation_with_llm( question: str, original_sql: str, user_feedback: str, intent_info: dict, original_df: pd.DataFrame, validation_payload: dict, rag_context: str, ) -> dict: """ Use LLM to decide whether SQL must change or explanation is sufficient. Returns a structured decision. """ log_pipeline("DECISION_LLM_START") # Build evidence evidence_md = "### Original Result\n" evidence_md += original_df.to_markdown(index=False) for i, chk in enumerate(validation_payload.get("equivalence", []), 1): evidence_md += f"\n\n### Equivalence Check {i}: {chk['label']}\n" evidence_md += chk["df"].to_markdown(index=False) for i, chk in enumerate(validation_payload.get("context", []), 1): evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n" evidence_md += chk["df"].to_markdown(index=False) system_prompt = """ You are a senior analytics and data analyst. Your task: - Decide whether the ORIGINAL SQL is logically correct or not, given the original question, the initial result, the feedback and the evidences of the equivalence check and context shown to you. - Do NOT judge based on expectations or growth alone. - Different definitions (e.g. customers vs customers-with-orders) are NOT errors. - Only mark SQL incorrect if evidences shows a true flaw in the results or the logic. CRITICAL RULES: - You must return VALID JSON. - DO NOT add explnations outside JSON. - DO NO use markdown. - DO NOT include extra text. Return JSON ONLY in this format: { "decision": "correct_sql | not_correct_sql", "confidence": "high | medium | low", "reason": "provide an explanation here" } """ prompt = f""" Original question: {question} Original SQL: {original_sql} User feedback: {user_feedback} Interpreted intent: {intent_info} Schema context: {rag_context} Below is the validation evidence. """ resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md) parsed = safe_json_load(resp.content) if not parsed: decision = { "decision": "correct_sql", "confidence": "low", "reason": "LLM response was not valid JSON" } else: decision = parsed log_pipeline( "DECISION_LLM_DONE", f"decision={decision.get('decision')} | confidence={decision.get('confidence')} | reason ={decision.get('reason')}" ) return decision # %% def explain_validation_with_llm( question: str, feedback: str, original_df: pd.DataFrame, validation_payload: dict, schema_context: str, ) -> str: """ Business explanation of validation results. """ log_pipeline("EXPLAIN_START") evidence_md = ( "## Pipeline: Feedback --> Validation -> Explanation\n\n" "### Baseline — Original Result\n" ) evidence_md += original_df.to_markdown(index=False) for i, chk in enumerate(validation_payload.get("context", []), 1): evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n" evidence_md += chk["df"].to_markdown(index=False) for i, chk in enumerate(validation_payload.get("equivalence", []), 1): evidence_md += f"\n\n### Verification Check {i}: {chk['label']}\n" evidence_md += chk["df"].to_markdown(index=False) system_prompt = """ You are a senior data analyst explaining results to a business user. Rules: - DO NOT mention SQL mechanics - Compare numbers explicitly - Explain trends year-over-year. - Explain what the actual numbers are and what it means for the business. - If the user is concerned about the SQL or the numbers in their feedback, based on verification and context checks, explain the results to them. """ prompt = f""" Original question: {question} User feedback: {feedback} Schema of the data: {schema_context} """ resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md) log_pipeline("EXPLAIN_DONE") return resp.content.strip() # %% # %% # %% def feedback_pipeline_interactive( feedback_rating: str, feedback_comment: str, last_sql: str, last_rag_context: str, last_question: str, last_answer_md: str, last_df_state: dict, feedback_sql: str, attempts_left: int, ): """ ALWAYS returns EXACTLY 10 values NO booleans followup_signal belongs to {"FOLLOWUP", "NONE"} """ last_df = state_to_df(last_df_state) rating = (feedback_rating or "").strip().lower() comment = (feedback_comment or "").strip() attempts_left = max(0, int(attempts_left or 0)) # ---------- INVALID RATING ---------- if rating not in {"correct", "wrong"}: return ( last_answer_md, last_df, last_sql, last_rag_context, last_question, last_answer_md, df_to_state(last_df), "NONE", "", attempts_left, ) # ---------- CORRECT ---------- if rating == "correct": update_cache_on_feedback( question=last_question, original_doc_md=find_cached_doc_by_sql( last_question, last_sql, get_sql_cache_store() ), user_marked_good=True, llm_corrected_sql=None, llm_corrected_answer_md=None, store=get_sql_cache_store(), ) md = "**Confirmed correct by user**\n\n" + last_answer_md return ( md, last_df, last_sql, last_rag_context, last_question, md, df_to_state(last_df), "NONE", "", attempts_left, ) # ---------- WRONG ---------- attempts_left = max(0, attempts_left - 1) if not is_feedback_sufficient(comment): return ( last_answer_md, last_df, last_sql, last_rag_context, last_question, last_answer_md, df_to_state(last_df), "FOLLOWUP", build_followup_prompt_for_user(comment), attempts_left, ) # ---------- VALIDATION ---------- intent_info = interpret_feedback_intent(last_question, comment) validation = run_validation_queries( intent_info=intent_info, question=last_question, original_sql=last_sql, user_feedback=comment, rag_context=last_rag_context, ) decision = decide_after_validation_with_llm( question=last_question, original_sql=last_sql, user_feedback=comment, intent_info=intent_info, original_df=last_df, validation_payload=validation, rag_context=last_rag_context, ) # ---------- EXPLAIN ---------- if decision.get("decision") == "correct_sql": explanation = explain_validation_with_llm( question=last_question, feedback=comment, original_df=last_df, validation_payload=validation, schema_context=last_rag_context, ) return ( explanation, last_df, last_sql, last_rag_context, last_question, explanation, df_to_state(last_df), "NONE", "", attempts_left, ) # ---------- SQL CORRECTION ---------- corrected_sql, explanation = review_and_correct_sql_with_llm( question=last_question, generated_sql=last_sql, user_feedback_comment=comment, validation_payload=validation, decision_reason=decision.get("reason", ""), rag_context=last_rag_context, ) df_new, err = execute_sql(corrected_sql, conn) final_md = "**SQL corrected**\n\n" + summarize_results( question=last_question, sql=corrected_sql, df=df_new if not err else None, rag_context=last_rag_context, error=err, ) return ( final_md, df_new if not err else pd.DataFrame(), corrected_sql, last_rag_context, last_question, final_md, df_to_state(df_new if not err else pd.DataFrame()), "NONE", "", attempts_left, ) # %% # %% import gradio as gr import pandas as pd # ------------------------------- # Helpers # ------------------------------- def attempts_text(attempts: int) -> str: if attempts <= 0: return "**Feedback exhausted. Please ask a new question.**" return f"**Feedback attempts remaining: {attempts}**" # ------------------------------- # CHAT TURN # ------------------------------- def chat_turn(user_input, history, state): history = history or [] history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": "🔎 Thinking…"}) ( answer_md, _df, sql, rag, q, ans_state, df_state, attempts, attempts_md, debug_md, ) = backend_pipeline(user_input) history.pop() history.append({"role": "assistant", "content": answer_md}) new_state = { "last_sql": sql, "last_rag": rag, "last_question": q, "last_answer": ans_state, "last_df": df_state, "attempts": attempts, } df = state_to_df(df_state) return history, new_state, attempts_md, debug_md, df # ------------------------------- # FEEDBACK TURN # ------------------------------- def feedback_turn(feedback_rating, feedback_text, history, state): history = history or [] # ---------- Attempts exhausted ---------- if state["attempts"] <= 0: history.append({ "role": "assistant", "content": "Feedback attempts are exhausted. Please ask a new question." }) return history, state, attempts_text(0), pd.DataFrame() # ---------- BAD without text (consume attempt) ---------- if feedback_rating == "wrong" and not (feedback_text or "").strip(): new_attempts = max(0, state["attempts"] - 1) history.append({ "role": "assistant", "content": "Please provide details before submitting negative feedback." }) new_state = dict(state) new_state["attempts"] = new_attempts return history, new_state, attempts_text(new_attempts), state_to_df(state["last_df"]) # ---------- Normal feedback ---------- history.append({ "role": "user", "content": f"{feedback_rating.upper()} feedback" }) history.append({ "role": "assistant", "content": "Validating feedback…" }) ( answer_md, _, sql_new, rag_new, q_new, ans_state, df_state, followup_flag, followup_prompt, attempts_new, ) = feedback_pipeline_interactive( feedback_rating=feedback_rating, feedback_comment=feedback_text, last_sql=state["last_sql"], last_rag_context=state["last_rag"], last_question=state["last_question"], last_answer_md=state["last_answer"], last_df_state=state["last_df"], feedback_sql="", attempts_left=state["attempts"], ) history.pop() history.append({ "role": "assistant", "content": followup_prompt if followup_flag == "FOLLOWUP" else answer_md }) new_state = { "last_sql": sql_new, "last_rag": rag_new, "last_question": q_new, "last_answer": ans_state, "last_df": df_state, "attempts": attempts_new, } df = state_to_df(df_state) return history, new_state, attempts_text(attempts_new), df # ------------------------------- # UI # ------------------------------- with gr.Blocks(analytics_enabled=False) as demo: gr.Markdown("# Analytics Assistant Chatbot") chat = gr.Chatbot(height=520) attempts_md = gr.Markdown(value=attempts_text(4)) show_debug = gr.Checkbox(label="Show SQL / Cache info", value=False) debug_panel = gr.Markdown(visible=False) result_df = gr.Dataframe( label="Query Results", wrap=True, interactive=False, ) state = gr.State({ "last_sql": "", "last_rag": "", "last_question": "", "last_answer": "", "last_df": {}, "attempts": 4, }) user_input = gr.Textbox( placeholder="Ask a question…", lines=1, show_label=False, ) feedback_input = gr.Textbox( placeholder="Details (mandatory if the result seems wrong to you)…", lines=2, show_label=False, ) # Buttons with gr.Row(): gr.Column(scale=6) with gr.Column(scale=2, min_width=110): good_btn = gr.Button("Good", size="sm") with gr.Column(scale=2, min_width=110): bad_btn = gr.Button("Bad", size="sm") # Toggle debug panel show_debug.change( lambda show, md: gr.update(value=md, visible=show), inputs=[show_debug, debug_panel], outputs=debug_panel, ) # Question submit user_input.submit( chat_turn, inputs=[user_input, chat, state], outputs=[chat, state, attempts_md, debug_panel, result_df], queue=False, ) # GOOD feedback good_btn.click( feedback_turn, inputs=[gr.State("correct"), feedback_input, chat, state], outputs=[chat, state, attempts_md, result_df], queue=False, ) # BAD feedback bad_btn.click( feedback_turn, inputs=[gr.State("wrong"), feedback_input, chat, state], outputs=[chat, state, attempts_md, result_df], queue=False, ) # %% if __name__ == "__main__": demo.launch() # %%