Spaces:
Sleeping
Sleeping
| # %% | |
| import os | |
| import uuid | |
| import sqlite3 | |
| import datetime | |
| import json | |
| import re | |
| import time | |
| from typing import Dict, Any, List, Tuple, Optional | |
| import hashlib | |
| import pandas as pd | |
| from groq import Groq | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| # %% | |
| # 0. Global config | |
| DB_PATH = "olist.db" | |
| DATA_DIR = "data" | |
| # Groq Model | |
| GROQ_MODEL_NAME = "llama-3.3-70b-versatile" | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY not found.") | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| # Embedding model | |
| EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" | |
| # Logging paths | |
| PROMPT_LOG_PATH = "prompt_logs.txt" | |
| RUN_LOG_PATH = "run_logs.txt" | |
| # %% | |
| import os | |
| import chromadb | |
| from chromadb.config import Settings | |
| CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data") | |
| os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True) | |
| _CHROMA_CLIENT = None | |
| def get_chroma_client(): | |
| global _CHROMA_CLIENT | |
| if _CHROMA_CLIENT is None: | |
| _CHROMA_CLIENT = chromadb.PersistentClient( | |
| path=CHROMA_PERSIST_DIR, | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| ), | |
| ) | |
| return _CHROMA_CLIENT | |
| # %% | |
| def log_prompt(tag: str, prompt: str) -> None: | |
| """ | |
| Append the full prompt to a log file, with a tag and timestamp. | |
| """ | |
| timestamp = datetime.datetime.now().isoformat(timespec="seconds") | |
| header = f"\n\n================ {tag} @ {timestamp} ================\n" | |
| with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f: | |
| f.write(header) | |
| f.write(prompt) | |
| f.write("\n================ END PROMPT ================\n") | |
| def log_run_event(tag: str, content: str) -> None: | |
| """ | |
| Append model response, final SQL, and error info into a run log. | |
| """ | |
| timestamp = datetime.datetime.now().isoformat(timespec="seconds") | |
| header = f"\n\n================ {tag} @ {timestamp} ================\n" | |
| with open(RUN_LOG_PATH, "a", encoding="utf-8") as f: | |
| f.write(header) | |
| f.write(content) | |
| f.write("\n================ END EVENT ================\n") | |
| # %% | |
| from pathlib import Path | |
| BASE_DIR = Path.cwd() | |
| ANALYTICS_DB_PATH = BASE_DIR / "data" / "olist_analytics.db" | |
| ANALYTICS_DB_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| CHAT_DB_PATH = BASE_DIR / "chat_data" / "chat_history.db" | |
| CHAT_DB_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| # %% | |
| # %% | |
| def get_analytics_conn() -> sqlite3.Connection: | |
| return sqlite3.connect( | |
| ANALYTICS_DB_PATH, | |
| check_same_thread=False | |
| ) | |
| def get_chat_conn() -> sqlite3.Connection: | |
| conn = sqlite3.connect(CHAT_DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA foreign_keys = ON;") | |
| return conn | |
| # %% | |
| # ============================== | |
| # Chat message helpers | |
| # ============================== | |
| def assistant_msg(text: str): | |
| return (None, text) | |
| def user_msg(text: str): | |
| return (text, None) | |
| def add_assistant_state(history, text): | |
| history.append(assistant_msg(text)) | |
| return history | |
| # %% | |
| def df_to_state(df: pd.DataFrame) -> dict: | |
| if df is None or not isinstance(df, pd.DataFrame) or df.empty: | |
| return {"columns": [], "rows": []} | |
| cols = list(df.columns) | |
| rows = df.astype(str).values.tolist() | |
| if not rows: | |
| return {"columns": [], "rows": []} | |
| return {"columns": cols, "rows": rows} | |
| def state_to_df(state: dict) -> pd.DataFrame: | |
| if not isinstance(state, dict): | |
| return pd.DataFrame() | |
| cols = state.get("columns") | |
| rows = state.get("rows") | |
| if not cols or rows is None: | |
| return pd.DataFrame() | |
| return pd.DataFrame(rows, columns=cols) | |
| # %% | |
| def init_feedback_table(conn: sqlite3.Connection) -> None: | |
| """ | |
| Create (or upgrade) a table to capture user feedback on model answers. | |
| """ | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_feedback ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| created_at TEXT NOT NULL, | |
| question TEXT NOT NULL, | |
| generated_sql TEXT, | |
| model_answer TEXT, | |
| rating TEXT CHECK(rating IN ('good','bad')) NOT NULL, | |
| comment TEXT, | |
| corrected_sql TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| def record_feedback( | |
| conn: sqlite3.Connection, | |
| question: str, | |
| generated_sql: str, | |
| model_answer: str, | |
| rating: str, # "good" or "bad" | |
| comment: Optional[str] = None, | |
| corrected_sql: Optional[str] = None, | |
| ) -> None: | |
| """ | |
| Store user feedback about a particular model answer / SQL query. | |
| If corrected_sql is provided, it is treated as an external correction. | |
| """ | |
| rating = rating.lower() | |
| if rating not in ("good", "bad"): | |
| raise ValueError("rating must be 'good' or 'bad'") | |
| ts = datetime.datetime.now().isoformat(timespec="seconds") | |
| conn.execute( | |
| """ | |
| INSERT INTO user_feedback ( | |
| created_at, question, generated_sql, model_answer, | |
| rating, comment, corrected_sql | |
| ) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, | |
| (ts, question, generated_sql, model_answer, rating, comment, corrected_sql), | |
| ) | |
| conn.commit() | |
| def get_last_feedback_for_question( | |
| conn: sqlite3.Connection, | |
| question: str, | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Return the most recent feedback row for this question (if any). | |
| """ | |
| cur = conn.cursor() | |
| cur.execute( | |
| """ | |
| SELECT created_at, generated_sql, model_answer, | |
| rating, comment, corrected_sql | |
| FROM user_feedback | |
| WHERE question = ? | |
| ORDER BY created_at DESC | |
| LIMIT 1 | |
| """, | |
| (question,), | |
| ) | |
| row = cur.fetchone() | |
| if not row: | |
| return None | |
| return { | |
| "created_at": row[0], | |
| "generated_sql": row[1], | |
| "model_answer": row[2], | |
| "rating": row[3], | |
| "comment": row[4], | |
| "corrected_sql": row[5], | |
| } | |
| # %% | |
| def init_db() -> None: | |
| with get_analytics_conn() as conn: | |
| csv_files = [ | |
| "olist_customers_dataset.csv", | |
| "olist_orders_dataset.csv", | |
| "olist_order_items_dataset.csv", | |
| "olist_products_dataset.csv", | |
| "olist_order_reviews_dataset.csv", | |
| "olist_order_payments_dataset.csv", | |
| "product_category_name_translation.csv", | |
| "olist_sellers_dataset.csv", | |
| "olist_geolocation_dataset.csv", | |
| ] | |
| for fname in csv_files: | |
| path = os.path.join(DATA_DIR, fname) | |
| if not os.path.exists(path): | |
| continue | |
| table_name = os.path.splitext(fname)[0] | |
| df = pd.read_csv(path) | |
| df.to_sql(table_name, conn, if_exists="replace", index=False) | |
| init_feedback_table(conn) | |
| # %% | |
| # 4. Manual docs for Olist tables | |
| OLIST_DOCS: Dict[str, Dict[str, Any]] = { | |
| "olist_customers_dataset": { | |
| "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).", | |
| "columns": { | |
| "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.", | |
| "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.", | |
| "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.", | |
| "customer_city": "Customer's city as captured at the time of the order or registration.", | |
| "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)." | |
| } | |
| }, | |
| "olist_orders_dataset": { | |
| "description": "Customer orders placed on the Olist marketplace, one row per order.", | |
| "columns": { | |
| "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.", | |
| "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.", | |
| "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).", | |
| "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).", | |
| "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.", | |
| "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.", | |
| "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.", | |
| "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout." | |
| } | |
| }, | |
| "olist_order_items_dataset": { | |
| "description": "Order line items, one row per product per order.", | |
| "columns": { | |
| "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.", | |
| "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.", | |
| "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.", | |
| "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.", | |
| "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.", | |
| "price": "Item price paid by the customer for this line (in BRL, not including freight).", | |
| "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)." | |
| } | |
| }, | |
| "olist_products_dataset": { | |
| "description": "Product catalog with physical and category attributes, one row per product.", | |
| "columns": { | |
| "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.", | |
| "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.", | |
| "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).", | |
| "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').", | |
| "product_photos_qty": "Number of product images associated with the listing.", | |
| "product_weight_g": "Product weight in grams.", | |
| "product_length_cm": "Product length in centimeters (package dimension).", | |
| "product_height_cm": "Product height in centimeters (package dimension).", | |
| "product_width_cm": "Product width in centimeters (package dimension)." | |
| } | |
| }, | |
| "olist_order_reviews_dataset": { | |
| "description": "Post-purchase customer reviews and satisfaction scores, one row per review.", | |
| "columns": { | |
| "review_id": "Primary key. Unique identifier for each review record.", | |
| "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.", | |
| "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).", | |
| "review_comment_title": "Optional short text title or summary of the review.", | |
| "review_comment_message": "Optional detailed free-text comment describing the customer experience.", | |
| "review_creation_date": "Date when the customer created the review.", | |
| "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)." | |
| } | |
| }, | |
| "olist_order_payments_dataset": { | |
| "description": "Payments associated with orders, one row per payment record (order can have multiple payments).", | |
| "columns": { | |
| "order_id": "Foreign key to olist_orders_dataset.order_id.", | |
| "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).", | |
| "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).", | |
| "payment_installments": "Number of installments chosen by the customer for this payment.", | |
| "payment_value": "Monetary amount paid in this payment record (in BRL)." | |
| } | |
| }, | |
| "product_category_name_translation": { | |
| "description": "Lookup table mapping Portuguese product category names to English equivalents.", | |
| "columns": { | |
| "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.", | |
| "product_category_name_english": "Translated product category name in English." | |
| } | |
| }, | |
| "olist_sellers_dataset": { | |
| "description": "Seller master data, one row per seller operating on the Olist marketplace.", | |
| "columns": { | |
| "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.", | |
| "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.", | |
| "seller_city": "City where the seller is located.", | |
| "seller_state": "State where the seller is located (two-letter Brazilian state code)." | |
| } | |
| }, | |
| "olist_geolocation_dataset": { | |
| "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.", | |
| "columns": { | |
| "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.", | |
| "geolocation_lat": "Latitude coordinate of the location.", | |
| "geolocation_lng": "Longitude coordinate of the location.", | |
| "geolocation_city": "City name for the location.", | |
| "geolocation_state": "State code for the location (two-letter Brazilian state code)." | |
| } | |
| } | |
| } | |
| init_db() | |
| # %% | |
| from pathlib import Path | |
| # ---------------------------------- | |
| # Configuration | |
| # ---------------------------------- | |
| from pathlib import Path | |
| BASE_DIR = Path.cwd() | |
| DB_PATH = BASE_DIR / "chat_data" / "chat_history.db" | |
| DB_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| # ---------------------------------- | |
| # SQL schema | |
| # ---------------------------------- | |
| SCHEMA_SQL = """ | |
| -- =============================== | |
| -- Conversations (chat threads) | |
| -- =============================== | |
| CREATE TABLE IF NOT EXISTS conversations ( | |
| conversation_id TEXT PRIMARY KEY, | |
| title TEXT, | |
| created_at TEXT NOT NULL, | |
| updated_at TEXT NOT NULL | |
| ); | |
| -- =============================== | |
| -- Messages inside conversations | |
| -- =============================== | |
| CREATE TABLE IF NOT EXISTS conversation_messages ( | |
| message_id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| conversation_id TEXT NOT NULL, | |
| role TEXT CHECK(role IN ('user', 'assistant')) NOT NULL, | |
| content TEXT NOT NULL, | |
| created_at TEXT NOT NULL, | |
| FOREIGN KEY (conversation_id) REFERENCES conversations(conversation_id) | |
| ); | |
| -- =============================== | |
| -- Snapshot of last known state | |
| -- =============================== | |
| CREATE TABLE IF NOT EXISTS conversation_state ( | |
| conversation_id TEXT PRIMARY KEY, | |
| state_json TEXT NOT NULL, | |
| updated_at TEXT NOT NULL, | |
| FOREIGN KEY (conversation_id) REFERENCES conversations(conversation_id) | |
| ); | |
| """ | |
| # ---------------------------------- | |
| # Execute | |
| # ---------------------------------- | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.execute("PRAGMA foreign_keys = ON;") | |
| conn.executescript(SCHEMA_SQL) | |
| conn.commit() | |
| conn.close() | |
| print(" Chat history tables initialized at:", DB_PATH.resolve()) | |
| # %% | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| DB_PATH = Path("chat_data/chat_history.db") | |
| # ------------------------------ | |
| # Connection helper | |
| # ------------------------------ | |
| def get_conn(): | |
| return get_chat_conn() | |
| # ------------------------------ | |
| # Conversation lifecycle | |
| # ------------------------------ | |
| def create_conversation(title: str | None = None) -> str: | |
| conversation_id = uuid.uuid4().hex | |
| now = datetime.datetime.utcnow().isoformat() | |
| if not title: | |
| title = "New conversation" | |
| with get_conn() as conn: | |
| conn.execute( | |
| """ | |
| INSERT INTO conversations (conversation_id, title, created_at, updated_at) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (conversation_id, title, now, now), | |
| ) | |
| return conversation_id | |
| def list_conversations(limit: int = 50): | |
| with get_conn() as conn: | |
| rows = conn.execute( | |
| """ | |
| SELECT conversation_id, title, updated_at | |
| FROM conversations | |
| ORDER BY updated_at DESC | |
| LIMIT ? | |
| """, | |
| (limit,), | |
| ).fetchall() | |
| return [dict(r) for r in rows] | |
| def rename_conversation(conversation_id: str, new_title: str): | |
| now = datetime.datetime.utcnow().isoformat() | |
| with get_conn() as conn: | |
| conn.execute( | |
| """ | |
| UPDATE conversations | |
| SET title = ?, updated_at = ? | |
| WHERE conversation_id = ? | |
| """, | |
| (new_title, now, conversation_id), | |
| ) | |
| # ------------------------------ | |
| # Messages | |
| # ------------------------------ | |
| def append_message(conversation_id: str, role: str, content: str): | |
| now = datetime.datetime.utcnow().isoformat() | |
| with get_conn() as conn: | |
| conn.execute( | |
| """ | |
| INSERT INTO conversation_messages | |
| (conversation_id, role, content, created_at) | |
| VALUES (?, ?, ?, ?) | |
| """, | |
| (conversation_id, role, content, now), | |
| ) | |
| conn.execute( | |
| """ | |
| UPDATE conversations | |
| SET updated_at = ? | |
| WHERE conversation_id = ? | |
| """, | |
| (now, conversation_id), | |
| ) | |
| def load_messages(conversation_id: str): | |
| with get_conn() as conn: | |
| rows = conn.execute( | |
| """ | |
| SELECT role, content | |
| FROM conversation_messages | |
| WHERE conversation_id = ? | |
| ORDER BY message_id ASC | |
| """, | |
| (conversation_id,), | |
| ).fetchall() | |
| return [{"role": r["role"], "content": r["content"]} for r in rows] | |
| # ------------------------------ | |
| # State snapshot | |
| # ------------------------------ | |
| def save_state(conversation_id: str, state: dict): | |
| now = datetime.datetime.utcnow().isoformat() | |
| payload = json.dumps(state) | |
| with get_conn() as conn: | |
| conn.execute( | |
| """ | |
| INSERT INTO conversation_state (conversation_id, state_json, updated_at) | |
| VALUES (?, ?, ?) | |
| ON CONFLICT(conversation_id) | |
| DO UPDATE SET | |
| state_json = excluded.state_json, | |
| updated_at = excluded.updated_at | |
| """, | |
| (conversation_id, payload, now), | |
| ) | |
| def load_state(conversation_id: str) -> dict: | |
| with get_conn() as conn: | |
| row = conn.execute( | |
| """ | |
| SELECT state_json | |
| FROM conversation_state | |
| WHERE conversation_id = ? | |
| """, | |
| (conversation_id,), | |
| ).fetchone() | |
| if not row: | |
| return {} | |
| return json.loads(row["state_json"]) | |
| # %% | |
| def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]: | |
| """ | |
| Introspect SQLite tables, columns and foreign key relationships. | |
| """ | |
| cursor = connection.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| metadata: Dict[str, Any] = {"tables": {}} | |
| for table in tables: | |
| cursor.execute(f"PRAGMA table_info('{table}')") | |
| cols = cursor.fetchall() | |
| # Manual docs for the table | |
| table_docs = OLIST_DOCS.get(table, {}) | |
| table_desc: str = table_docs.get( | |
| "description", | |
| f"Table '{table}' from Olist dataset." | |
| ) | |
| column_docs: Dict[str, str] = table_docs.get("columns", {}) | |
| columns_meta: Dict[str, str] = {} | |
| for c in cols: | |
| col_name = c[1] | |
| col_type = c[2] or "TEXT" | |
| if col_name in column_docs: | |
| columns_meta[col_name] = column_docs[col_name] | |
| else: | |
| columns_meta[col_name] = f"Column '{col_name}' of type {col_type}" | |
| cursor.execute(f"PRAGMA foreign_key_list('{table}')") | |
| fk_rows = cursor.fetchall() | |
| relationships: List[str] = [] | |
| for fk in fk_rows: | |
| ref_table = fk[2] | |
| from_col = fk[3] | |
| to_col = fk[4] | |
| relationships.append( | |
| f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)" | |
| ) | |
| metadata["tables"][table] = { | |
| "description": table_desc, | |
| "columns": columns_meta, | |
| "relationships": relationships, | |
| } | |
| return metadata | |
| with get_analytics_conn() as conn: | |
| schema_metadata = extract_schema_metadata(conn) | |
| # %% | |
| def build_schema_yaml(metadata: Dict[str, Any]) -> str: | |
| """ | |
| Render metadata dict into a YAML-style string. | |
| """ | |
| lines: List[str] = ["tables:"] | |
| for tname, tinfo in metadata["tables"].items(): | |
| lines.append(f" {tname}:") | |
| desc = tinfo.get("description", "").replace('"', "'") | |
| lines.append(f' description: "{desc}"') | |
| lines.append(" columns:") | |
| for col_name, col_desc in tinfo.get("columns", {}).items(): | |
| col_desc_clean = col_desc.replace('"', "'") | |
| lines.append(f' {col_name}: "{col_desc_clean}"') | |
| rels = tinfo.get("relationships", []) | |
| if rels: | |
| lines.append(" relationships:") | |
| for rel in rels: | |
| rel_clean = rel.replace('"', "'") | |
| lines.append(f' - "{rel_clean}"') | |
| return "\n".join(lines) | |
| schema_yaml = build_schema_yaml(schema_metadata) | |
| # %% | |
| # 6. Build schema documents for RAG (taking samples from the table) | |
| def build_schema_documents( | |
| schema_metadata: Dict[str, Any], | |
| sample_rows: int = 5, | |
| ) -> Tuple[List[str], List[Dict[str, Any]]]: | |
| """ | |
| Build one rich RAG document per table, using schema_metadata. | |
| """ | |
| with get_analytics_conn() as connection: | |
| cursor = connection.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| docs: List[str] = [] | |
| metadatas: List[Dict[str, Any]] = [] | |
| for table in tables: | |
| if table not in schema_metadata["tables"]: | |
| continue | |
| tmeta = schema_metadata["tables"][table] | |
| table_desc = tmeta.get("description", "") | |
| columns_meta = tmeta.get("columns", {}) | |
| relationships = tmeta.get("relationships", []) | |
| cursor.execute(f"PRAGMA table_info('{table}')") | |
| cols = cursor.fetchall() | |
| col_lines = [] | |
| for c in cols: | |
| col_name = c[1] | |
| col_type = c[2] or "TEXT" | |
| col_desc = columns_meta.get( | |
| col_name, | |
| f"Column '{col_name}' of type {col_type}" | |
| ) | |
| col_lines.append(f"- {col_name} ({col_type}): {col_desc}") | |
| try: | |
| sample_df = pd.read_sql_query( | |
| f"SELECT * FROM '{table}' LIMIT {sample_rows}", | |
| connection, | |
| ) | |
| sample_text = sample_df.to_markdown(index=False) | |
| except Exception: | |
| sample_text = "(could not fetch sample rows)" | |
| rel_block = "" | |
| if relationships: | |
| rel_block = "Relationships:\n" + "\n".join( | |
| f"- {rel}" for rel in relationships | |
| ) + "\n" | |
| doc_text = ( | |
| f"Table: {table}\n" | |
| f"Description: {table_desc}\n\n" | |
| f"Columns:\n" + "\n".join(col_lines) + "\n\n" | |
| f"{rel_block}\n" | |
| f"Example rows:\n{sample_text}\n" | |
| ) | |
| docs.append(doc_text) | |
| metadatas.append({ | |
| "doc_type": "table_schema", | |
| "table_name": table, | |
| }) | |
| return docs, metadatas | |
| # Build RAG texts + metadata | |
| schema_docs, schema_doc_metas = build_schema_documents(schema_metadata) | |
| RAG_TEXTS: List[str] = [] | |
| RAG_METADATAS: List[Dict[str, Any]] = [] | |
| # 1) Per-table docs | |
| RAG_TEXTS.extend(schema_docs) | |
| RAG_METADATAS.extend(schema_doc_metas) | |
| # 2) Global YAML as a separate doc | |
| RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml) | |
| RAG_METADATAS.append({"doc_type": "global_schema"}) | |
| # %% | |
| def build_store_final(): | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name=EMBED_MODEL_NAME, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| rag_store = Chroma( | |
| client=get_chroma_client(), | |
| collection_name="rag_schema_store", | |
| embedding_function=embedding_model, | |
| ) | |
| if rag_store._collection.count() == 0: | |
| rag_store.add_texts( | |
| texts=RAG_TEXTS, | |
| metadatas=RAG_METADATAS, | |
| ) | |
| rag_retriever = rag_store.as_retriever( | |
| search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}} | |
| ) | |
| return rag_store, rag_retriever | |
| # Initialize once | |
| rag_store, rag_retriever = build_store_final() | |
| # %% | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| _sql_cache_store = None | |
| _sql_embedding_fn = None | |
| SQL_CACHE_COLLECTION = "sql_cache_mpnet" | |
| def get_sql_cache_store(): | |
| global _sql_cache_store, _sql_embedding_fn | |
| if _sql_cache_store is not None: | |
| return _sql_cache_store | |
| if _sql_embedding_fn is None: | |
| _sql_embedding_fn = HuggingFaceEmbeddings( | |
| model_name=EMBED_MODEL_NAME, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| _sql_cache_store = Chroma( | |
| client=get_chroma_client(), | |
| collection_name=SQL_CACHE_COLLECTION, | |
| embedding_function=_sql_embedding_fn, | |
| ) | |
| return _sql_cache_store | |
| # %% | |
| def sanitize_metadata_for_chroma(metadata: dict) -> dict: | |
| safe = {} | |
| for k, v in (metadata or {}).items(): | |
| if v is None: | |
| continue | |
| if isinstance(v, (int, float, bool)): | |
| safe[k] = v | |
| elif isinstance(v, str): | |
| safe[k] = v | |
| else: | |
| safe[k] = str(v) | |
| return safe | |
| # %% | |
| def normalize_question_strict(q: str) -> str: | |
| """ | |
| Deterministic normalization for exact cache hits. | |
| """ | |
| if not q: | |
| return "" | |
| q = q.lower().strip() | |
| q = re.sub(r"[^\w\s]", "", q) | |
| q = re.sub(r"\s+", " ", q) | |
| return q | |
| # %% | |
| # Helper: normalize question | |
| def normalize_question_text(q: str) -> str: | |
| if not q: | |
| return "" | |
| q = q.strip().lower() | |
| q = re.sub(r"[^\w\s]", " ", q) | |
| q = re.sub(r"\s+", " ", q).strip() | |
| return q | |
| # Helper: compute success rate | |
| def compute_success_rate(md: dict) -> float: | |
| sc = md.get("success_count", 0) or 0 | |
| tf = md.get("total_feedbacks", 0) or 0 | |
| if tf <= 0: | |
| return 0.0 | |
| return float(sc) / float(tf) | |
| # Insert initial cache entry (no feedback yet) | |
| def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None): | |
| """ | |
| Insert a cached entry when you run a query and want to cache it regardless of feedback. | |
| initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0 | |
| """ | |
| if store is None: | |
| store = get_sql_cache_store() | |
| ident = uuid.uuid4().hex | |
| norm = normalize_question_text(question) | |
| md = { | |
| "id": ident, | |
| "normalized_question": norm, | |
| "sql": sql, | |
| "answer_md": answer_md, | |
| "saved_at": time.time(), | |
| "views": 1, | |
| "success_count": 0, | |
| "total_feedbacks": 0, | |
| "success_rate": 0.5, | |
| } | |
| if extra_metadata: | |
| md.update(extra_metadata) | |
| # Use store's API; | |
| store.add_texts([question], metadatas=[md]) | |
| return ident | |
| # %% | |
| import time | |
| import logging | |
| from typing import Optional, List, Dict, Any | |
| from difflib import SequenceMatcher | |
| _logger = logging.getLogger(__name__) | |
| # ------------------------------------------------------------------------- | |
| # Utility helpers | |
| # ------------------------------------------------------------------------- | |
| def _now_ts() -> float: | |
| return time.time() | |
| def similarity_score(a: str, b: str) -> float: | |
| return SequenceMatcher(None, (a or ""), (b or "")).ratio() | |
| # ------------------------------------------------------------------------- | |
| # Persist helper | |
| # ------------------------------------------------------------------------- | |
| def langchain_upsert( | |
| store, | |
| text: str, | |
| metadata: dict, | |
| cache_id: str, | |
| ): | |
| safe_md = sanitize_metadata_for_chroma(metadata) | |
| try: | |
| store.add_texts( | |
| texts=[text], | |
| metadatas=[safe_md], | |
| ids=[cache_id], | |
| ) | |
| except TypeError: | |
| store.add_texts( | |
| texts=[text], | |
| metadatas=[safe_md], | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Cache insert / update | |
| # ------------------------------------------------------------------------- | |
| def cache_sql_answer_dedup( | |
| question: str, | |
| sql: str, | |
| answer_md: str, | |
| metadata: dict, | |
| store, | |
| ): | |
| norm_q_semantic = normalize_text(question) | |
| norm_q_exact = normalize_question_strict(question) | |
| cache_id = generate_cache_id(question, sql) | |
| now = _now_ts() | |
| md = { | |
| # identity | |
| "cache_id": cache_id, | |
| "sql": sql, | |
| "answer_md": answer_md, | |
| # exact match key | |
| "normalized_question": norm_q_exact, | |
| # timestamps | |
| "saved_at": metadata.get("saved_at", now), | |
| "last_updated_at": now, | |
| "last_viewed_at": metadata.get("last_viewed_at", 0), | |
| # metrics | |
| "good_count": metadata.get("good_count", 0), | |
| "bad_count": metadata.get("bad_count", 0), | |
| "total_feedbacks": metadata.get("total_feedbacks", 0), | |
| "success_rate": metadata.get("success_rate", 0.5), | |
| "views": metadata.get("views", 0), | |
| } | |
| langchain_upsert( | |
| store=store, | |
| text=norm_q_semantic, # semantic vector | |
| metadata=md, | |
| cache_id=cache_id, | |
| ) | |
| return { | |
| "question": question, | |
| "sql": sql, | |
| "answer_md": answer_md, | |
| "metadata": md, | |
| } | |
| # ------------------------------------------------------------------------- | |
| # Find exact cached entry (question + SQL) | |
| # ------------------------------------------------------------------------- | |
| def find_cached_doc_by_sql(question: str, sql: str, store): | |
| cache_id = generate_cache_id(question, sql) | |
| coll = getattr(store, "_collection", None) | |
| if coll and hasattr(coll, "get"): | |
| try: | |
| res = coll.get(ids=[cache_id]) | |
| if res and res.get("metadatas"): | |
| md = res["metadatas"][0] | |
| return { | |
| "id": cache_id, | |
| "question": question, | |
| "sql": md.get("sql"), | |
| "answer_md": md.get("answer_md"), | |
| "metadata": md, | |
| } | |
| except Exception: | |
| pass | |
| return None | |
| # ------------------------------------------------------------------------- | |
| # Retrieve cached answers ranked primarily by success rate | |
| # ------------------------------------------------------------------------- | |
| import re | |
| import unicodedata | |
| def normalize_text(text: str) -> str: | |
| if not text: | |
| return "" | |
| text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii") | |
| text = text.lower() | |
| text = re.sub(r"[^\w\s]", " ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| import hashlib | |
| def generate_cache_id(question: str, sql: str) -> str: | |
| q = normalize_text(question) | |
| s = (sql or "").strip() | |
| key = f"{q}||{s}".encode("utf-8") | |
| return hashlib.sha1(key).hexdigest() | |
| def rank_cached_candidates(candidates: list[dict]) -> dict: | |
| """ | |
| Deterministic quality-first ranking. | |
| """ | |
| candidates.sort( | |
| key=lambda c: ( | |
| -float(c["metadata"].get("success_rate", 0.5)), | |
| -int(c["metadata"].get("good_count", 0)), | |
| int(c["metadata"].get("bad_count", 0)), | |
| -float(c["metadata"].get("last_updated_at", 0)), | |
| ) | |
| ) | |
| return candidates[0] | |
| def retrieve_exact_cached_sql(question: str, store): | |
| """ | |
| Exact question match, but still quality-ranked. | |
| """ | |
| norm_q = normalize_question_strict(question) | |
| coll = store._collection | |
| try: | |
| res = coll.get(where={"normalized_question": norm_q}) | |
| if not res or not res.get("metadatas"): | |
| return None | |
| candidates = [] | |
| for md in res["metadatas"]: | |
| candidates.append({ | |
| "matched_question": question, | |
| "sql": md["sql"], | |
| "answer_md": md.get("answer_md", ""), | |
| "distance": 0.0, | |
| "metadata": md, | |
| }) | |
| return rank_cached_candidates(candidates) | |
| except Exception: | |
| return None | |
| def retrieve_best_cached_sql( | |
| question: str, | |
| store, | |
| max_distance: float = 0.25, | |
| ): | |
| norm_q = normalize_text(question) | |
| results = store.similarity_search_with_score(norm_q, k=20) | |
| candidates = [] | |
| for doc, score in results: | |
| distance = float(score) | |
| if distance > max_distance: | |
| continue | |
| md = doc.metadata or {} | |
| if "sql" not in md: | |
| continue | |
| candidates.append({ | |
| "matched_question": doc.page_content, | |
| "sql": md["sql"], | |
| "answer_md": md.get("answer_md", ""), | |
| "distance": distance, | |
| "metadata": md, | |
| # ranking signals (FORCED numeric) | |
| "success_rate": float(md.get("success_rate", 0.5)), | |
| "good": int(md.get("good_count", 0)), | |
| "bad": int(md.get("bad_count", 0)), | |
| "views": int(md.get("views", 0)), | |
| "last_updated": float(md.get("last_updated_at", 0)), | |
| }) | |
| if not candidates: | |
| return None | |
| # QUALITY-FIRST RANKING | |
| candidates.sort( | |
| key=lambda c: ( | |
| -c["success_rate"], | |
| -c["good"], | |
| c["bad"], | |
| c["distance"], | |
| -c["last_updated"], | |
| ) | |
| ) | |
| return candidates[0] | |
| # ------------------------------------------------------------------------- | |
| # Increment views | |
| # ------------------------------------------------------------------------- | |
| def increment_cache_views(metadata: dict, store): | |
| if not metadata: | |
| return False | |
| cache_id = metadata.get("cache_id") | |
| if not cache_id: | |
| return False | |
| md = dict(metadata) | |
| md["views"] = int(md.get("views", 0)) + 1 | |
| md["last_viewed_at"] = _now_ts() | |
| md["last_updated_at"] = _now_ts() | |
| md["saved_at"] = md.get("saved_at", _now_ts()) | |
| try: | |
| langchain_upsert( | |
| store=store, | |
| text=normalize_text(md.get("normalized_question", "")), | |
| metadata=md, | |
| cache_id=cache_id, | |
| ) | |
| return True | |
| except Exception: | |
| _logger.exception("increment_cache_views failed") | |
| return False | |
| # Update metrics on feedback | |
| def update_cache_on_feedback( | |
| question: str, | |
| original_doc_md: dict, | |
| user_marked_good: bool, | |
| llm_corrected_sql: str | None, | |
| llm_corrected_answer_md: str | None, | |
| store, | |
| ): | |
| if not original_doc_md: | |
| return | |
| md = dict(original_doc_md["metadata"]) | |
| cache_id = md["cache_id"] | |
| # ---- feedback counts ---- | |
| if user_marked_good: | |
| md["good_count"] = md.get("good_count", 0) + 1 | |
| else: | |
| md["bad_count"] = md.get("bad_count", 0) + 1 | |
| md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1 | |
| md["success_rate"] = ( | |
| md["good_count"] / md["total_feedbacks"] | |
| if md["total_feedbacks"] > 0 else 0.5 | |
| ) | |
| # ---- timestamps ---- | |
| md["saved_at"] = md.get("saved_at", _now_ts()) # preserve | |
| md["last_updated_at"] = _now_ts() | |
| langchain_upsert( | |
| store=store, | |
| text=normalize_text(question), | |
| metadata=md, | |
| cache_id=cache_id, | |
| ) | |
| # ------------------------- | |
| # Corrected SQL --> NEW ENTRY | |
| # ------------------------- | |
| if llm_corrected_sql and llm_corrected_answer_md: | |
| cache_sql_answer_dedup( | |
| question=question, | |
| sql=llm_corrected_sql, | |
| answer_md=llm_corrected_answer_md, | |
| metadata={ | |
| "good_count": 1, | |
| "bad_count": 0, | |
| "total_feedbacks": 1, | |
| "success_rate": 0.5, | |
| "views": 0, | |
| "saved_at": _now_ts(), | |
| }, | |
| store=store, | |
| ) | |
| # %% | |
| def log_pipeline(event: str, detail: str = ""): | |
| """ | |
| Centralized pipeline logger. | |
| Shows in console + can be routed to file later. | |
| """ | |
| ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| msg = f"[{ts}] PIPELINE::{event}" | |
| if detail: | |
| msg += f" | {detail}" | |
| print(msg) | |
| # ### 8. Groq LLM via LangChain | |
| from langchain_groq import ChatGroq | |
| import re | |
| import gradio as gr | |
| llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY) | |
| # %% | |
| def get_rag_context(question: str) -> str: | |
| """ | |
| Retrieve the most relevant schema documents for the question. | |
| """ | |
| docs = rag_retriever.invoke(question) | |
| return "\n\n---\n\n".join(d.page_content for d in docs) | |
| def clean_sql(sql: str) -> str: | |
| sql = sql.strip() | |
| if "```" in sql: | |
| sql = sql.replace("```sql", "").replace("```", "").strip() | |
| return sql | |
| def extract_sql_from_markdown(text: str) -> str: | |
| """ | |
| Extract the first ```sql ... ``` block from LLM output. | |
| If not found, return the whole text. | |
| """ | |
| match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| return text.strip() | |
| def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str: | |
| """ | |
| After the given marker, return the rest of the text as explanation. | |
| """ | |
| idx = text.upper().find(marker.upper()) | |
| if idx == -1: | |
| return text.strip() | |
| return text[idx + len(marker):].strip() | |
| # 6b. General / descriptive questions | |
| GENERAL_DESC_KEYWORDS = [ | |
| "what is this dataset about", | |
| "what is this data about", | |
| "describe this dataset", | |
| "describe the dataset", | |
| "dataset overview", | |
| "data overview", | |
| "summary of the dataset", | |
| "explain this dataset", | |
| ] | |
| def is_general_question(question: str) -> bool: | |
| """ | |
| Detect high-level descriptive questions where we should answer | |
| directly from schema context instead of generating SQL. | |
| """ | |
| q = question.lower().strip() | |
| return any(key in q for key in GENERAL_DESC_KEYWORDS) | |
| def answer_general_question(question: str) -> str: | |
| """ | |
| Use the RAG schema docs to generate a rich, high-level description | |
| of the Olist dataset for conceptual questions. | |
| """ | |
| rag_context = get_rag_context(question) | |
| system_instructions = """ | |
| You are a data documentation expert. | |
| You will be given: | |
| - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships). | |
| - A high-level user question like "what is this dataset about?". | |
| Your job: | |
| - Write a clear, structured overview of the dataset. | |
| - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation). | |
| - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.). | |
| - Target a non-technical person. | |
| Do NOT write SQL. Answer in Markdown. | |
| """ | |
| prompt = ( | |
| system_instructions | |
| + "\n\n=== SCHEMA CONTEXT ===\n" | |
| + rag_context | |
| + "\n\n=== USER QUESTION ===\n" | |
| + question | |
| + "\n\nDetailed dataset overview:" | |
| ) | |
| log_prompt("GENERAL_DATASET_QUESTION", prompt) | |
| response = llm.invoke(prompt) | |
| log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content) | |
| return response.content.strip() | |
| # %% | |
| # ### 9. SQL execution / validation | |
| def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]: | |
| """ | |
| Execute SQL on SQLite and return a DataFrame, else return an error. | |
| """ | |
| try: | |
| df = pd.read_sql_query(sql, connection) | |
| return df, None | |
| except Exception as e: | |
| return None, str(e) | |
| def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]: | |
| """ | |
| Basic SQL validator: | |
| - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues. | |
| """ | |
| try: | |
| cursor = connection.cursor() | |
| cursor.execute(f"EXPLAIN QUERY PLAN {sql}") | |
| return True, None | |
| except Exception as e: | |
| return False, str(e) | |
| # %% | |
| # ### 10. SQL Generation + Repair with LLM | |
| def review_and_correct_sql_with_llm( | |
| question: str, | |
| generated_sql: str, | |
| user_feedback_comment: str, | |
| validation_payload: dict, | |
| decision_reason: str, | |
| rag_context: str, | |
| ) -> tuple[str, str]: | |
| """ | |
| Fix SQL ONLY after LLM has decided it is necessary. | |
| """ | |
| log_pipeline("SQL_CORRECTION_START") | |
| evidence_md = "" | |
| for chk in validation_payload.get("equivalence", []): | |
| evidence_md += chk["df"].to_markdown(index=False) + "\n\n" | |
| prompt = f""" | |
| Original question: | |
| {question} | |
| Original SQL: | |
| ```sql | |
| {generated_sql} | |
| User feedback: | |
| {user_feedback_comment} | |
| Reason SQL is considered incorrect: | |
| {decision_reason} | |
| Validation evidence: | |
| {evidence_md} | |
| Schema: | |
| {rag_context} | |
| TASK: | |
| Fix the SQL ONLY if necessary | |
| If SQL is already correct, return it unchanged | |
| Explain your reasoning | |
| Return corrected SQL and explanation. | |
| """ | |
| resp = llm.invoke(prompt) | |
| corrected_sql = extract_sql_from_markdown(resp.content) or generated_sql | |
| corrected_sql = clean_sql(corrected_sql) | |
| explanation = extract_explanation_after_marker(resp.content, "EXPLANATION:") | |
| log_pipeline( | |
| "SQL_CORRECTION_DONE", | |
| f"sql_changed={corrected_sql.strip() != generated_sql.strip()}" | |
| ) | |
| return corrected_sql, explanation | |
| # Generate SQL | |
| def generate_sql(question: str, rag_context: str) -> str: | |
| """ | |
| Generate SQL using LLM, but also pull in any past user feedback | |
| (corrected SQL) for this question as external guidance. | |
| """ | |
| # External correction from past feedback, if any | |
| with get_analytics_conn() as conn: | |
| try: | |
| last_fb = get_last_feedback_for_question(conn, question) | |
| except sqlite3.OperationalError: | |
| init_feedback_table(conn) | |
| last_fb = None | |
| if last_fb and last_fb.get("corrected_sql"): | |
| previous_feedback_block = f""" | |
| EXTERNAL USER FEEDBACK FROM PAST RUNS: | |
| Previous generated SQL: | |
| {last_fb['generated_sql']} | |
| Corrected SQL (preferred reference): | |
| {last_fb['corrected_sql']} | |
| User comment & prior explanation: | |
| {last_fb.get('comment') or '(none)'} | |
| You must avoid repeating the same mistake and should follow the logic of | |
| the corrected SQL when appropriate, while still reasoning from the schema. | |
| """ | |
| else: | |
| previous_feedback_block = "" | |
| system_instructions = f""" | |
| You are a senior data analyst writing SQL for a SQLite database. | |
| You will be given: | |
| - A description of available tables, columns, and relationships (schema + YAML metadata). | |
| - A natural language question from the user. | |
| Your job: | |
| - Write ONE valid SQLite SQL query that answers the question. | |
| - ONLY use tables and columns that exist in the schema_context. | |
| - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions. | |
| - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations. | |
| - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator, | |
| and round to 2 decimals when appropriate. | |
| {previous_feedback_block} | |
| """ | |
| prompt = ( | |
| system_instructions | |
| + "\n\n=== RAG CONTEXT ===\n" | |
| + rag_context | |
| + "\n\n=== USER QUESTION ===\n" | |
| + question | |
| + "\n\nSQL query:" | |
| ) | |
| log_prompt("SQL_GENERATION", prompt) | |
| response = llm.invoke(prompt) | |
| log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content) | |
| sql = clean_sql(response.content) | |
| return sql | |
| # %% | |
| # Repair SQL | |
| def repair_sql( | |
| question: str, | |
| rag_context: str, | |
| bad_sql: str, | |
| error_message: str, | |
| ) -> str: | |
| """ | |
| Ask the LLM to correct a failing SQL query. | |
| """ | |
| system_instructions = """ | |
| You are a senior data analyst fixing an existing SQL query for a SQLite database. | |
| You will be given: | |
| - Schema context (tables, columns, relationships). | |
| - The user's question. | |
| - A previously generated SQL query that failed. | |
| - The SQLite error message. | |
| Your job: | |
| - Diagnose why the query failed. | |
| - Rewrite ONE valid SQLite SQL query that answers the question. | |
| - ONLY use tables and columns that exist in the schema_context. | |
| - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions. | |
| - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations. | |
| - Return ONLY the corrected SQL query, no explanation or markdown. | |
| """ | |
| prompt = ( | |
| system_instructions | |
| + "\n\n=== RAG CONTEXT ===\n" | |
| + rag_context | |
| + "\n\n=== USER QUESTION ===\n" | |
| + question | |
| + "\n\n=== PREVIOUS (FAILING) SQL ===\n" | |
| + bad_sql | |
| + "\n\n=== SQLITE ERROR ===\n" | |
| + error_message | |
| + "\n\nCorrected SQL query:" | |
| ) | |
| log_prompt("SQL_REPAIR", prompt) | |
| response = llm.invoke(prompt) | |
| log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content) | |
| sql = clean_sql(response.content) | |
| return sql | |
| # %% | |
| ### 11. Result summarization | |
| def summarize_results( | |
| question: str, | |
| sql: str, | |
| df: Optional[pd.DataFrame], | |
| rag_context: str, | |
| error: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Ask the LLM to produce a concise, human-readable answer. | |
| """ | |
| system_instructions = """ | |
| You are a senior data analyst. | |
| You will be given: | |
| The user's question. | |
| The final SQL that was executed. | |
| A small preview of the query result (as a Markdown table, if available). | |
| Optional error information if the query failed. | |
| Your job: | |
| Provide a clear, concise answer in Markdown. | |
| If the result is numeric / aggregated, explain what it means in business terms. | |
| If there was an error, explain it simply and suggest how the user could rephrase. | |
| Do NOT show raw SQL unless it is helpful to the user. | |
| """ | |
| # Build a markdown table preview if we have data | |
| if df is not None and not df.empty: | |
| preview_rows = min(len(df), 50) | |
| df_preview_md = df.head(preview_rows).to_markdown(index=False) | |
| else: | |
| df_preview_md = "(no rows returned)" | |
| prompt = ( | |
| system_instructions | |
| + "\n\n=== USER QUESTION ===\n" | |
| + question | |
| + "\n\n=== EXECUTED SQL ===\n" | |
| + sql | |
| + "\n\n=== QUERY RESULT PREVIEW ===\n" | |
| + df_preview_md | |
| + "\n\n=== RAG CONTEXT (schema) ===\n" | |
| + rag_context | |
| ) | |
| if error: | |
| prompt += "\n\n=== ERROR ===\n" + error | |
| # Logging helpers assumed to exist | |
| log_prompt("RESULT_SUMMARY", prompt) | |
| response = llm.invoke(prompt) | |
| log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content) | |
| return response.content.strip() | |
| # %% | |
| def backend_pipeline(question: str): | |
| """ | |
| STRICT cache-first backend. | |
| ALWAYS returns EXACTLY 10 values. | |
| """ | |
| if not question or not question.strip(): | |
| return ( | |
| "Please type a question.", | |
| pd.DataFrame(), | |
| "", | |
| "", | |
| "", | |
| "", | |
| {}, | |
| 4, | |
| "**Feedback attempts remaining: 4**", | |
| "", | |
| ) | |
| attempts_left = 4 | |
| attempts_text = f"**Feedback attempts remaining: {attempts_left}**" | |
| store = get_sql_cache_store() | |
| # ---------------- CACHE LOOKUP ---------------- | |
| cached = retrieve_exact_cached_sql(question, store) | |
| if cached is None: | |
| cached = retrieve_best_cached_sql(question, store, max_distance=0.25) | |
| # ---------------- CACHE HIT ---------------- | |
| if cached: | |
| increment_cache_views(cached["metadata"], store) | |
| rag_context = get_rag_context(question) | |
| sql = cached["sql"] | |
| with get_analytics_conn() as conn: | |
| df, err = execute_sql(sql, conn) | |
| if err: | |
| df = pd.DataFrame() | |
| md = cached["metadata"] | |
| answer_md = cached.get("answer_md", "") | |
| debug_md = f""" | |
| ### Query Provenance (Cache) | |
| **Matched question:** `{cached.get("matched_question", "")}` | |
| **Success rate:** {md.get("success_rate", 0.5):.2f} | |
| **Good / Bad:** {md.get("good_count", 0)} / {md.get("bad_count", 0)} | |
| **Views:** {md.get("views", 0)} | |
| **SQL used:** | |
| {sql} | |
| """ | |
| return ( | |
| answer_md, | |
| df, | |
| sql, | |
| rag_context, | |
| question, | |
| answer_md, | |
| df_to_state(df), | |
| attempts_left, | |
| attempts_text, | |
| debug_md, | |
| ) | |
| # ---------------- LLM FLOW ---------------- | |
| rag_context = get_rag_context(question) | |
| sql = generate_sql(question, rag_context) | |
| with get_analytics_conn() as conn: | |
| is_valid, err = validate_sql(sql, conn) | |
| if not is_valid and err: | |
| sql = repair_sql(question, rag_context, sql, err) | |
| with get_analytics_conn() as conn: | |
| df, exec_error = execute_sql(sql, conn) | |
| if exec_error: | |
| df = pd.DataFrame() | |
| answer_md = summarize_results( | |
| question=question, | |
| sql=sql, | |
| df=df if not exec_error else None, | |
| rag_context=rag_context, | |
| error=exec_error, | |
| ) | |
| # Cache LLM result | |
| cache_sql_answer_dedup( | |
| question=question, | |
| sql=sql, | |
| answer_md=answer_md, | |
| metadata={ | |
| "good_count": 0, | |
| "bad_count": 0, | |
| "total_feedbacks": 0, | |
| "success_rate": 0.5, | |
| "views": 1, | |
| "saved_at": time.time(), | |
| }, | |
| store=store, | |
| ) | |
| debug_md = f""" | |
| Query Provenance (LLM Generated) | |
| Source: LLM | |
| Execution status: {"Error" if exec_error else "Success"} | |
| SQL used: | |
| {sql} | |
| """ | |
| return ( | |
| answer_md, | |
| df, | |
| sql, | |
| rag_context, | |
| question, | |
| answer_md, | |
| df_to_state(df), | |
| attempts_left, | |
| attempts_text, | |
| debug_md, | |
| ) | |
| # %% | |
| import re | |
| def looks_like_sql(text: str) -> bool: | |
| if not text: | |
| return False | |
| return bool(re.search( | |
| r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b", | |
| text, | |
| flags=re.I, | |
| )) | |
| def is_feedback_sufficient(feedback_text: str) -> bool: | |
| """ | |
| Decide whether feedback is actionable WITHOUT LLM. | |
| """ | |
| if not feedback_text: | |
| return False | |
| text = feedback_text.strip().lower() | |
| # Long feedback → sufficient | |
| if len(text) >= 60: | |
| return True | |
| # SQL pasted | |
| if looks_like_sql(text): | |
| return True | |
| # Numbers + comparison language → sufficient | |
| if re.search(r"\b\d+\b", text) and any( | |
| k in text for k in [ | |
| "last year", "previous", "this year", | |
| "increase", "decrease", "higher", "lower" | |
| ] | |
| ): | |
| return True | |
| # Analytical keywords | |
| signal_words = [ | |
| "filter", "where", "year", "month", "should", "instead", | |
| "expected", "wrong", "missing", "count", "distinct", | |
| "join", "exclude", "include", "only" | |
| ] | |
| signals = sum(1 for w in signal_words if w in text) | |
| if signals >= 1 and len(text) >= 20: | |
| return True | |
| return False | |
| # %% | |
| def build_followup_prompt_for_user(sample_feedback: str = "") -> str: | |
| """ | |
| Deterministic follow-up question to ask the user when feedback is vague. | |
| Returns a friendly prompt that the UI can display to the user. | |
| """ | |
| base = ( | |
| "Thanks — I need a bit more detail to act on this feedback.\n\n" | |
| "Please tell me one (or more) of the following so I can check and correct the result:\n\n" | |
| "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n" | |
| " the **time range** (year/month), or the **filters** applied?\n" | |
| "2. If you expected a different number, what was the expected number (and how was it computed)?\n" | |
| "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n" | |
| "Examples you can copy-paste:\n" | |
| ) | |
| examples = ( | |
| "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n" | |
| "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n" | |
| "- \"Please exclude canceled orders (order_status = 'canceled').\"\n" | |
| "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n" | |
| ) | |
| hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare." | |
| prompt = base + examples + hint | |
| if sample_feedback: | |
| prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt | |
| return prompt | |
| # %% | |
| def interpret_feedback_intent(question: str, feedback_text: str) -> dict: | |
| """ | |
| Convert business feedback into a testable hypothesis. | |
| NO SQL generation here. | |
| """ | |
| system_prompt = """ | |
| You are a senior analytics reviewer. | |
| Your task: | |
| - Identify what the user is CLAIMING. | |
| - Do NOT assume they are correct. | |
| - Do NOT write SQL. | |
| Possible intents: | |
| - comparison_claim (mentions last year / previous / higher / lower) | |
| - overcount | |
| - undercount | |
| - missing_entities | |
| - wrong_filter | |
| - wrong_time_range | |
| - unclear | |
| Return JSON ONLY. | |
| """ | |
| prompt = f""" | |
| USER QUESTION: | |
| {question} | |
| USER FEEDBACK: | |
| {feedback_text} | |
| Return JSON exactly like this: | |
| {{ | |
| "intent": "comparison_claim | overcount | undercount | missing_entities | wrong_filter | wrong_time_range | unclear", | |
| "confidence": "low | medium | high" | |
| }} | |
| """ | |
| resp = llm.invoke(system_prompt + "\n\n" + prompt) | |
| print(resp) | |
| try: | |
| return json.loads(resp.content) | |
| except Exception: | |
| return { | |
| "intent": "unclear", | |
| "confidence": "low", | |
| } | |
| # %% | |
| def run_validation_queries( | |
| intent_info: dict, | |
| question: str, | |
| original_sql: str, | |
| user_feedback: str, | |
| rag_context: str, | |
| ) -> dict: | |
| """ | |
| Generate AND EXECUTE validation SQL. | |
| Separates equivalence checks vs context checks. | |
| Uses USER FEEDBACK explicitly for context validation. | |
| """ | |
| log_pipeline( | |
| "VALIDATION_START", | |
| f"intent={intent_info.get('intent')}" | |
| ) | |
| system_prompt = f""" | |
| You are validating a business claim. | |
| User intent: | |
| {intent_info.get("intent")} | |
| Original question: | |
| {question} | |
| User feedback (IMPORTANT): | |
| {user_feedback} | |
| Original SQL: | |
| {original_sql} | |
| Schema context: | |
| {rag_context} | |
| TASK: | |
| - Generate 2-4 SQL queries. | |
| - Mark EACH query explicitly as one of: | |
| - "equivalence": | |
| * Alternative formulations of the SAME question | |
| * Used to verify SQL correctness (joins, filters, logic) | |
| - "context": | |
| * Used to VERIFY or CHALLENGE the user's feedback | |
| * If the user mentions years, comparisons, growth, or missing entities, | |
| generate year-over-year or comparative queries. | |
| - DO NOT fix or modify the original SQL. | |
| - These queries are ONLY for validation and analysis. | |
| IMPORTANT: | |
| - Context queries MUST be driven by the user's feedback. | |
| - If the user claims a number, for example, values of a previous year, then generate queries that can verify that claim from the data. | |
| - If the user provides some query to test in the feedback for the original question, use it as well. Do not assume it is correct. | |
| Return JSON ONLY in this format: | |
| {{ | |
| "checks": [ | |
| {{ | |
| "type": "equivalence | context", | |
| "sql": "SELECT ...", | |
| "label": "short description of what this validates" | |
| }} | |
| ] | |
| }} | |
| """ | |
| resp = llm.invoke(system_prompt) | |
| raw = resp.content.strip() | |
| print(raw) | |
| # Strip ```json fences if present | |
| if raw.startswith("```"): | |
| raw = raw.strip("`") | |
| raw = raw.replace("json", "", 1).strip() | |
| try: | |
| parsed = json.loads(raw) | |
| except Exception as e: | |
| log_pipeline("VALIDATION_PARSE_FAILED", str(e)) | |
| return { | |
| "equivalence": [], | |
| "context": [], | |
| } | |
| equivalence = [] | |
| context = [] | |
| for chk in parsed.get("checks", []): | |
| sql = chk.get("sql") | |
| chk_type = chk.get("type") | |
| label = chk.get("label", "unnamed check") | |
| if not sql or chk_type not in ("equivalence", "context"): | |
| continue | |
| with get_analytics_conn() as conn: | |
| df, err = execute_sql(sql, conn) | |
| print(df) | |
| log_pipeline( | |
| "VALIDATION_SQL_EXECUTED", | |
| f"type={chk_type} | label={label}" | |
| ) | |
| if err or df is None or df.empty: | |
| continue | |
| payload = { | |
| "label": label, | |
| "sql": sql, | |
| "df": df, | |
| } | |
| if chk_type == "equivalence": | |
| equivalence.append(payload) | |
| else: | |
| context.append(payload) | |
| log_pipeline( | |
| "VALIDATION_DONE", | |
| f"equivalence={len(equivalence)}, context={len(context)}" | |
| ) | |
| return { | |
| "equivalence": equivalence, | |
| "context": context, | |
| } | |
| # %% | |
| def safe_json_load(text: str) -> dict | None: | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| # Try extracting first JSON block | |
| match = re.search(r"\{[\s\S]*\}", text) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except Exception: | |
| return None | |
| return None | |
| # %% | |
| def decide_after_validation_with_llm( | |
| question: str, | |
| original_sql: str, | |
| user_feedback: str, | |
| intent_info: dict, | |
| original_df: pd.DataFrame, | |
| validation_payload: dict, | |
| rag_context: str, | |
| ) -> dict: | |
| """ | |
| Use LLM to decide whether SQL must change or explanation is sufficient. | |
| Returns a structured decision. | |
| """ | |
| log_pipeline("DECISION_LLM_START") | |
| # Build evidence | |
| evidence_md = "### Original Result\n" | |
| evidence_md += original_df.to_markdown(index=False) | |
| for i, chk in enumerate(validation_payload.get("equivalence", []), 1): | |
| evidence_md += f"\n\n### Equivalence Check {i}: {chk['label']}\n" | |
| evidence_md += chk["df"].to_markdown(index=False) | |
| for i, chk in enumerate(validation_payload.get("context", []), 1): | |
| evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n" | |
| evidence_md += chk["df"].to_markdown(index=False) | |
| system_prompt = """ | |
| You are a senior analytics and data analyst. | |
| Your task: | |
| - Decide whether the ORIGINAL SQL is logically correct or not, given the original question, the initial result, the feedback and the evidences of the equivalence check and context shown to you. | |
| - Do NOT judge based on expectations or growth alone. | |
| - Different definitions (e.g. customers vs customers-with-orders) are NOT errors. | |
| - Only mark SQL incorrect if evidences shows a true flaw in the results or the logic. | |
| CRITICAL RULES: | |
| - You must return VALID JSON. | |
| - DO NOT add explnations outside JSON. | |
| - DO NO use markdown. | |
| - DO NOT include extra text. | |
| Return JSON ONLY in this format: | |
| { | |
| "decision": "correct_sql | not_correct_sql", | |
| "confidence": "high | medium | low", | |
| "reason": "provide an explanation here" | |
| } | |
| """ | |
| prompt = f""" | |
| Original question: | |
| {question} | |
| Original SQL: | |
| {original_sql} | |
| User feedback: | |
| {user_feedback} | |
| Interpreted intent: | |
| {intent_info} | |
| Schema context: | |
| {rag_context} | |
| Below is the validation evidence. | |
| """ | |
| resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md) | |
| parsed = safe_json_load(resp.content) | |
| if not parsed: | |
| decision = { | |
| "decision": "correct_sql", | |
| "confidence": "low", | |
| "reason": "LLM response was not valid JSON" | |
| } | |
| else: | |
| decision = parsed | |
| log_pipeline( | |
| "DECISION_LLM_DONE", | |
| f"decision={decision.get('decision')} | confidence={decision.get('confidence')} | reason ={decision.get('reason')}" | |
| ) | |
| return decision | |
| # %% | |
| def explain_validation_with_llm( | |
| question: str, | |
| feedback: str, | |
| original_df: pd.DataFrame, | |
| validation_payload: dict, | |
| schema_context: str, | |
| ) -> str: | |
| """ | |
| Business explanation of validation results. | |
| """ | |
| log_pipeline("EXPLAIN_START") | |
| evidence_md = ( | |
| "## Pipeline: Feedback --> Validation -> Explanation\n\n" | |
| "### Baseline — Original Result\n" | |
| ) | |
| evidence_md += original_df.to_markdown(index=False) | |
| for i, chk in enumerate(validation_payload.get("context", []), 1): | |
| evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n" | |
| evidence_md += chk["df"].to_markdown(index=False) | |
| for i, chk in enumerate(validation_payload.get("equivalence", []), 1): | |
| evidence_md += f"\n\n### Verification Check {i}: {chk['label']}\n" | |
| evidence_md += chk["df"].to_markdown(index=False) | |
| system_prompt = """ | |
| You are a senior data analyst explaining results to a business user. | |
| Rules: | |
| - DO NOT mention SQL mechanics | |
| - Compare numbers explicitly | |
| - Explain trends year-over-year. | |
| - Explain what the actual numbers are and what it means for the business. | |
| - If the user is concerned about the SQL or the numbers in their feedback, based on verification and context checks, explain the results to them. | |
| """ | |
| prompt = f""" | |
| Original question: | |
| {question} | |
| User feedback: | |
| {feedback} | |
| Schema of the data: | |
| {schema_context} | |
| """ | |
| resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md) | |
| log_pipeline("EXPLAIN_DONE") | |
| return resp.content.strip() | |
| # %% | |
| # %% | |
| def feedback_pipeline_interactive( | |
| feedback_rating: str, | |
| feedback_comment: str, | |
| last_sql: str, | |
| last_rag_context: str, | |
| last_question: str, | |
| last_answer_md: str, | |
| last_df_state: dict, | |
| feedback_sql: str, | |
| attempts_left: int, | |
| ): | |
| """ | |
| ALWAYS returns EXACTLY 10 values | |
| NO booleans | |
| followup_signal ∈ {"FOLLOWUP", "NONE"} | |
| """ | |
| last_df = state_to_df(last_df_state) | |
| rating = (feedback_rating or "").strip().lower() | |
| comment = (feedback_comment or "").strip() | |
| attempts_left = max(0, int(attempts_left or 0)) | |
| # ---------- INVALID RATING ---------- | |
| if rating not in {"correct", "wrong"}: | |
| return ( | |
| last_answer_md, | |
| last_df, | |
| last_sql, | |
| last_rag_context, | |
| last_question, | |
| last_answer_md, | |
| df_to_state(last_df), | |
| "NONE", | |
| "", | |
| attempts_left, | |
| ) | |
| # ---------- CORRECT ---------- | |
| if rating == "correct": | |
| update_cache_on_feedback( | |
| question=last_question, | |
| original_doc_md=find_cached_doc_by_sql( | |
| last_question, last_sql, get_sql_cache_store() | |
| ), | |
| user_marked_good=True, | |
| llm_corrected_sql=None, | |
| llm_corrected_answer_md=None, | |
| store=get_sql_cache_store(), | |
| ) | |
| # ADD: persist feedback | |
| with get_analytics_conn() as conn: | |
| record_feedback( | |
| conn=conn, | |
| question=last_question, | |
| generated_sql=last_sql, | |
| model_answer=last_answer_md, | |
| rating="good", | |
| comment=comment or None, | |
| corrected_sql=None, | |
| ) | |
| md = "**Confirmed correct by user**\n\n" + last_answer_md | |
| return ( | |
| md, | |
| last_df, | |
| last_sql, | |
| last_rag_context, | |
| last_question, | |
| md, | |
| df_to_state(last_df), | |
| "NONE", | |
| "", | |
| attempts_left, | |
| ) | |
| # ---------- WRONG ---------- | |
| attempts_left = max(0, attempts_left - 1) | |
| if not is_feedback_sufficient(comment): | |
| return ( | |
| last_answer_md, | |
| last_df, | |
| last_sql, | |
| last_rag_context, | |
| last_question, | |
| last_answer_md, | |
| df_to_state(last_df), | |
| "FOLLOWUP", | |
| build_followup_prompt_for_user(comment), | |
| attempts_left, | |
| ) | |
| # ---------- VALIDATION ---------- | |
| intent_info = interpret_feedback_intent(last_question, comment) | |
| validation = run_validation_queries( | |
| intent_info=intent_info, | |
| question=last_question, | |
| original_sql=last_sql, | |
| user_feedback=comment, | |
| rag_context=last_rag_context, | |
| ) | |
| decision = decide_after_validation_with_llm( | |
| question=last_question, | |
| original_sql=last_sql, | |
| user_feedback=comment, | |
| intent_info=intent_info, | |
| original_df=last_df, | |
| validation_payload=validation, | |
| rag_context=last_rag_context, | |
| ) | |
| # ---------- EXPLAIN ---------- | |
| if decision.get("decision") == "correct_sql": | |
| explanation = explain_validation_with_llm( | |
| question=last_question, | |
| feedback=comment, | |
| original_df=last_df, | |
| validation_payload=validation, | |
| schema_context=last_rag_context, | |
| ) | |
| return ( | |
| explanation, | |
| last_df, | |
| last_sql, | |
| last_rag_context, | |
| last_question, | |
| explanation, | |
| df_to_state(last_df), | |
| "NONE", | |
| "", | |
| attempts_left, | |
| ) | |
| # ---------- SQL CORRECTION ---------- | |
| corrected_sql, explanation = review_and_correct_sql_with_llm( | |
| question=last_question, | |
| generated_sql=last_sql, | |
| user_feedback_comment=comment, | |
| validation_payload=validation, | |
| decision_reason=decision.get("reason", ""), | |
| rag_context=last_rag_context, | |
| ) | |
| with get_analytics_conn() as conn: | |
| df_new, err = execute_sql(corrected_sql, conn) | |
| # ADD: persist negative feedback + corrected SQL | |
| record_feedback( | |
| conn=conn, | |
| question=last_question, | |
| generated_sql=last_sql, | |
| model_answer=last_answer_md, | |
| rating="bad", | |
| comment=comment, | |
| corrected_sql=corrected_sql, | |
| ) | |
| final_md = "**SQL corrected**\n\n" + summarize_results( | |
| question=last_question, | |
| sql=corrected_sql, | |
| df=df_new if not err else None, | |
| rag_context=last_rag_context, | |
| error=err, | |
| ) | |
| return ( | |
| final_md, | |
| df_new if not err else pd.DataFrame(), | |
| corrected_sql, | |
| last_rag_context, | |
| last_question, | |
| final_md, | |
| df_to_state(df_new if not err else pd.DataFrame()), | |
| "NONE", | |
| "", | |
| attempts_left, | |
| ) | |
| # %% | |
| # %% | |
| import gradio as gr | |
| # ========================================================= | |
| # Helpers | |
| # ========================================================= | |
| def attempts_text(attempts: int) -> str: | |
| if attempts <= 0: | |
| return "**Feedback exhausted. Please ask a new question.**" | |
| return f"**Feedback attempts remaining: {attempts}**" | |
| def conversation_choices(): | |
| convs = list_conversations() | |
| return [(c["title"], c["conversation_id"]) for c in convs] | |
| def load_conversation_dropdown(): | |
| return gr.update(choices=conversation_choices()) | |
| # ========================================================= | |
| # CHAT TURN FIXED | |
| # ========================================================= | |
| def chat_turn(user_input, history, state): | |
| history = history or [] | |
| # --- Create conversation if new --- | |
| if state.get("conversation_id") is None: | |
| ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M") | |
| title = f"{ts} · {user_input[:50]}" | |
| cid = create_conversation(title=title) | |
| state["conversation_id"] = cid | |
| else: | |
| cid = state["conversation_id"] | |
| append_message(cid, "user", user_input) | |
| history.append({"role": "user", "content": user_input}) | |
| history.append({"role": "assistant", "content": "🔎 Thinking…"}) | |
| ( | |
| answer_md, | |
| df, | |
| sql, | |
| rag, | |
| q, | |
| ans_state, | |
| df_state, | |
| attempts, | |
| attempts_md, | |
| debug_md, | |
| ) = backend_pipeline(user_input) | |
| history.pop() | |
| history.append({"role": "assistant", "content": answer_md}) | |
| append_message(cid, "assistant", answer_md) | |
| new_state = { | |
| "conversation_id": cid, | |
| "last_sql": sql, | |
| "last_rag": rag, | |
| "last_question": q, | |
| "last_answer": ans_state, | |
| "last_df": df_state, | |
| "attempts": attempts, | |
| } | |
| save_state(cid, new_state) | |
| return ( | |
| history, | |
| new_state, | |
| attempts_md, | |
| debug_md, | |
| df, # THIS FIXES THE TABLE | |
| gr.update(choices=conversation_choices()), | |
| ) | |
| # ========================================================= | |
| # FEEDBACK TURN | |
| # ========================================================= | |
| def feedback_turn(feedback_rating, feedback_text, history, state): | |
| history = history or [] | |
| cid = state.get("conversation_id") | |
| # ---------- Attempts exhausted ---------- | |
| if state["attempts"] <= 0: | |
| history.append({ | |
| "role": "assistant", | |
| "content": "Feedback attempts are exhausted. Please ask a new question." | |
| }) | |
| return history, state, attempts_text(0), pd.DataFrame() | |
| # ---------- BAD without text ---------- | |
| if feedback_rating == "wrong" and not (feedback_text or "").strip(): | |
| new_attempts = max(0, state["attempts"] - 1) | |
| history.append({ | |
| "role": "assistant", | |
| "content": "Please provide details before submitting negative feedback." | |
| }) | |
| state["attempts"] = new_attempts | |
| save_state(cid, state) | |
| return history, state, attempts_text(new_attempts), state_to_df(state["last_df"]) | |
| history.append({"role": "user", "content": f"{feedback_rating.upper()} feedback"}) | |
| history.append({"role": "assistant", "content": "🧪 Validating feedback…"}) | |
| ( | |
| answer_md, | |
| df, | |
| sql_new, | |
| rag_new, | |
| q_new, | |
| ans_state, | |
| df_state, | |
| followup_flag, | |
| followup_prompt, | |
| attempts_new, | |
| ) = feedback_pipeline_interactive( | |
| feedback_rating=feedback_rating, | |
| feedback_comment=feedback_text, | |
| last_sql=state["last_sql"], | |
| last_rag_context=state["last_rag"], | |
| last_question=state["last_question"], | |
| last_answer_md=state["last_answer"], | |
| last_df_state=state["last_df"], | |
| feedback_sql="", | |
| attempts_left=state["attempts"], | |
| ) | |
| history.pop() | |
| history.append({ | |
| "role": "assistant", | |
| "content": followup_prompt if followup_flag == "FOLLOWUP" else answer_md | |
| }) | |
| append_message(cid, "assistant", answer_md) | |
| new_state = { | |
| "conversation_id": cid, | |
| "last_sql": sql_new, | |
| "last_rag": rag_new, | |
| "last_question": q_new, | |
| "last_answer": ans_state, | |
| "last_df": df_state, | |
| "attempts": attempts_new, | |
| } | |
| save_state(cid, new_state) | |
| return history, new_state, attempts_text(attempts_new), df | |
| # ========================================================= | |
| # RESTORE / NEW CHAT | |
| # ========================================================= | |
| def restore_conversation(conversation_id): | |
| msgs = load_messages(conversation_id) | |
| st = load_state(conversation_id) | |
| st["conversation_id"] = conversation_id | |
| df = state_to_df(st.get("last_df", {})) | |
| return msgs, st, attempts_text(st.get("attempts", 4)), df | |
| def new_chat(): | |
| cid = create_conversation("New conversation") | |
| return [], { | |
| "conversation_id": cid, | |
| "last_sql": "", | |
| "last_rag": "", | |
| "last_question": "", | |
| "last_answer": "", | |
| "last_df": {}, | |
| "attempts": 4, | |
| }, attempts_text(4), pd.DataFrame() | |
| # ========================================================= | |
| # UI | |
| # ========================================================= | |
| with gr.Blocks(analytics_enabled=False) as demo: | |
| gr.Markdown("# Analytics Assistant Chatbot") | |
| with gr.Row(): | |
| conversation_selector = gr.Dropdown( | |
| label="Conversation history", | |
| choices=[], | |
| interactive=True, | |
| ) | |
| new_chat_btn = gr.Button("New chat") | |
| chat = gr.Chatbot(height=520) | |
| attempts_md = gr.Markdown(value=attempts_text(4)) | |
| show_debug = gr.Checkbox(label="Show SQL / Cache info", value=False) | |
| debug_panel = gr.Markdown(visible=False) | |
| result_df = gr.Dataframe( | |
| label="Query Results", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| state = gr.State({ | |
| "conversation_id": None, | |
| "last_sql": "", | |
| "last_rag": "", | |
| "last_question": "", | |
| "last_answer": "", | |
| "last_df": {}, | |
| "attempts": 4, | |
| }) | |
| user_input = gr.Textbox(placeholder="Ask a question…", lines=1, show_label=False) | |
| feedback_input = gr.Textbox( | |
| placeholder="Details (required if result is wrong)…", | |
| lines=2, | |
| show_label=False, | |
| ) | |
| with gr.Row(): | |
| gr.Column(scale=6) | |
| good_btn = gr.Button("Good", size="sm") | |
| bad_btn = gr.Button("Bad", size="sm") | |
| show_debug.change( | |
| lambda show, md: gr.update(value=md, visible=show), | |
| inputs=[show_debug, debug_panel], | |
| outputs=debug_panel, | |
| ) | |
| user_input.submit( | |
| chat_turn, | |
| inputs=[user_input, chat, state], | |
| outputs=[chat, state, attempts_md, debug_panel, result_df, conversation_selector], | |
| queue=False, | |
| ) | |
| good_btn.click( | |
| lambda fb, h, s: feedback_turn("correct", fb, h, s), | |
| inputs=[feedback_input, chat, state], | |
| outputs=[chat, state, attempts_md, result_df], | |
| ) | |
| bad_btn.click( | |
| lambda fb, h, s: feedback_turn("wrong", fb, h, s), | |
| inputs=[feedback_input, chat, state], | |
| outputs=[chat, state, attempts_md, result_df], | |
| ) | |
| conversation_selector.change( | |
| restore_conversation, | |
| inputs=conversation_selector, | |
| outputs=[chat, state, attempts_md, result_df], | |
| ) | |
| new_chat_btn.click( | |
| new_chat, | |
| outputs=[chat, state, attempts_md, result_df], | |
| ) | |
| demo.load(load_conversation_dropdown, outputs=conversation_selector) | |
| # %% | |
| if __name__ == "__main__": | |
| demo.launch() | |