Omkar1872 commited on
Commit
dc27e7e
·
verified ·
1 Parent(s): b12feaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -56
app.py CHANGED
@@ -1,56 +1,53 @@
1
- import gradio as gr
2
- import pandas as pd
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
- import sqlite3
5
-
6
- # Load model and tokenizer
7
- model_name = "mrm8488/t5-base-finetuned-wikiSQL"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
-
11
- def generate_sql_query(natural_language, data):
12
- # Load uploaded CSV
13
- df = pd.read_csv(data.name)
14
-
15
- # Create in-memory SQLite DB
16
- conn = sqlite3.connect(":memory:")
17
- df.to_sql("data_table", conn, index=False, if_exists="replace")
18
-
19
- # Create schema description
20
- schema = ", ".join([f"{col}" for col in df.columns])
21
-
22
- # Combine user query and schema
23
- input_text = f"translate English to SQL: {natural_language} | table columns: {schema}"
24
-
25
- # Generate SQL query
26
- inputs = tokenizer(input_text, return_tensors="pt")
27
- outputs = model.generate(**inputs, max_length=256)
28
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
-
30
- try:
31
- # Execute the generated SQL query
32
- result_df = pd.read_sql_query(sql_query, conn)
33
- except Exception as e:
34
- result_df = pd.DataFrame({"Error": [str(e)]})
35
-
36
- conn.close()
37
- return sql_query, result_df.head()
38
-
39
- # Gradio UI
40
- iface = gr.Interface(
41
- fn=generate_sql_query,
42
- inputs=[
43
- gr.Textbox(label="Enter your question (Natural Language)", placeholder="e.g., Show customers with age > 30"),
44
- gr.File(label="Upload CSV dataset")
45
- ],
46
- outputs=[
47
- gr.Textbox(label="Generated SQL Query"),
48
- gr.Dataframe(label="Query Result")
49
- ],
50
- title="🧠 Natural Language to SQL Generator",
51
- description="Upload a CSV file and ask questions in plain English. The app converts them into SQL and shows the result.",
52
- allow_flagging="never"
53
- )
54
-
55
- if __name__ == "__main__":
56
- iface.launch()
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import sqlite3
5
+
6
+ # Load model
7
+ model_name = "mrm8488/t5-base-finetuned-wikiSQL"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+
11
+ def nl_to_sql(question, file):
12
+ try:
13
+ df = pd.read_csv(file.name)
14
+ except Exception as e:
15
+ return f"Error reading CSV: {e}", pd.DataFrame()
16
+
17
+ # Create SQLite DB
18
+ conn = sqlite3.connect(":memory:")
19
+ df.to_sql("data_table", conn, index=False, if_exists="replace")
20
+
21
+ # Schema description
22
+ schema = ", ".join(df.columns)
23
+ text = f"translate English to SQL: {question} | table columns: {schema}"
24
+
25
+ inputs = tokenizer(text, return_tensors="pt")
26
+ outputs = model.generate(**inputs, max_length=256)
27
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+
29
+ # Try executing SQL query
30
+ try:
31
+ result = pd.read_sql_query(sql_query, conn)
32
+ except Exception as e:
33
+ result = pd.DataFrame({"Error": [str(e)]})
34
+
35
+ conn.close()
36
+ return sql_query, result.head()
37
+
38
+ iface = gr.Interface(
39
+ fn=nl_to_sql,
40
+ inputs=[
41
+ gr.Textbox(label="Ask your question (Natural Language)", placeholder="e.g., Show customers older than 30"),
42
+ gr.File(label="Upload your CSV file")
43
+ ],
44
+ outputs=[
45
+ gr.Textbox(label="Generated SQL Query"),
46
+ gr.Dataframe(label="Result Preview")
47
+ ],
48
+ title="🧠 Natural Language to SQL Generator",
49
+ description="Upload a CSV and ask questions in plain English. Generates SQL and shows results instantly."
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ iface.launch()