Shashwat-18's picture
Upload app.py
f12d484 verified
# %%
import warnings
warnings.filterwarnings("ignore")
import os
import sqlite3
from typing import Tuple, Optional, List, Dict, Any
import matplotlib.pyplot as plt
import gradio as gr
import pandas as pd
import uuid
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
# %% [markdown]
# ### 0. Global configurations.
# %%
load_dotenv()
DB_PATH = "olist.db"
DATA_DIR = "data"
# Groq Model
GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found.")
# Embedding model
EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
# %%
import datetime
PROMPT_LOG_PATH = "prompt_logs.txt"
def log_prompt(tag: str, prompt: str) -> None:
"""
Append the full prompt to a log file, with a tag and timestamp.
"""
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
header = f"\n\n================ {tag} @ {timestamp} ================\n"
with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
f.write(header)
f.write(prompt)
f.write("\n================ END PROMPT ================\n")
# %%
RUN_LOG_PATH = "run_logs.txt"
def log_run_event(tag: str, content: str) -> None:
"""
Append model response, final SQL, and error info into a run log.
"""
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
header = f"\n\n================ {tag} @ {timestamp} ================\n"
with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
f.write(header)
f.write(content)
f.write("\n================ END EVENT ================\n")
# %% [markdown]
# ### 1. Database setup (from CSVs)
# %%
def init_db() -> sqlite3.Connection:
"""
Load all CSVs from the data/ folder into a local SQLite DB.
Table names are derived from file names (without .csv).
"""
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
csv_files = [
"olist_customers_dataset.csv",
"olist_orders_dataset.csv",
"olist_order_items_dataset.csv",
"olist_products_dataset.csv",
"olist_order_reviews_dataset.csv",
"olist_order_payments_dataset.csv",
"product_category_name_translation.csv",
"olist_sellers_dataset.csv",
"olist_geolocation_dataset.csv",
]
for fname in csv_files:
path = os.path.join(DATA_DIR, fname)
print(path)
if not os.path.exists(path):
print(f"CSV not found: {path} - skipping")
continue
table_name = os.path.splitext(fname)[0]
print(f"Loading {path} into table {table_name}...")
df = pd.read_csv(path)
df.to_sql(table_name, conn, if_exists="replace", index=False)
return conn
conn = init_db()
# %% [markdown]
# ### 2. Schema extractor + metadata YAML
# %%
OLIST_DOCS: Dict[str, Dict[str, Any]] = {
"olist_customers_dataset": {
"description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
"columns": {
"customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
"customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
"customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
"customer_city": "Customer's city as captured at the time of the order or registration.",
"customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
}
},
"olist_orders_dataset": {
"description": "Customer orders placed on the Olist marketplace, one row per order.",
"columns": {
"order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
"customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
"order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
"order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
"order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
"order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
"order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
"order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
}
},
"olist_order_items_dataset": {
"description": "Order line items, one row per product per order.",
"columns": {
"order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
"order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
"product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
"seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
"shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
"price": "Item price paid by the customer for this line (in BRL, not including freight).",
"freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
}
},
"olist_products_dataset": {
"description": "Product catalog with physical and category attributes, one row per product.",
"columns": {
"product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
"product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
"product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
"product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
"product_photos_qty": "Number of product images associated with the listing.",
"product_weight_g": "Product weight in grams.",
"product_length_cm": "Product length in centimeters (package dimension).",
"product_height_cm": "Product height in centimeters (package dimension).",
"product_width_cm": "Product width in centimeters (package dimension)."
}
},
"olist_order_reviews_dataset": {
"description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
"columns": {
"review_id": "Primary key. Unique identifier for each review record.",
"order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
"review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
"review_comment_title": "Optional short text title or summary of the review.",
"review_comment_message": "Optional detailed free-text comment describing the customer experience.",
"review_creation_date": "Date when the customer created the review.",
"review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
}
},
"olist_order_payments_dataset": {
"description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
"columns": {
"order_id": "Foreign key to olist_orders_dataset.order_id.",
"payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
"payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
"payment_installments": "Number of installments chosen by the customer for this payment.",
"payment_value": "Monetary amount paid in this payment record (in BRL)."
}
},
"product_category_name_translation": {
"description": "Lookup table mapping Portuguese product category names to English equivalents.",
"columns": {
"product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
"product_category_name_english": "Translated product category name in English."
}
},
"olist_sellers_dataset": {
"description": "Seller master data, one row per seller operating on the Olist marketplace.",
"columns": {
"seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
"seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
"seller_city": "City where the seller is located.",
"seller_state": "State where the seller is located (two-letter Brazilian state code)."
}
},
"olist_geolocation_dataset": {
"description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
"columns": {
"geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
"geolocation_lat": "Latitude coordinate of the location.",
"geolocation_lng": "Longitude coordinate of the location.",
"geolocation_city": "City name for the location.",
"geolocation_state": "State code for the location (two-letter Brazilian state code)."
}
}
}
# %%
def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
"""
Introspect SQLite tables, columns and foreign key relationships.
"""
cursor = connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
metadata: Dict[str, Any] = {"tables": {}}
for table in tables:
cursor.execute(f"PRAGMA table_info('{table}')")
cols = cursor.fetchall()
# Manual docs for the table
table_docs = OLIST_DOCS.get(table, {})
table_desc: str = table_docs.get(
"description",
f"Table '{table}' from Olist dataset."
)
column_docs: Dict[str, str] = table_docs.get("columns", {})
columns_meta: Dict[str, str] = {}
for c in cols:
col_name = c[1]
col_type = c[2] or "TEXT"
if col_name in column_docs:
columns_meta[col_name] = column_docs[col_name]
else:
columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
cursor.execute(f"PRAGMA foreign_key_list('{table}')")
fk_rows = cursor.fetchall()
relationships: List[str] = []
for fk in fk_rows:
ref_table = fk[2]
from_col = fk[3]
to_col = fk[4]
relationships.append(
f"{table}.{from_col}{ref_table}.{to_col} (foreign key)"
)
metadata["tables"][table] = {
"description": table_desc,
"columns": columns_meta,
"relationships": relationships,
}
return metadata
schema_metadata = extract_schema_metadata(conn)
schema_metadata
# %%
def build_schema_yaml(metadata: Dict[str, Any]) -> str:
"""
Render metadata dict into a YAML-style string.
"""
lines: List[str] = ["tables:"]
for tname, tinfo in metadata["tables"].items():
lines.append(f" {tname}:")
desc = tinfo.get("description", "").replace('"', "'")
lines.append(f' description: "{desc}"')
lines.append(" columns:")
for col_name, col_desc in tinfo.get("columns", {}).items():
col_desc_clean = col_desc.replace('"', "'")
lines.append(f' {col_name}: "{col_desc_clean}"')
rels = tinfo.get("relationships", [])
if rels:
lines.append(" relationships:")
for rel in rels:
rel_clean = rel.replace('"', "'")
lines.append(f' - "{rel_clean}"')
return "\n".join(lines)
schema_yaml = build_schema_yaml(schema_metadata)
# %% [markdown]
# ### 4. Build schema documents for RAG (taking samples from the table)
# %%
def build_schema_documents(
connection: sqlite3.Connection,
schema_metadata: Dict[str, Any],
sample_rows: int = 5,
) -> Tuple[List[str], List[Dict[str, Any]]]:
"""
Build one rich RAG document per table, using schema_metadata.
Each document inclu des:
- Table name
- Table description
- Columns with type + description
- Relationships (FKs)
- A few sample rows
Returns:
docs: list of plain-text documents (one per table)
metadatas: list of metadata dicts aligned with docs
(each has doc_type="table_schema", table_name="<name>")
"""
cursor = connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
docs: List[str] = []
metadatas: List[Dict[str, Any]] = []
for table in tables:
tmeta = schema_metadata["tables"][table]
table_desc = tmeta.get("description", "")
columns_meta = tmeta.get("columns", {})
relationships = tmeta.get("relationships", [])
# Use PRAGMA to get types, then enrich with descriptions -- STILL NEED TO WORK ON.
cursor.execute(f"PRAGMA table_info('{table}')")
cols = cursor.fetchall()
col_lines = []
for c in cols:
col_name = c[1]
col_type = c[2] or "TEXT"
col_desc = columns_meta.get(col_name, f"Column '{col_name}' of type {col_type}")
col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
# Sample rows
try:
sample_df = pd.read_sql_query(
f"SELECT * FROM '{table}' LIMIT {sample_rows}",
connection,
)
sample_text = sample_df.to_markdown(index=False)
except Exception:
sample_text = "(could not fetch sample rows)"
# Relationships block
rel_block = ""
if relationships:
rel_block = "Relationships:\n" + "\n".join(
f"- {rel}" for rel in relationships
) + "\n"
doc_text = (
f"Table: {table}\n"
f"Description: {table_desc}\n\n"
f"Columns:\n" + "\n".join(col_lines) + "\n\n"
f"{rel_block}\n"
f"Example rows:\n{sample_text}\n"
)
docs.append(doc_text)
metadatas.append({
"doc_type": "table_schema",
"table_name": table,
})
return docs, metadatas
# %%
# Build RAG texts + metadata
schema_docs, schema_doc_metas = build_schema_documents(conn, schema_metadata)
RAG_TEXTS: List[str] = []
RAG_METADATAS: List[Dict[str, Any]] = []
# 1) Per-table docs
RAG_TEXTS.extend(schema_docs)
RAG_METADATAS.extend(schema_doc_metas)
# 2) Global YAML as a separate doc
RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
RAG_METADATAS.append({"doc_type": "global_schema"})
# %% [markdown]
# ### 5. Build Chroma store + global RAG retriever
#
# %%
def build_store_final() -> Tuple[Chroma, Any]:
"""
Build the production RAG store with fixed embedding model.
"""
embedding_model = HuggingFaceEmbeddings(model_name=EMBED_MODEL_NAME)
collection_name = f"olist_schema_all_mpnet_{uuid.uuid4().hex[:8]}"
store = Chroma.from_texts(
RAG_TEXTS,
embedding_model,
metadatas=RAG_METADATAS,
collection_name=collection_name,
persist_directory=None, # in-memory
)
retriever = store.as_retriever(
search_kwargs={
"k": 3,
"filter": {"doc_type": "table_schema"},
}
)
return store, retriever
rag_store, rag_retriever = build_store_final()
# %% [markdown]
# ### 6. Groq LLM via LangChain
# %%
llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
def get_rag_context(question: str) -> str:
"""
Retrieve the most relevant schema documents for the question.
"""
docs = rag_retriever.invoke(question)
return "\n\n---\n\n".join(d.page_content for d in docs)
def clean_sql(sql: str) -> str:
sql = sql.strip()
if "```" in sql:
sql = sql.replace("```sql", "").replace("```", "").strip()
return sql
# 6b. General / descriptive questions
GENERAL_DESC_KEYWORDS = [
"what is this dataset about",
"what is this data about",
"describe this dataset",
"describe the dataset",
"dataset overview",
"data overview",
"summary of the dataset",
"explain this dataset",
]
def is_general_question(question: str) -> bool:
"""
Detect high-level descriptive questions where we should answer
directly from schema context instead of generating SQL.
"""
q = question.lower().strip()
return any(key in q for key in GENERAL_DESC_KEYWORDS)
def answer_general_question(question: str) -> str:
"""
Use the RAG schema docs to generate a rich, high-level description
of the Olist dataset for conceptual questions.
"""
rag_context = get_rag_context(question)
system_instructions = """
You are a data documentation expert.
You will be given:
- Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
- A high-level user question like "what is this dataset about?".
Your job:
- Write a clear, structured overview of the dataset.
- Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
- Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
- Target a data analyst seeing the dataset for the first time.
Do NOT write SQL. Answer in Markdown.
"""
prompt = (
system_instructions
+ "\n\n=== SCHEMA CONTEXT ===\n"
+ rag_context
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\nDetailed dataset overview:"
)
log_prompt("GENERAL_DATASET_QUESTION", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
return response.content.strip()
# %% [markdown]
# ### 7. SQL execution / validation
# %%
def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
"""
Execute SQL on SQLite and return a DataFrame, else return an error.
"""
try:
df = pd.read_sql_query(sql, connection)
return df, None
except Exception as e:
return None, str(e)
def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
"""
Basic SQL validator:
- Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
"""
try:
cursor = connection.cursor()
cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
return True, None
except Exception as e:
return False, str(e)
# %% [markdown]
# ### 8. SQL Generation + Repair with LLM
# %%
def generate_sql(question: str, rag_context: str) -> str:
system_instructions = """
You are a senior data analyst writing SQL for a SQLite database.
You will be given:
- A description of available tables, columns, and relationships (schema + YAML metadata).
- A natural language question from the user.
Your job:
- Write ONE valid SQLite SQL query that answers the question.
- ONLY use tables and columns that exist in the schema_context.
- Use explicit JOINs with ON conditions.
- Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
- Always use floating-point division for percentage calculations using 1.0 * numerator / denominator, and round to 2 decimals when appropriate.
Return ONLY the SQL query, no explanation or markdown.
"""
prompt = (
system_instructions
+ "\n\n"
+ "=== RAG CONTEXT ===\n"
+ rag_context
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\nSQL query:"
)
log_prompt("SQL_GENERATION", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
sql = clean_sql(response.content)
return sql
def repair_sql(
question: str, rag_context: str, bad_sql: str, error_message: str
) -> str:
"""
Ask the LLM to correct a failing SQL query.
"""
system_instructions = """
You are a senior data analyst fixing an existing SQL query for a SQLite database.
You will be given:
- Schema context (tables, columns, relationships).
- The user's question.
- A previously generated SQL query that failed.
- The SQLite error message.
Your job:
- Diagnose why the query failed.
- Rewrite ONE valid SQLite SQL query that answers the question.
- ONLY use tables and columns that exist in the schema_context.
- Use explicit JOINs with ON conditions.
- Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
Return ONLY the corrected SQL query, no explanation or markdown.
"""
prompt = (
system_instructions
+ "\n\n=== RAG CONTEXT ===\n"
+ rag_context
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\n=== PREVIOUS (FAILING) SQL ===\n"
+ bad_sql
+ "\n\n=== SQLITE ERROR ===\n"
+ error_message
+ "\n\nCorrected SQL query:"
)
log_prompt("SQL_REPAIR", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
sql = clean_sql(response.content)
return sql
# %% [markdown]
# ### 9. Result summarization
# %%
def summarize_results(
question: str,
sql: str,
df: Optional[pd.DataFrame],
rag_context: str,
error: Optional[str] = None,
) -> str:
"""
Ask the LLM to produce a concise, human-readable answer.
"""
system_instructions = """
You are a senior data analyst.
You will be given:
- The user's question.
- The final SQL that was executed.
- A small preview of the query result (as a Markdown table, if available).
- Optional error information if the query failed.
Your job:
- Provide a clear, concise answer in Markdown.
- If the result is numeric / aggregated, explain what it means in business terms.
- If there was an error, explain it simply and suggest how the user could rephrase.
- Do NOT show raw SQL unless it is helpful to the user.
"""
if df is not None and not df.empty:
preview_rows = min(len(df), 50)
df_preview_md = df.head(preview_rows).to_markdown(index=False)
else:
df_preview_md = "(no rows returned)"
prompt = (
system_instructions
+ "\n\n=== USER QUESTION ===\n"
+ question
+ "\n\n=== EXECUTED SQL ===\n"
+ sql
+ "\n\n=== QUERY RESULT PREVIEW ===\n"
+ df_preview_md
+ "\n\n=== RAG CONTEXT (schema) ===\n"
+ rag_context
)
if error:
prompt += "\n\n=== ERROR ===\n" + error
log_prompt("RESULT_SUMMARY", prompt)
response = llm.invoke(prompt)
log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
return response.content.strip()
# %% [markdown]
# ### 10. Backend Pipeline
# %%
def backend_pipeline(question: str):
"""
End-to-end flow:
1. Retrieve RAG context (schema).
2. LLM generates SQL.
3. SQL validator checks & attempts auto-repair if needed - Currently Done Once.
4. Execute query on SQLite.
5. LLM summarizes results.
"""
if not question or not question.strip():
return "Please type a question.", pd.DataFrame()
if is_general_question(question):
overview_md = answer_general_question(question)
return overview_md, pd.DataFrame()
# Step 1: RAG retrieve
rag_context = get_rag_context(question)
# Step 2: LLM generates SQL
sql = generate_sql(question, rag_context)
original_sql = sql
# Step 3: Validate SQL; auto-repair once if needed
is_valid, validation_error = validate_sql(sql, conn)
repaired = False
if not is_valid and validation_error:
# Attempt repair
repaired_sql = repair_sql(question, rag_context, sql, validation_error)
repaired_valid, repaired_error = validate_sql(repaired_sql, conn)
if repaired_valid:
# Successfully repaired
sql = repaired_sql
repaired = True
log_run_event(
"SQL_REPAIR_SUCCESS",
f"Original SQL:\n{original_sql}\n\nRepaired SQL:\n{sql}\n",
)
validation_error = None
else:
# Repair attempt failed
log_run_event(
"SQL_REPAIR_FAILED",
f"Original SQL:\n{original_sql}\n\nRepaired SQL:\n{repaired_sql}\n\nValidation error:\n{repaired_error}",
)
validation_error = repaired_error or validation_error
# Log final SQL (whether repaired or not)
log_run_event("FINAL_SQL", sql)
# Step 4: Execute query
df, exec_error = (None, None)
if validation_error:
exec_error = validation_error
log_run_event(
"EXECUTION_SKIPPED_DUE_TO_VALIDATION_ERROR",
f"SQL:\n{sql}\n\nValidation error:\n{validation_error}",
)
else:
df, exec_error = execute_sql(sql, conn)
if exec_error:
log_run_event(
"EXECUTION_ERROR",
f"SQL:\n{sql}\n\nExecution error:\n{exec_error}",
)
else:
rows = 0 if df is None else len(df)
log_run_event(
"EXECUTION_SUCCESS",
f"SQL:\n{sql}\n\nRows returned: {rows}",
)
# Step 5: LLM summary / answer
summary_text = summarize_results(
question=question,
sql=sql,
df=df,
rag_context=rag_context,
error=exec_error,
)
# Build SQL debug / status section
sql_status_lines = []
if exec_error:
sql_status_lines.append("There was an error running the SQL.")
sql_status_lines.append(f"**Error:** `{exec_error}`")
else:
sql_status_lines.append("Query ran successfully.")
if repaired:
sql_status_lines.append("_Note: The original SQL was corrected by the assistant._")
sql_status_lines.append("\n**Final SQL used:**\n")
sql_status_lines.append(f"```sql\n{sql}\n```")
sql_debug_md = "\n".join(sql_status_lines)
# Result table
if df is not None and exec_error is None:
df_preview = df
else:
df_preview = pd.DataFrame()
# Combine summary and SQL status
answer_md = summary_text + "\n\n---\n\n" + sql_debug_md
return answer_md, df_preview
# %% [markdown]
# ### 11. Gradio Interface
# %%
with gr.Blocks() as demo:
gr.Markdown(
"""
Olist Analytics Assistant (RAG + Groq LLM)
Ask questions in natural language about the Olist dataset.
The app will:
1. Retrieve schema with RAG.
2. Generate and (if needed) repair SQL
3. Run it on a local SQLite DB
4. Show the SQL query, the output and summarize the results
"""
)
with gr.Row():
with gr.Column(scale=1):
question_in = gr.Textbox(
label="Your question",
placeholder="e.g. Total revenue per seller in 2017 (include shipping)",
lines=4,
)
submit_btn = gr.Button("Run")
with gr.Column(scale=2):
answer_out = gr.Markdown(label="Answer & Details")
table_out = gr.Dataframe(label="Result table")
submit_btn.click(
fn=backend_pipeline,
inputs=question_in,
outputs=[answer_out, table_out],
)
app = demo
# %%
if __name__ == "__main__":
import os
demo.launch(
server_name = "0.0.0.0",
server_port = int(os.getenv("PORT", 7860))
)