talk_to_data / app.py
PD03's picture
Update app.py
03083eb verified
raw
history blame
2.68 kB
# app.py
import gradio as gr
import pandas as pd
import duckdb
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load your synthetic data into DuckDB
df = pd.read_csv("synthetic_profit.csv")
conn = duckdb.connect(":memory:")
conn.register("sap", df)
# 2) Build a one-line schema string for prompts
schema = ", ".join(df.columns) # e.g. "Region, Product, FiscalYear, ..."
# 3) Load an open-source model fine-tuned on WikiSQL for SQL generation
MODEL_ID = "mrm8488/t5-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
sql_generator = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
max_length=128,
)
# 4) Function to turn a natural-language question into SQL
def generate_sql(question: str) -> str:
prompt = (
f"Translate the following English question into SQL for a DuckDB table named `sap` "
f"with columns ({schema}):\n\n"
f"Question: {question}\nSQL:"
)
out = sql_generator(prompt)[0]["generated_text"].strip()
# strip triple-backticks if present
if out.startswith("```") and out.endswith("```"):
out = "\n".join(out.splitlines()[1:-1])
return out
# 5) Core Q&A function: NL → SQL → execute → format
def answer_profitability(question: str) -> str:
# a) generate the SQL
sql = generate_sql(question)
# b) execute it
try:
result_df = conn.execute(sql).df()
except Exception as e:
return (
f"❌ SQL execution error:\n{e}\n\n"
f"Generated SQL:\n```sql\n{sql}\n```"
)
# c) format the result
if result_df.empty:
return f"No results.\n\n```sql\n{sql}\n```"
if result_df.shape == (1,1):
return str(result_df.iat[0,0])
return result_df.to_markdown(index=False)
# 6) Gradio UI with explicit inputs & outputs
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…", label="Question"),
outputs=gr.Textbox(lines=8, placeholder="Answer appears here", label="Answer"),
title="SAP Profitability Q&A (HF SQL Generation + DuckDB)",
description=(
"Uses an open-source Hugging Face model fine-tuned on WikiSQL to translate your question into SQL, "
"executes it against the `sap` table in DuckDB, and returns the result."
),
allow_flagging="never",
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)