alokik29 commited on
Commit
fb85d2e
ยท
verified ยท
1 Parent(s): f248584

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -65
app.py CHANGED
@@ -2,96 +2,92 @@ import torch
2
  import sqlite3
3
  import pandas as pd
4
  import gradio as gr
 
5
  from langchain_community.llms import HuggingFacePipeline
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
 
8
- # ============================================================
9
- # ๐Ÿš€ Load SQLCoder model
10
- # ============================================================
11
  model_id = "defog/sqlcoder-7b-2"
12
-
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_id,
16
- torch_dtype="auto",
17
- device_map="auto"
18
- )
19
-
20
- pipe = pipeline(
21
- "text-generation",
22
- model=model,
23
- tokenizer=tokenizer,
24
- max_new_tokens=256,
25
- do_sample=False
26
- )
27
-
28
  sqlcoder_llm = HuggingFacePipeline(pipeline=pipe)
29
 
30
- # ============================================================
31
- # ๐Ÿง  Define query function
32
- # ============================================================
33
  def ask_question(user_db, question):
34
- """Takes an uploaded SQLite database + a question, returns SQL + result"""
35
  if not user_db:
36
- return "โŒ Please upload a database file.", None
37
-
38
  conn = sqlite3.connect(user_db.name)
39
  cursor = conn.cursor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Create a Text-to-SQL prompt
42
- prompt = f"""
43
- You are an expert SQL generator.
44
- The database follows the Chinook schema with tables:
45
- customers, invoices, invoice_items, tracks, albums, artists, employees, genres, media_types, playlists, playlist_track.
46
- Translate this question into a valid SQLite query for this schema.
47
- Return only SQL (no text).
48
- Question: {question}
49
- SQL:
50
- """
51
 
 
 
 
 
 
52
 
53
- # โœ… Use .invoke() instead of calling the object directly
 
 
 
54
  response = sqlcoder_llm.invoke(prompt)
55
-
56
- # Ensure we get plain string
57
- if isinstance(response, dict) and "text" in response:
58
- response = response["text"]
59
- elif isinstance(response, list):
60
- response = response[0]["generated_text"]
61
-
62
- # Clean and finalize SQL
63
- sql_query = response.strip().split("SQL:")[-1].strip()
64
- sql_query = sql_query.split("\n")[0].strip()
65
- if not sql_query.endswith(";"):
66
- sql_query += ";"
67
-
68
  try:
69
- cursor.execute(sql_query)
70
  rows = cursor.fetchall()
71
- columns = [desc[0] for desc in cursor.description]
72
- df = pd.DataFrame(rows, columns=columns)
 
 
73
  conn.close()
74
- return sql_query, df
75
- except Exception as e:
76
  conn.close()
77
- return f"โŒ Error executing query: {e}\n\nGenerated SQL:\n{sql_query}", None
78
 
79
- # ============================================================
80
- # ๐ŸŽจ Gradio UI
81
- # ============================================================
82
  demo = gr.Interface(
83
  fn=ask_question,
84
  inputs=[
85
- gr.File(label="Upload SQLite Database (.db)"),
86
- gr.Textbox(label="Ask your question")
87
  ],
88
  outputs=[
89
- gr.Textbox(label="Generated SQL Query"),
90
- gr.Dataframe(label="Query Result")
91
  ],
92
- title="๐Ÿง  Text-to-SQL on Your Own Database",
93
- description="Upload your SQLite database and ask natural language questions."
94
  )
95
 
96
- if __name__ == "__main__":
97
- demo.launch()
 
2
  import sqlite3
3
  import pandas as pd
4
  import gradio as gr
5
+ import re
6
  from langchain_community.llms import HuggingFacePipeline
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
 
9
+ # Load model
 
 
10
  model_id = "defog/sqlcoder-7b-2"
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
13
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, do_sample=False)
 
 
 
 
 
 
 
 
 
 
 
 
14
  sqlcoder_llm = HuggingFacePipeline(pipeline=pipe)
15
 
 
 
 
16
  def ask_question(user_db, question):
 
17
  if not user_db:
18
+ return "โŒ Upload database", None
19
+
20
  conn = sqlite3.connect(user_db.name)
21
  cursor = conn.cursor()
22
+
23
+ # Get full schema with columns
24
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
25
+ tables = [row[0] for row in cursor.fetchall()]
26
+
27
+ schema_info = []
28
+ for table in tables:
29
+ cursor.execute(f"PRAGMA table_info({table});")
30
+ columns = [col[1] for col in cursor.fetchall()]
31
+ schema_info.append(f"{table}({', '.join(columns)})")
32
+
33
+ schema_text = "\n".join(schema_info)
34
+
35
+ # Smart prompt - let model figure out the right table
36
+ prompt = f"""You are a SQL expert. Generate ONLY the SQL query, nothing else.
37
 
38
+ Database Schema:
39
+ {schema_text}
 
 
 
 
 
 
 
 
40
 
41
+ Instructions:
42
+ - Use the EXACT table and column names from the schema above
43
+ - If user asks about concepts (like "sales", "customers", "products"), find the most relevant table
44
+ - Return ONLY valid SQL with semicolon
45
+ - No explanations, no markdown, just SQL
46
 
47
+ Question: {question}
48
+ SQL:"""
49
+
50
+ # Generate SQL
51
  response = sqlcoder_llm.invoke(prompt)
52
+ sql = str(response).strip()
53
+
54
+ # Extract SQL
55
+ if "SQL:" in sql:
56
+ sql = sql.split("SQL:")[-1].strip()
57
+ sql = sql.split("\n")[0].strip()
58
+ if not sql.endswith(";"):
59
+ sql += ";"
60
+
61
+ # Remove common formatting
62
+ sql = sql.replace("```sql", "").replace("```", "").strip()
63
+
64
+ # Execute
65
  try:
66
+ cursor.execute(sql)
67
  rows = cursor.fetchall()
68
+ if cursor.description:
69
+ df = pd.DataFrame(rows, columns=[d[0] for d in cursor.description])
70
+ else:
71
+ df = pd.DataFrame()
72
  conn.close()
73
+ return f"โœ… SQL:\n{sql}\n\n๐Ÿ“Š {len(df)} rows", df
74
+ except sqlite3.Error as e:
75
  conn.close()
76
+ return f"โŒ Error: {e}\n\nSQL tried:\n{sql}\n\n๐Ÿ’ก Available tables:\n{schema_text}", None
77
 
78
+ # UI
 
 
79
  demo = gr.Interface(
80
  fn=ask_question,
81
  inputs=[
82
+ gr.File(label="๐Ÿ“ Upload Database (.db)"),
83
+ gr.Textbox(label="โ“ Ask Question", placeholder="e.g., show all data, highest value, total count")
84
  ],
85
  outputs=[
86
+ gr.Textbox(label="๐Ÿค– SQL & Status", lines=6),
87
+ gr.Dataframe(label="๐Ÿ“Š Results")
88
  ],
89
+ title="๐Ÿ”ฎ Universal Text-to-SQL",
90
+ description="Upload ANY SQLite database and ask questions. The AI will figure out the right tables!"
91
  )
92
 
93
+ demo.launch()