|
|
| import warnings
|
| warnings.filterwarnings("ignore")
|
|
|
| import os
|
| import sqlite3
|
| from typing import Tuple, Optional, List, Dict, Any
|
|
|
| import matplotlib.pyplot as plt
|
|
|
| import gradio as gr
|
| import pandas as pd
|
|
|
| import uuid
|
| from dotenv import load_dotenv
|
| from langchain_groq import ChatGroq
|
| from langchain_community.embeddings import HuggingFaceEmbeddings
|
| from langchain_community.vectorstores import Chroma
|
|
|
|
|
|
|
|
|
|
|
| load_dotenv()
|
|
|
| DB_PATH = "olist.db"
|
| DATA_DIR = "data"
|
|
|
|
|
| GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
|
|
|
| GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| if not GROQ_API_KEY:
|
| raise ValueError("GROQ_API_KEY not found.")
|
|
|
|
|
| EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
|
|
|
|
|
|
|
| import datetime
|
|
|
| PROMPT_LOG_PATH = "prompt_logs.txt"
|
|
|
| def log_prompt(tag: str, prompt: str) -> None:
|
| """
|
| Append the full prompt to a log file, with a tag and timestamp.
|
| """
|
| timestamp = datetime.datetime.now().isoformat(timespec="seconds")
|
| header = f"\n\n================ {tag} @ {timestamp} ================\n"
|
| with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
|
| f.write(header)
|
| f.write(prompt)
|
| f.write("\n================ END PROMPT ================\n")
|
|
|
|
|
|
|
| RUN_LOG_PATH = "run_logs.txt"
|
|
|
| def log_run_event(tag: str, content: str) -> None:
|
| """
|
| Append model response, final SQL, and error info into a run log.
|
| """
|
| timestamp = datetime.datetime.now().isoformat(timespec="seconds")
|
| header = f"\n\n================ {tag} @ {timestamp} ================\n"
|
| with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
|
| f.write(header)
|
| f.write(content)
|
| f.write("\n================ END EVENT ================\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
| def init_db() -> sqlite3.Connection:
|
| """
|
| Load all CSVs from the data/ folder into a local SQLite DB.
|
| Table names are derived from file names (without .csv).
|
| """
|
| conn = sqlite3.connect(DB_PATH, check_same_thread=False)
|
|
|
| csv_files = [
|
| "olist_customers_dataset.csv",
|
| "olist_orders_dataset.csv",
|
| "olist_order_items_dataset.csv",
|
| "olist_products_dataset.csv",
|
| "olist_order_reviews_dataset.csv",
|
| "olist_order_payments_dataset.csv",
|
| "product_category_name_translation.csv",
|
| "olist_sellers_dataset.csv",
|
| "olist_geolocation_dataset.csv",
|
| ]
|
|
|
| for fname in csv_files:
|
| path = os.path.join(DATA_DIR, fname)
|
| print(path)
|
| if not os.path.exists(path):
|
| print(f"CSV not found: {path} - skipping")
|
| continue
|
|
|
| table_name = os.path.splitext(fname)[0]
|
| print(f"Loading {path} into table {table_name}...")
|
| df = pd.read_csv(path)
|
| df.to_sql(table_name, conn, if_exists="replace", index=False)
|
|
|
| return conn
|
|
|
| conn = init_db()
|
|
|
|
|
|
|
|
|
|
|
|
|
| OLIST_DOCS: Dict[str, Dict[str, Any]] = {
|
| "olist_customers_dataset": {
|
| "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
|
| "columns": {
|
| "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
|
| "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
|
| "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
|
| "customer_city": "Customer's city as captured at the time of the order or registration.",
|
| "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
|
| }
|
| },
|
| "olist_orders_dataset": {
|
| "description": "Customer orders placed on the Olist marketplace, one row per order.",
|
| "columns": {
|
| "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
|
| "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
|
| "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
|
| "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
|
| "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
|
| "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
|
| "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
|
| "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
|
| }
|
| },
|
| "olist_order_items_dataset": {
|
| "description": "Order line items, one row per product per order.",
|
| "columns": {
|
| "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
|
| "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
|
| "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
|
| "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
|
| "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
|
| "price": "Item price paid by the customer for this line (in BRL, not including freight).",
|
| "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
|
| }
|
| },
|
| "olist_products_dataset": {
|
| "description": "Product catalog with physical and category attributes, one row per product.",
|
| "columns": {
|
| "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
|
| "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
|
| "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
|
| "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
|
| "product_photos_qty": "Number of product images associated with the listing.",
|
| "product_weight_g": "Product weight in grams.",
|
| "product_length_cm": "Product length in centimeters (package dimension).",
|
| "product_height_cm": "Product height in centimeters (package dimension).",
|
| "product_width_cm": "Product width in centimeters (package dimension)."
|
| }
|
| },
|
| "olist_order_reviews_dataset": {
|
| "description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
|
| "columns": {
|
| "review_id": "Primary key. Unique identifier for each review record.",
|
| "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
|
| "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
|
| "review_comment_title": "Optional short text title or summary of the review.",
|
| "review_comment_message": "Optional detailed free-text comment describing the customer experience.",
|
| "review_creation_date": "Date when the customer created the review.",
|
| "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
|
| }
|
| },
|
| "olist_order_payments_dataset": {
|
| "description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
|
| "columns": {
|
| "order_id": "Foreign key to olist_orders_dataset.order_id.",
|
| "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
|
| "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
|
| "payment_installments": "Number of installments chosen by the customer for this payment.",
|
| "payment_value": "Monetary amount paid in this payment record (in BRL)."
|
| }
|
| },
|
| "product_category_name_translation": {
|
| "description": "Lookup table mapping Portuguese product category names to English equivalents.",
|
| "columns": {
|
| "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
|
| "product_category_name_english": "Translated product category name in English."
|
| }
|
| },
|
| "olist_sellers_dataset": {
|
| "description": "Seller master data, one row per seller operating on the Olist marketplace.",
|
| "columns": {
|
| "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
|
| "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
|
| "seller_city": "City where the seller is located.",
|
| "seller_state": "State where the seller is located (two-letter Brazilian state code)."
|
| }
|
| },
|
| "olist_geolocation_dataset": {
|
| "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
|
| "columns": {
|
| "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
|
| "geolocation_lat": "Latitude coordinate of the location.",
|
| "geolocation_lng": "Longitude coordinate of the location.",
|
| "geolocation_city": "City name for the location.",
|
| "geolocation_state": "State code for the location (two-letter Brazilian state code)."
|
| }
|
| }
|
| }
|
|
|
|
|
|
|
|
|
| def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
|
| """
|
| Introspect SQLite tables, columns and foreign key relationships.
|
| """
|
| cursor = connection.cursor()
|
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| tables = [row[0] for row in cursor.fetchall()]
|
|
|
| metadata: Dict[str, Any] = {"tables": {}}
|
|
|
| for table in tables:
|
| cursor.execute(f"PRAGMA table_info('{table}')")
|
| cols = cursor.fetchall()
|
|
|
|
|
| table_docs = OLIST_DOCS.get(table, {})
|
| table_desc: str = table_docs.get(
|
| "description",
|
| f"Table '{table}' from Olist dataset."
|
| )
|
| column_docs: Dict[str, str] = table_docs.get("columns", {})
|
|
|
| columns_meta: Dict[str, str] = {}
|
| for c in cols:
|
| col_name = c[1]
|
| col_type = c[2] or "TEXT"
|
|
|
| if col_name in column_docs:
|
| columns_meta[col_name] = column_docs[col_name]
|
| else:
|
| columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
|
|
|
| cursor.execute(f"PRAGMA foreign_key_list('{table}')")
|
|
|
|
|
| fk_rows = cursor.fetchall()
|
| relationships: List[str] = []
|
| for fk in fk_rows:
|
| ref_table = fk[2]
|
| from_col = fk[3]
|
| to_col = fk[4]
|
| relationships.append(
|
| f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)"
|
| )
|
|
|
| metadata["tables"][table] = {
|
| "description": table_desc,
|
| "columns": columns_meta,
|
| "relationships": relationships,
|
| }
|
|
|
| return metadata
|
|
|
| schema_metadata = extract_schema_metadata(conn)
|
| schema_metadata
|
|
|
|
|
|
|
| def build_schema_yaml(metadata: Dict[str, Any]) -> str:
|
| """
|
| Render metadata dict into a YAML-style string.
|
| """
|
| lines: List[str] = ["tables:"]
|
| for tname, tinfo in metadata["tables"].items():
|
| lines.append(f" {tname}:")
|
| desc = tinfo.get("description", "").replace('"', "'")
|
| lines.append(f' description: "{desc}"')
|
| lines.append(" columns:")
|
| for col_name, col_desc in tinfo.get("columns", {}).items():
|
| col_desc_clean = col_desc.replace('"', "'")
|
| lines.append(f' {col_name}: "{col_desc_clean}"')
|
| rels = tinfo.get("relationships", [])
|
| if rels:
|
| lines.append(" relationships:")
|
| for rel in rels:
|
| rel_clean = rel.replace('"', "'")
|
| lines.append(f' - "{rel_clean}"')
|
| return "\n".join(lines)
|
|
|
| schema_yaml = build_schema_yaml(schema_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_schema_documents(
|
| connection: sqlite3.Connection,
|
| schema_metadata: Dict[str, Any],
|
| sample_rows: int = 5,
|
| ) -> Tuple[List[str], List[Dict[str, Any]]]:
|
| """
|
| Build one rich RAG document per table, using schema_metadata.
|
|
|
| Each document inclu des:
|
| - Table name
|
| - Table description
|
| - Columns with type + description
|
| - Relationships (FKs)
|
| - A few sample rows
|
|
|
| Returns:
|
| docs: list of plain-text documents (one per table)
|
| metadatas: list of metadata dicts aligned with docs
|
| (each has doc_type="table_schema", table_name="<name>")
|
| """
|
| cursor = connection.cursor()
|
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
| tables = [row[0] for row in cursor.fetchall()]
|
|
|
| docs: List[str] = []
|
| metadatas: List[Dict[str, Any]] = []
|
|
|
| for table in tables:
|
| tmeta = schema_metadata["tables"][table]
|
| table_desc = tmeta.get("description", "")
|
| columns_meta = tmeta.get("columns", {})
|
| relationships = tmeta.get("relationships", [])
|
|
|
|
|
| cursor.execute(f"PRAGMA table_info('{table}')")
|
| cols = cursor.fetchall()
|
|
|
| col_lines = []
|
| for c in cols:
|
| col_name = c[1]
|
| col_type = c[2] or "TEXT"
|
| col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
|
| col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
|
|
|
|
|
| try:
|
| sample_df = pd.read_sql_query(
|
| f"SELECT * FROM '{table}' LIMIT {sample_rows}",
|
| connection,
|
| )
|
| sample_text = sample_df.to_markdown(index=False)
|
| except Exception:
|
| sample_text = "(could not fetch sample rows)"
|
|
|
|
|
| rel_block = ""
|
| if relationships:
|
| rel_block = "Relationships:\n" + "\n".join(
|
| f"- {rel}" for rel in relationships
|
| ) + "\n"
|
|
|
| doc_text = (
|
| f"Table: {table}\n"
|
| f"Description: {table_desc}\n\n"
|
| f"Columns:\n" + "\n".join(col_lines) + "\n\n"
|
| f"{rel_block}\n"
|
| f"Example rows:\n{sample_text}\n"
|
| )
|
|
|
| docs.append(doc_text)
|
| metadatas.append({
|
| "doc_type": "table_schema",
|
| "table_name": table,
|
| })
|
|
|
| return docs, metadatas
|
|
|
|
|
|
|
|
|
| schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
|
|
|
| RAG_TEXTS: List[str] = []
|
| RAG_METADATAS: List[Dict[str, Any]] = []
|
|
|
|
|
| RAG_TEXTS.extend(schema_docs)
|
| RAG_METADATAS.extend(schema_doc_metas)
|
|
|
|
|
| RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
|
| RAG_METADATAS.append({"doc_type": "global_schema"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_store_final() -> Tuple[Chroma, Any]:
|
| """
|
| Build the production RAG store with fixed embedding model.
|
| """
|
| embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
|
|
|
| collection_name = f"olist_schema_all_mpnet_{uuid.uuid4().hex[:8]}"
|
|
|
| store = Chroma.from_texts(
|
| RAG_TEXTS,
|
| embedding_model,
|
| metadatas=RAG_METADATAS,
|
| collection_name=collection_name,
|
| persist_directory=None,
|
| )
|
|
|
| retriever = store.as_retriever(
|
| search_kwargs={
|
| "k": 3,
|
| "filter": {"doc_type": "table_schema"},
|
| }
|
| )
|
|
|
| return store, retriever
|
|
|
| rag_store, rag_retriever = build_store_final()
|
|
|
|
|
|
|
|
|
|
|
|
|
| llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
|
|
|
| def get_rag_context(question: str) -> str:
|
| """
|
| Retrieve the most relevant schema documents for the question.
|
| """
|
| docs = rag_retriever.invoke(question)
|
| return "\n\n---\n\n".join(d.page_content for d in docs)
|
|
|
| def clean_sql(sql: str) -> str:
|
| sql = sql.strip()
|
| if "```" in sql:
|
| sql = sql.replace("```sql", "").replace("```", "").strip()
|
| return sql
|
|
|
|
|
|
|
| GENERAL_DESC_KEYWORDS = [
|
| "what is this dataset about",
|
| "what is this data about",
|
| "describe this dataset",
|
| "describe the dataset",
|
| "dataset overview",
|
| "data overview",
|
| "summary of the dataset",
|
| "explain this dataset",
|
| ]
|
|
|
|
|
| def is_general_question(question: str) -> bool:
|
| """
|
| Detect high-level descriptive questions where we should answer
|
| directly from schema context instead of generating SQL.
|
| """
|
| q = question.lower().strip()
|
| return any(key in q for key in GENERAL_DESC_KEYWORDS)
|
|
|
|
|
| def answer_general_question(question: str) -> str:
|
| """
|
| Use the RAG schema docs to generate a rich, high-level description
|
| of the Olist dataset for conceptual questions.
|
| """
|
| rag_context = get_rag_context(question)
|
|
|
| system_instructions = """
|
| You are a data documentation expert.
|
|
|
| You will be given:
|
| - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
|
| - A high-level user question like "what is this dataset about?".
|
|
|
| Your job:
|
| - Write a clear, structured overview of the dataset.
|
| - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
|
| - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
|
| - Target a data analyst seeing the dataset for the first time.
|
|
|
| Do NOT write SQL. Answer in Markdown.
|
| """
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== SCHEMA CONTEXT ===\n"
|
| + rag_context
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\nDetailed dataset overview:"
|
| )
|
|
|
| log_prompt("GENERAL_DATASET_QUESTION", prompt)
|
|
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
|
|
|
| return response.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
|
| """
|
| Execute SQL on SQLite and return a DataFrame, else return an error.
|
| """
|
| try:
|
| df = pd.read_sql_query(sql, connection)
|
| return df, None
|
| except Exception as e:
|
| return None, str(e)
|
|
|
| def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
|
| """
|
| Basic SQL validator:
|
| - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
|
| """
|
| try:
|
| cursor = connection.cursor()
|
| cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
|
| return True, None
|
| except Exception as e:
|
| return False, str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def generate_sql(question: str, rag_context: str) -> str:
|
| system_instructions = """
|
| You are a senior data analyst writing SQL for a SQLite database.
|
|
|
| You will be given:
|
| - A description of available tables, columns, and relationships (schema + YAML metadata).
|
| - A natural language question from the user.
|
|
|
| Your job:
|
| - Write ONE valid SQLite SQL query that answers the question.
|
| - ONLY use tables and columns that exist in the schema_context.
|
| - Use explicit JOINs with ON conditions.
|
| - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
|
| - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator, and round to 2 decimals when appropriate.
|
|
|
| Return ONLY the SQL query, no explanation or markdown.
|
| """
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n"
|
| + "=== RAG CONTEXT ===\n"
|
| + rag_context
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\nSQL query:"
|
| )
|
|
|
| log_prompt("SQL_GENERATION", prompt)
|
|
|
|
|
| response = llm.invoke(prompt)
|
|
|
| log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
|
| sql = clean_sql(response.content)
|
| return sql
|
|
|
| def repair_sql(
|
| question: str, rag_context: str, bad_sql: str, error_message: str
|
| ) -> str:
|
| """
|
| Ask the LLM to correct a failing SQL query.
|
| """
|
| system_instructions = """
|
| You are a senior data analyst fixing an existing SQL query for a SQLite database.
|
|
|
| You will be given:
|
| - Schema context (tables, columns, relationships).
|
| - The user's question.
|
| - A previously generated SQL query that failed.
|
| - The SQLite error message.
|
|
|
| Your job:
|
| - Diagnose why the query failed.
|
| - Rewrite ONE valid SQLite SQL query that answers the question.
|
| - ONLY use tables and columns that exist in the schema_context.
|
| - Use explicit JOINs with ON conditions.
|
| - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
|
|
|
| Return ONLY the corrected SQL query, no explanation or markdown.
|
| """
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== RAG CONTEXT ===\n"
|
| + rag_context
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\n=== PREVIOUS (FAILING) SQL ===\n"
|
| + bad_sql
|
| + "\n\n=== SQLITE ERROR ===\n"
|
| + error_message
|
| + "\n\nCorrected SQL query:"
|
| )
|
|
|
| log_prompt("SQL_REPAIR", prompt)
|
|
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
|
| sql = clean_sql(response.content)
|
| return sql
|
|
|
|
|
|
|
|
|
|
|
|
|
| def summarize_results(
|
| question: str,
|
| sql: str,
|
| df: Optional[pd.DataFrame],
|
| rag_context: str,
|
| error: Optional[str] = None,
|
| ) -> str:
|
| """
|
| Ask the LLM to produce a concise, human-readable answer.
|
| """
|
| system_instructions = """
|
| You are a senior data analyst.
|
|
|
| You will be given:
|
| - The user's question.
|
| - The final SQL that was executed.
|
| - A small preview of the query result (as a Markdown table, if available).
|
| - Optional error information if the query failed.
|
|
|
| Your job:
|
| - Provide a clear, concise answer in Markdown.
|
| - If the result is numeric / aggregated, explain what it means in business terms.
|
| - If there was an error, explain it simply and suggest how the user could rephrase.
|
| - Do NOT show raw SQL unless it is helpful to the user.
|
| """
|
|
|
| if df is not None and not df.empty:
|
| preview_rows = min(len(df), 50)
|
| df_preview_md = df.head(preview_rows).to_markdown(index=False)
|
| else:
|
| df_preview_md = "(no rows returned)"
|
|
|
| prompt = (
|
| system_instructions
|
| + "\n\n=== USER QUESTION ===\n"
|
| + question
|
| + "\n\n=== EXECUTED SQL ===\n"
|
| + sql
|
| + "\n\n=== QUERY RESULT PREVIEW ===\n"
|
| + df_preview_md
|
| + "\n\n=== RAG CONTEXT (schema) ===\n"
|
| + rag_context
|
| )
|
|
|
| if error:
|
| prompt += "\n\n=== ERROR ===\n" + error
|
|
|
| log_prompt("RESULT_SUMMARY", prompt)
|
|
|
| response = llm.invoke(prompt)
|
| log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
|
|
|
| return response.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
| def backend_pipeline(question: str):
|
| """
|
| End-to-end flow:
|
|
|
| 1. Retrieve RAG context (schema).
|
| 2. LLM generates SQL.
|
| 3. SQL validator checks & attempts auto-repair if needed - Currently Done Once.
|
| 4. Execute query on SQLite.
|
| 5. LLM summarizes results.
|
| """
|
| if not question or not question.strip():
|
| return "Please type a question.", pd.DataFrame()
|
|
|
| if is_general_question(question):
|
| overview_md = answer_general_question(question)
|
| return overview_md, pd.DataFrame()
|
|
|
|
|
| rag_context = get_rag_context(question)
|
|
|
|
|
| sql = generate_sql(question, rag_context)
|
| original_sql = sql
|
|
|
|
|
| is_valid, validation_error = validate_sql(sql, conn)
|
| repaired = False
|
|
|
| if not is_valid and validation_error:
|
|
|
| repaired_sql = repair_sql(question, rag_context, sql, validation_error)
|
| repaired_valid, repaired_error = validate_sql(repaired_sql, conn)
|
|
|
| if repaired_valid:
|
|
|
| sql = repaired_sql
|
| repaired = True
|
| log_run_event(
|
| "SQL_REPAIR_SUCCESS",
|
| f"Original SQL:\n{original_sql}\n\nRepaired SQL:\n{sql}\n",
|
| )
|
| validation_error = None
|
| else:
|
|
|
| log_run_event(
|
| "SQL_REPAIR_FAILED",
|
| f"Original SQL:\n{original_sql}\n\nRepaired SQL:\n{repaired_sql}\n\nValidation error:\n{repaired_error}",
|
| )
|
| validation_error = repaired_error or validation_error
|
|
|
|
|
| log_run_event("FINAL_SQL", sql)
|
|
|
|
|
| df, exec_error = (None, None)
|
| if validation_error:
|
| exec_error = validation_error
|
| log_run_event(
|
| "EXECUTION_SKIPPED_DUE_TO_VALIDATION_ERROR",
|
| f"SQL:\n{sql}\n\nValidation error:\n{validation_error}",
|
| )
|
| else:
|
| df, exec_error = execute_sql(sql, conn)
|
| if exec_error:
|
| log_run_event(
|
| "EXECUTION_ERROR",
|
| f"SQL:\n{sql}\n\nExecution error:\n{exec_error}",
|
| )
|
| else:
|
| rows = 0 if df is None else len(df)
|
| log_run_event(
|
| "EXECUTION_SUCCESS",
|
| f"SQL:\n{sql}\n\nRows returned: {rows}",
|
| )
|
|
|
|
|
| summary_text = summarize_results(
|
| question=question,
|
| sql=sql,
|
| df=df,
|
| rag_context=rag_context,
|
| error=exec_error,
|
| )
|
|
|
|
|
| sql_status_lines = []
|
| if exec_error:
|
| sql_status_lines.append("There was an error running the SQL.")
|
| sql_status_lines.append(f"**Error:** `{exec_error}`")
|
| else:
|
| sql_status_lines.append("Query ran successfully.")
|
|
|
| if repaired:
|
| sql_status_lines.append("_Note: The original SQL was corrected by the assistant._")
|
|
|
| sql_status_lines.append("\n**Final SQL used:**\n")
|
| sql_status_lines.append(f"```sql\n{sql}\n```")
|
|
|
| sql_debug_md = "\n".join(sql_status_lines)
|
|
|
|
|
| if df is not None and exec_error is None:
|
| df_preview = df
|
| else:
|
| df_preview = pd.DataFrame()
|
|
|
|
|
| answer_md = summary_text + "\n\n---\n\n" + sql_debug_md
|
|
|
| return answer_md, df_preview
|
|
|
|
|
|
|
|
|
|
|
|
|
| with gr.Blocks() as demo:
|
| gr.Markdown(
|
| """
|
| Olist Analytics Assistant (RAG + Groq LLM)
|
|
|
| Ask questions in natural language about the Olist dataset.
|
| The app will:
|
|
|
| 1. Retrieve schema with RAG.
|
| 2. Generate and (if needed) repair SQL
|
| 3. Run it on a local SQLite DB
|
| 4. Show the SQL query, the output and summarize the results
|
| """
|
| )
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| question_in = gr.Textbox(
|
| label="Your question",
|
| placeholder="e.g. Total revenue per seller in 2017 (include shipping)",
|
| lines=4,
|
| )
|
| submit_btn = gr.Button("Run")
|
| with gr.Column(scale=2):
|
| answer_out = gr.Markdown(label="Answer & Details")
|
| table_out = gr.Dataframe(label="Result table")
|
|
|
| submit_btn.click(
|
| fn=backend_pipeline,
|
| inputs=question_in,
|
| outputs=[answer_out, table_out],
|
| )
|
|
|
| app = demo
|
|
|
|
|
| if __name__ == "__main__":
|
| import os
|
| demo.launch(
|
| server_name = "0.0.0.0",
|
| server_port = int(os.getenv("PORT", 7860))
|
| )
|
|
|