PD03 commited on
Commit
af13fe6
·
verified ·
1 Parent(s): b6070af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -41
app.py CHANGED
@@ -1,26 +1,20 @@
1
- import os
2
  import gradio as gr
3
  import pandas as pd
4
  import duckdb
5
  import torch
6
  from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
- # 1) Load your synthetic data into DuckDB
9
  df = pd.read_csv("synthetic_profit.csv")
10
  conn = duckdb.connect(":memory:")
11
  conn.register("sap", df)
12
-
13
- # 2) Build a one-line schema description
14
  schema = ", ".join(df.columns)
15
- # e.g. "Region,Product,FiscalYear,FiscalQuarter,Revenue,Profit,ProfitMargin"
16
 
17
- # 3) Prepare the T5-WikiSQL model & tokenizer (slow, SentencePiece)
18
  MODEL_ID = "mrm8488/t5-base-finetuned-wikisql"
19
  device = 0 if torch.cuda.is_available() else -1
20
-
21
  tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
22
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
23
-
24
  sql_generator = pipeline(
25
  "text2text-generation",
26
  model=model,
@@ -30,59 +24,75 @@ sql_generator = pipeline(
30
  max_length=128,
31
  )
32
 
33
- # 4) NL SQL with schema + example few-shot
34
  def generate_sql(question: str) -> str:
35
  prompt = f"""
36
  -- DuckDB table `sap` columns: {schema}
37
 
38
- -- EXAMPLE
39
  -- Q: What is the total profit by region?
40
- -- SQL: SELECT Region, SUM(Profit) AS total_profit
41
- -- FROM sap
42
- -- GROUP BY Region;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  -- NOW
45
  Q: {question}
46
  SQL:
47
- """
 
48
  out = sql_generator(prompt)[0]["generated_text"].strip()
49
- # strip ``` if the model wraps it
50
- if out.startswith("```") and out.endswith("```"):
51
- out = "\n".join(out.splitlines()[1:-1])
52
- return out
 
 
53
 
54
- # 5) Core QA function: generate SQL, run it, format result
55
  def answer_profitability(question: str) -> str:
56
- sql = generate_sql(question)
 
 
 
57
 
58
- # run the SQL
59
  try:
60
- result_df = conn.execute(sql).df()
 
 
 
61
  except Exception as e:
62
- return (
63
- f"❌ SQL execution error:\n{e}\n\n"
64
- f"Generated SQL:\n```sql\n{sql}\n```"
65
- )
66
 
67
- # format output
68
- if result_df.empty:
69
- return f"No rows returned.\n\nSQL was:\n```sql\n{sql}\n```"
70
- if result_df.shape == (1,1):
71
- return str(result_df.iat[0,0])
72
- return result_df.to_markdown(index=False)
73
 
74
- # 6) Gradio UI
75
  iface = gr.Interface(
76
  fn=answer_profitability,
77
- inputs=gr.Textbox(lines=2, placeholder="Ask a question…", label="Question"),
78
- outputs=gr.Textbox(lines=8, placeholder="Answer appears here", label="Answer"),
79
- title="SAP Profitability Q&A (HF-Only SQL Generation)",
80
- description=(
81
- "Uses a T5-WikiSQL model with schema+example prompting to\n"
82
- "translate your question into valid SQL, then runs it in DuckDB."
83
- ),
84
  allow_flagging="never",
85
  )
86
 
87
  if __name__ == "__main__":
88
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import duckdb
4
  import torch
5
  from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
+ # Load data
8
  df = pd.read_csv("synthetic_profit.csv")
9
  conn = duckdb.connect(":memory:")
10
  conn.register("sap", df)
 
 
11
  schema = ", ".join(df.columns)
 
12
 
13
+ # Model & tokenizer
14
  MODEL_ID = "mrm8488/t5-base-finetuned-wikisql"
15
  device = 0 if torch.cuda.is_available() else -1
 
16
  tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
 
18
  sql_generator = pipeline(
19
  "text2text-generation",
20
  model=model,
 
24
  max_length=128,
25
  )
26
 
27
+ # Prompt→SQL with few-shot
28
  def generate_sql(question: str) -> str:
29
  prompt = f"""
30
  -- DuckDB table `sap` columns: {schema}
31
 
32
+ -- EXAMPLE 1
33
  -- Q: What is the total profit by region?
34
+ -- SQL:
35
+ SELECT
36
+ Region,
37
+ SUM(Profit) AS total_profit
38
+ FROM sap
39
+ GROUP BY Region;
40
+
41
+ -- EXAMPLE 2
42
+ -- Q: What is the total revenue for Product A in EMEA in Q1 2024?
43
+ -- SQL:
44
+ SELECT
45
+ SUM(Revenue) AS total_revenue
46
+ FROM sap
47
+ WHERE
48
+ Product = 'Product A'
49
+ AND Region = 'EMEA'
50
+ AND FiscalYear = 2024
51
+ AND FiscalQuarter = 'Q1';
52
 
53
  -- NOW
54
  Q: {question}
55
  SQL:
56
+ """.strip()
57
+
58
  out = sql_generator(prompt)[0]["generated_text"].strip()
59
+ if "SELECT" in out.upper():
60
+ sql = out[out.upper().index("SELECT"):]
61
+ if ";" in sql:
62
+ sql = sql[: sql.rindex(";") + 1]
63
+ return sql
64
+ raise ValueError(f"Did not generate a SELECT; got:\n{out}")
65
 
66
+ # NL→SQL→DuckDB→Result
67
  def answer_profitability(question: str) -> str:
68
+ try:
69
+ sql = generate_sql(question)
70
+ except Exception as e:
71
+ return f"❌ Prompt/SQL error:\n{e}"
72
 
 
73
  try:
74
+ rel = conn.execute(sql)
75
+ if rel is None:
76
+ return f"❌ No relation returned for SQL:\n```sql\n{sql}\n```"
77
+ df_out = rel.df()
78
  except Exception as e:
79
+ return f"❌ SQL execution error:\n{e}\n\nGenerated SQL:\n```sql\n{sql}\n```"
 
 
 
80
 
81
+ if df_out.empty:
82
+ return f"No rows.\n\n```sql\n{sql}\n```"
83
+ if df_out.shape == (1,1):
84
+ return str(df_out.iat[0,0])
85
+ return df_out.to_markdown(index=False)
 
86
 
87
+ # Gradio UI
88
  iface = gr.Interface(
89
  fn=answer_profitability,
90
+ inputs=gr.Textbox(lines=2, label="Question"),
91
+ outputs=gr.Textbox(lines=8, label="Answer"),
92
+ title="SAP Profitability Q&A",
93
+ description="Translate English → SQL → DuckDB → Answer",
 
 
 
94
  allow_flagging="never",
95
  )
96
 
97
  if __name__ == "__main__":
98
+ iface.launch(server_name="0.0.0.0", server_port=7860)