Balaprime commited on
Commit
2401c92
·
verified ·
1 Parent(s): e7f0576

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -307
app.py CHANGED
@@ -1,309 +1,158 @@
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
- import re
5
- import sqlparse
6
-
7
- # Load model and tokenizer
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model = AutoModelForCausalLM.from_pretrained(
10
- "onkolahmet/Qwen2-0.5B-Instruct-SQL-generator",
11
- torch_dtype="auto",
12
- device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
14
- tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator")
15
-
16
- # # Few-shot examples to include in each prompt
17
- # examples = [
18
- # {
19
- # "question": "Get the names and emails of customers who placed an order in the last 30 days.",
20
- # "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
21
- # },
22
- # {
23
- # "question": "Find all employees with a salary greater than 50000.",
24
- # "sql": "SELECT * FROM employees WHERE salary > 50000;"
25
- # },
26
- # {
27
- # "question": "List all product names and their categories where the price is below 50.",
28
- # "sql": "SELECT name, category FROM products WHERE price < 50;"
29
- # },
30
- # {
31
- # "question": "How many users registered in the year 2022?",
32
- # "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
33
- # }
34
- # ]
35
-
36
- def generate_sql(question, context=None):
37
- # Construct prompt with few-shot examples and context if available
38
- prompt = "Translate natural language questions to SQL queries.\n\n"
39
-
40
- # Add table context if available
41
- if context and context.strip():
42
- prompt += f"Table Context:\n{context}\n\n"
43
-
44
- # # Add few-shot examples
45
- # for ex in examples:
46
- # prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
47
-
48
- # Add the current question
49
- prompt += f"Q: {question}\nSQL:"
50
-
51
- # Tokenize and generate
52
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
53
-
54
- # Generate SQL query
55
- outputs = model.generate(
56
- inputs.input_ids,
57
- max_new_tokens=128,
58
- do_sample=True,
59
- eos_token_id=tokenizer.eos_token_id
60
- )
61
-
62
- # Extract and decode only the new generation
63
- sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
64
- return sql_query.strip()
65
-
66
- def clean_sql_output(sql_text):
67
- """
68
- Clean and deduplicate SQL queries:
69
- 1. Remove comments
70
- 2. Remove duplicate queries
71
- 3. Extract only the most relevant query
72
- 4. Format properly
73
- """
74
- # Remove SQL comments (both single line and multi-line)
75
- sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
76
- sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
77
-
78
- # Remove markdown code block syntax if present
79
- sql_text = re.sub(r'```sql|```', '', sql_text)
80
-
81
- # Split into individual queries if multiple exist
82
- if ';' in sql_text:
83
- queries = [q.strip() for q in sql_text.split(';') if q.strip()]
84
- else:
85
- # If no semicolons, try to identify separate queries by SELECT statements
86
- sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
87
- select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
88
-
89
- if len(select_matches) > 1:
90
- queries = []
91
- for i in range(len(select_matches)):
92
- start = select_matches[i].start()
93
- end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
94
- queries.append(sql_text_cleaned[start:end].strip())
95
- else:
96
- queries = [sql_text]
97
-
98
- # Remove empty queries
99
- queries = [q for q in queries if q.strip()]
100
-
101
- if not queries:
102
- return ""
103
-
104
- # If we have multiple queries, need to deduplicate
105
- if len(queries) > 1:
106
- # Normalize queries for comparison (lowercase, remove extra spaces)
107
- normalized_queries = []
108
- for q in queries:
109
- # Use sqlparse to format and normalize
110
- try:
111
- formatted = sqlparse.format(
112
- q + ('' if q.strip().endswith(';') else ';'),
113
- keyword_case='lower',
114
- identifier_case='lower',
115
- strip_comments=True,
116
- reindent=True
117
- )
118
- normalized_queries.append(formatted)
119
- except:
120
- # If sqlparse fails, just do basic normalization
121
- normalized = re.sub(r'\s+', ' ', q.lower().strip())
122
- normalized_queries.append(normalized)
123
-
124
- # Find unique queries
125
- unique_queries = []
126
- unique_normalized = []
127
-
128
- for i, norm_q in enumerate(normalized_queries):
129
- if norm_q not in unique_normalized:
130
- unique_normalized.append(norm_q)
131
- unique_queries.append(queries[i])
132
-
133
- # Choose the most likely correct query:
134
- # 1. Prefer queries with SELECT
135
- # 2. Prefer longer queries (often more detailed)
136
- # 3. Prefer first query if all else equal
137
- select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
138
-
139
- if select_queries:
140
- # Choose the longest SELECT query (likely most detailed)
141
- best_query = max(select_queries, key=len)
142
- elif unique_queries:
143
- # If no SELECT queries, choose the longest query
144
- best_query = max(unique_queries, key=len)
145
- else:
146
- # Fallback to the first query
147
- best_query = queries[0]
148
- else:
149
- best_query = queries[0]
150
-
151
- # Clean up the chosen query
152
- best_query = best_query.strip()
153
- if not best_query.endswith(';'):
154
- best_query += ';'
155
-
156
- # Final formatting to ensure consistent spacing
157
- best_query = re.sub(r'\s+', ' ', best_query)
158
-
159
- try:
160
- # Use sqlparse to nicely format the SQL for display
161
- formatted_sql = sqlparse.format(
162
- best_query,
163
- keyword_case='upper',
164
- identifier_case='lower',
165
- reindent=True,
166
- indent_width=2
167
- )
168
- return formatted_sql
169
- except:
170
- return best_query
171
-
172
- def process_input(question, table_context):
173
- """Function to process user input through the model and return formatted results"""
174
- if not question.strip():
175
- return "Please enter a question."
176
-
177
- # Generate SQL from the question and context
178
- raw_sql = generate_sql(question, table_context)
179
-
180
- # Clean the SQL output
181
- cleaned_sql = clean_sql_output(raw_sql)
182
-
183
- if not cleaned_sql:
184
- return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question."
185
-
186
- return cleaned_sql
187
-
188
- # Sample table context examples for the example selector
189
- example_contexts = [
190
- # Example 1
191
- """
192
- CREATE TABLE customers (
193
- id INT PRIMARY KEY,
194
- name VARCHAR(100),
195
- email VARCHAR(100),
196
- order_date DATE
197
- );
198
- """,
199
-
200
- # Example 2
201
- """
202
- CREATE TABLE products (
203
- id INT PRIMARY KEY,
204
- name VARCHAR(100),
205
- category VARCHAR(50),
206
- price DECIMAL(10,2),
207
- stock_quantity INT
208
- );
209
- """,
210
-
211
- # Example 3
212
- """
213
- CREATE TABLE employees (
214
- id INT PRIMARY KEY,
215
- name VARCHAR(100),
216
- department VARCHAR(50),
217
- salary DECIMAL(10,2),
218
- hire_date DATE
219
- );
220
- CREATE TABLE departments (
221
- id INT PRIMARY KEY,
222
- name VARCHAR(50),
223
- manager_id INT,
224
- budget DECIMAL(15,2)
225
- );
226
- """
227
- ]
228
-
229
- # Sample question examples
230
- example_questions = [
231
- "Get the names and emails of customers who placed an order in the last 30 days.",
232
- "Find all products with less than 10 items in stock.",
233
- "List all employees in the Sales department with a salary greater than 50000.",
234
- "What is the total budget for departments with more than 5 employees?",
235
- "Count how many products are in each category where the price is greater than 100."
236
- ]
237
-
238
- # Create the Gradio interface
239
- with gr.Blocks(title="Text to SQL Converter") as demo:
240
- gr.Markdown("# Text to SQL Query Converter")
241
- gr.Markdown("Enter your question and optional table context to generate an SQL query.")
242
-
243
- with gr.Row():
244
- with gr.Column():
245
- question_input = gr.Textbox(
246
- label="Your Question",
247
- placeholder="e.g., Find all products with price less than $50",
248
- lines=2
249
- )
250
-
251
- table_context = gr.Textbox(
252
- label="Table Context (Optional)",
253
- placeholder="Enter your database schema or table definitions here...",
254
- lines=10
255
- )
256
-
257
- submit_btn = gr.Button("Generate SQL Query")
258
-
259
- with gr.Column():
260
- sql_output = gr.Code(
261
- label="Generated SQL Query",
262
- language="sql",
263
- lines=12
264
- )
265
-
266
- # Examples section
267
- gr.Markdown("### Try some examples")
268
-
269
- example_selector = gr.Examples(
270
- examples=[
271
- ["List all products in the 'Electronics' category with price less than $500", example_contexts[1]],
272
- ["Find the total number of employees in each department", example_contexts[2]],
273
- ["Get customers who placed orders in the last 7 days", example_contexts[0]],
274
- ["Count the number of products in each category", example_contexts[1]],
275
- ["Find the average salary by department", example_contexts[2]]
276
- ],
277
- inputs=[question_input, table_context]
278
- )
279
-
280
- # Set up the submit button to trigger the process_input function
281
- submit_btn.click(
282
- fn=process_input,
283
- inputs=[question_input, table_context],
284
- outputs=sql_output
285
- )
286
-
287
- # Also trigger on pressing Enter in the question input
288
- question_input.submit(
289
- fn=process_input,
290
- inputs=[question_input, table_context],
291
- outputs=sql_output
292
- )
293
-
294
- # Add information about the model
295
- gr.Markdown("""
296
- ### About
297
- This app uses a fine-tuned language model to convert natural language questions into SQL queries.
298
-
299
- - **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
300
- - **How to use**:
301
- 1. Enter your question in natural language
302
- 2. If you have specific table schemas, add them in the Table Context field
303
- 3. Click "Generate SQL Query" or press Enter
304
-
305
- Note: The model works best when table context is provided, but can generate generic SQL queries without it.
306
- """)
307
-
308
- # Launch the app
309
- demo.launch()
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from sentence_transformers import SentenceTransformer
4
  import gradio as gr
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from groq import Groq
7
+
8
+
9
+ load_dotenv()
10
+
11
+ api = os.getenv("groq_api_key")
12
+
13
+ def create_metadata_embeddings():
14
+ student="""
15
+ Table: student
16
+ Columns:
17
+ - student_id: an integer representing the unique ID of a student.
18
+ - first_name: a string containing the first name of the student.
19
+ - last_name: a string containing the last name of the student.
20
+ - date_of_birth: a date representing the student's birthdate.
21
+ - email: a string for the student's email address.
22
+ - phone_number: a string for the student's contact number.
23
+ - major: a string representing the student's major field of study.
24
+ - year_of_enrollment: an integer for the year the student enrolled.
25
+ """
26
+
27
+ employee="""
28
+ Table: employee
29
+ Columns:
30
+ - employee_id: an integer representing the unique ID of an employee.
31
+ - first_name: a string containing the first name of the employee.
32
+ - last_name: a string containing the last name of the employee.
33
+ - email: a string for the employee's email address.
34
+ - department: a string for the department the employee works in.
35
+ - position: a string representing the employee's job title.
36
+ - salary: a float representing the employee's salary.
37
+ - date_of_joining: a date for when the employee joined the college.
38
+ """
39
+
40
+ course="""
41
+ Table: course_info
42
+ Columns:
43
+ - course_id: an integer representing the unique ID of the course.
44
+ - course_name: a string containing the course's name.
45
+ - course_code: a string for the course's unique code.
46
+ - instructor_id: an integer for the ID of the instructor teaching the course.
47
+ - department: a string for the department offering the course.
48
+ - credits: an integer representing the course credits.
49
+ - semester: a string for the semester when the course is offered.
50
+ """
51
+
52
+ metadata_list = [student, employee, course]
53
+
54
+ model = SentenceTransformer('all-MiniLM-L6-v2')
55
+
56
+ embeddings = model.encode(metadata_list)
57
+
58
+ return embeddings,model,student,employee,course
59
+
60
+ def find_best_fit(embeddings,model,user_query,student,employee,course):
61
+ query_embedding = model.encode([user_query])
62
+ similarities = cosine_similarity(query_embedding, embeddings)
63
+ best_match_table = similarities.argmax()
64
+ if(best_match_table==0):
65
+ table_metadata=student
66
+ elif(best_match_table==1):
67
+ table_metadata=employee
68
+ else:
69
+ table_metadata=course
70
+
71
+ return table_metadata
72
+
73
+
74
+
75
+ def create_prompt(user_query,table_metadata):
76
+ system_prompt="""
77
+ You are a SQL query generator specialized in generating SQL queries for a single table at a time. Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata.
78
+
79
+ Rules:
80
+ Single Table Only: Assume all queries are related to a single table provided in the metadata. Ignore any references to other tables.
81
+ Metadata-Based Validation: Always ensure the generated query matches the table name, columns, and data types provided in the metadata.
82
+ User Intent: Accurately capture the user's requirements, such as filters, sorting, or aggregations, as expressed in natural language.
83
+ SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems.
84
+
85
+ Input Format:
86
+ User Query: The user's natural language request.
87
+ Table Metadata: The structure of the relevant table, including the table name, column names, and data types.
88
+
89
+ Output Format:
90
+ SQL Query: A valid SQL query formatted for readability.
91
+ Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only.
92
+ You are ready to generate SQL queries based on the user input and table metadata.
93
+ """
94
+
95
+
96
+ user_prompt=f"""
97
+ User Query: {user_query}
98
+ Table Metadata: {table_metadata}
99
+ """
100
+
101
+ return system_prompt,user_prompt
102
+
103
+
104
+
105
+ def generate_output(system_prompt,user_prompt):
106
+ client = Groq(api_key=api,)
107
+ chat_completion = client.chat.completions.create(messages=[
108
+ {"role": "system", "content": system_prompt},
109
+ {"role": "user","content": user_prompt,}],model="llama3-70b-8192",)
110
+ res = chat_completion.choices[0].message.content
111
+
112
+ select=res[0:6].lower()
113
+ if(select=="select"):
114
+ output=res
115
+ else:
116
+ output="Can't perform the task at the moment."
117
+
118
+ return output
119
+
120
+
121
+ def response(user_query):
122
+ embeddings,model,student,employee,course=create_metadata_embeddings()
123
+
124
+ table_metadata=find_best_fit(embeddings,model,user_query,student,employee,course)
125
+
126
+ system_prompt,user_prompt=create_prompt(user_query,table_metadata)
127
+
128
+ output=generate_output(system_prompt,user_prompt)
129
+
130
+ return output
131
+
132
+ desc="""
133
+
134
+ There are three tables in the database:
135
+
136
+
137
+ Student Table:
138
+ The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment.
139
+
140
+
141
+ Employee Table:
142
+ The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining.
143
+
144
+
145
+ Course Info Table:
146
+ The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered.
147
+
148
+ """
149
+
150
+ demo = gr.Interface(
151
+ fn=response,
152
+ inputs=gr.Textbox(label="Please provide the natural language query"),
153
+ outputs=gr.Textbox(label="SQL Query"),
154
+ title="SQL Query generator",
155
+ description=desc
156
  )
157
+
158
+ demo.launch(share="True")