PD03 commited on
Commit
b6070af
·
verified ·
1 Parent(s): 1dd709a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -1,24 +1,24 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import pandas as pd
5
  import duckdb
6
  import torch
7
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
 
9
  # 1) Load your synthetic data into DuckDB
10
  df = pd.read_csv("synthetic_profit.csv")
11
  conn = duckdb.connect(":memory:")
12
  conn.register("sap", df)
13
 
14
- # 2) Build a one-line schema string for prompts
15
- schema = ", ".join(df.columns) # e.g. "Region, Product, FiscalYear, ..."
 
16
 
17
- # 3) Load an open-source model fine-tuned on WikiSQL for SQL generation
18
  MODEL_ID = "mrm8488/t5-base-finetuned-wikisql"
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False)
22
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
23
 
24
  sql_generator = pipeline(
@@ -30,25 +30,32 @@ sql_generator = pipeline(
30
  max_length=128,
31
  )
32
 
33
- # 4) Function to turn a natural-language question into SQL
34
  def generate_sql(question: str) -> str:
35
- prompt = (
36
- f"Translate the following English question into SQL for a DuckDB table named `sap` "
37
- f"with columns ({schema}):\n\n"
38
- f"Question: {question}\nSQL:"
39
- )
 
 
 
 
 
 
 
 
40
  out = sql_generator(prompt)[0]["generated_text"].strip()
41
- # strip triple-backticks if present
42
  if out.startswith("```") and out.endswith("```"):
43
  out = "\n".join(out.splitlines()[1:-1])
44
  return out
45
 
46
- # 5) Core Q&A function: NL SQL execute format
47
  def answer_profitability(question: str) -> str:
48
- # a) generate the SQL
49
  sql = generate_sql(question)
50
 
51
- # b) execute it
52
  try:
53
  result_df = conn.execute(sql).df()
54
  except Exception as e:
@@ -57,22 +64,22 @@ def answer_profitability(question: str) -> str:
57
  f"Generated SQL:\n```sql\n{sql}\n```"
58
  )
59
 
60
- # c) format the result
61
  if result_df.empty:
62
- return f"No results.\n\n```sql\n{sql}\n```"
63
  if result_df.shape == (1,1):
64
  return str(result_df.iat[0,0])
65
  return result_df.to_markdown(index=False)
66
 
67
- # 6) Gradio UI with explicit inputs & outputs
68
  iface = gr.Interface(
69
  fn=answer_profitability,
70
- inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…", label="Question"),
71
  outputs=gr.Textbox(lines=8, placeholder="Answer appears here", label="Answer"),
72
- title="SAP Profitability Q&A (HF SQL Generation + DuckDB)",
73
  description=(
74
- "Uses an open-source Hugging Face model fine-tuned on WikiSQL to translate your question into SQL, "
75
- "executes it against the `sap` table in DuckDB, and returns the result."
76
  ),
77
  allow_flagging="never",
78
  )
 
1
+ import os
 
2
  import gradio as gr
3
  import pandas as pd
4
  import duckdb
5
  import torch
6
+ from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
  # 1) Load your synthetic data into DuckDB
9
  df = pd.read_csv("synthetic_profit.csv")
10
  conn = duckdb.connect(":memory:")
11
  conn.register("sap", df)
12
 
13
+ # 2) Build a one-line schema description
14
+ schema = ", ".join(df.columns)
15
+ # e.g. "Region,Product,FiscalYear,FiscalQuarter,Revenue,Profit,ProfitMargin"
16
 
17
+ # 3) Prepare the T5-WikiSQL model & tokenizer (slow, SentencePiece)
18
  MODEL_ID = "mrm8488/t5-base-finetuned-wikisql"
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
22
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
23
 
24
  sql_generator = pipeline(
 
30
  max_length=128,
31
  )
32
 
33
+ # 4) NL SQL with schema + example few-shot
34
  def generate_sql(question: str) -> str:
35
+ prompt = f"""
36
+ -- DuckDB table `sap` columns: {schema}
37
+
38
+ -- EXAMPLE
39
+ -- Q: What is the total profit by region?
40
+ -- SQL: SELECT Region, SUM(Profit) AS total_profit
41
+ -- FROM sap
42
+ -- GROUP BY Region;
43
+
44
+ -- NOW
45
+ Q: {question}
46
+ SQL:
47
+ """
48
  out = sql_generator(prompt)[0]["generated_text"].strip()
49
+ # strip ``` if the model wraps it
50
  if out.startswith("```") and out.endswith("```"):
51
  out = "\n".join(out.splitlines()[1:-1])
52
  return out
53
 
54
+ # 5) Core QA function: generate SQL, run it, format result
55
  def answer_profitability(question: str) -> str:
 
56
  sql = generate_sql(question)
57
 
58
+ # run the SQL
59
  try:
60
  result_df = conn.execute(sql).df()
61
  except Exception as e:
 
64
  f"Generated SQL:\n```sql\n{sql}\n```"
65
  )
66
 
67
+ # format output
68
  if result_df.empty:
69
+ return f"No rows returned.\n\nSQL was:\n```sql\n{sql}\n```"
70
  if result_df.shape == (1,1):
71
  return str(result_df.iat[0,0])
72
  return result_df.to_markdown(index=False)
73
 
74
+ # 6) Gradio UI
75
  iface = gr.Interface(
76
  fn=answer_profitability,
77
+ inputs=gr.Textbox(lines=2, placeholder="Ask a question…", label="Question"),
78
  outputs=gr.Textbox(lines=8, placeholder="Answer appears here", label="Answer"),
79
+ title="SAP Profitability Q&A (HF-Only SQL Generation)",
80
  description=(
81
+ "Uses a T5-WikiSQL model with schema+example prompting to\n"
82
+ "translate your question into valid SQL, then runs it in DuckDB."
83
  ),
84
  allow_flagging="never",
85
  )