Balaprime commited on
Commit
d3a8a4a
·
verified ·
1 Parent(s): 1941ae5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -207
app.py CHANGED
@@ -1,207 +1,207 @@
1
- from dotenv import load_dotenv
2
- import os
3
- from sentence_transformers import SentenceTransformer
4
- import gradio as gr
5
- from sklearn.metrics.pairwise import cosine_similarity
6
- from groq import Groq
7
- from database import initialize_database, run_sql_query
8
-
9
- load_dotenv()
10
- api = os.getenv("groq_api_key")
11
-
12
- def create_metadata_embeddings():
13
- student = """
14
- Table: student
15
- Columns:
16
- - student_id: an integer representing the unique ID of a student.
17
- - first_name: a string containing the first name of the student.
18
- - last_name: a string containing the last name of the student.
19
- - date_of_birth: a date representing the student's birthdate.
20
- - email: a string for the student's email address.
21
- - phone_number: a string for the student's contact number.
22
- - major: a string representing the student's major field of study.
23
- - year_of_enrollment: an integer for the year the student enrolled.
24
- """
25
-
26
- employee = """
27
- Table: employee
28
- Columns:
29
- - employee_id: an integer representing the unique ID of an employee.
30
- - first_name: a string containing the first name of the employee.
31
- - last_name: a string containing the last name of the employee.
32
- - email: a string for the employee's email address.
33
- - department: a string for the department the employee works in.
34
- - position: a string representing the employee's job title.
35
- - salary: a float representing the employee's salary.
36
- - date_of_joining: a date for when the employee joined the college.
37
- """
38
-
39
- course = """
40
- Table: course
41
- Columns:
42
- - course_id: an integer representing the unique ID of the course.
43
- - course_name: a string containing the course's name.
44
- - course_code: a string for the course's unique code.
45
- - instructor_id: an integer for the ID of the instructor teaching the course.
46
- - department: a string for the department offering the course.
47
- - credits: an integer representing the course credits.
48
- - semester: a string for the semester when the course is offered.
49
- """
50
-
51
- metadata_list = [student, employee, course]
52
- model = SentenceTransformer('all-MiniLM-L6-v2')
53
- embeddings = model.encode(metadata_list)
54
-
55
- return embeddings, model, student, employee, course
56
-
57
-
58
- # def find_best_fit(embeddings, model, user_query, student, employee, course):
59
- # query_embedding = model.encode([user_query])
60
- # similarities = cosine_similarity(query_embedding, embeddings)
61
- # best_match_table = similarities.argmax()
62
- # return [student, employee, course][best_match_table]
63
- def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4):
64
- """
65
- Identifies relevant tables for a query based on semantic similarity.
66
- Returns a list of matching table metadata strings.
67
- """
68
- query_embedding = model.encode([user_query])
69
- similarities = cosine_similarity(query_embedding, embeddings)[0] # Flatten array
70
-
71
- matched_tables = []
72
- for idx, sim in enumerate(similarities):
73
- if sim >= threshold:
74
- matched_tables.append((table_names[idx], metadata_list[idx], sim))
75
-
76
- # If none meet the threshold, use top-k highest scoring
77
- if not matched_tables:
78
- top_indices = similarities.argsort()[-top_k:][::-1]
79
- matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices]
80
-
81
- return matched_tables # List of (table_name, metadata, similarity)
82
-
83
-
84
-
85
- # def create_prompt(user_query, table_metadata):
86
-
87
- # system_prompt = """
88
- # You are an SQL query generator capable of handling multiple tables with relationships.
89
-
90
- # Use table metadata to construct accurate queries, including joins, based on the user's intent.
91
-
92
- # Ensure:
93
- # - The query is valid.
94
- # - All table and column names match metadata.
95
- # - Join conditions are based on matching keys (e.g., student_id = course.student_id).
96
-
97
- # Output a single-line SQL query only, without any explanation.
98
-
99
- # """
100
-
101
- # user_prompt = f"""
102
- # User Query: {user_query}
103
- # Table Metadata: {table_metadata}
104
- # """
105
- # return system_prompt, user_prompt
106
- def create_prompt(user_query, table_metadata):
107
- system_prompt = """
108
- You are a SQL query generator specialized in generating SQL queries for one or more tables.
109
- Your task is to convert natural language queries into SQL statements using the provided metadata.
110
-
111
- Rules:
112
- - Use JOINs only if required by user intent.
113
- - Ensure the generated SQL query only uses the tables and columns mentioned in the metadata.
114
- - Use standard SQL syntax in a single line.
115
- - Do NOT explain or add comments — only return the SQL query string.
116
-
117
- Input Format:
118
- User Query: A natural language request.
119
- Table Metadata: List of available tables and their structures.
120
-
121
- Output Format:
122
- SQL Query: A valid single-line SQL query only.
123
- """
124
-
125
- user_prompt = f"""
126
- User Query: {user_query}
127
- Table Metadata: {table_metadata}
128
- """
129
- return system_prompt, user_prompt
130
-
131
-
132
-
133
- def generate_output(system_prompt, user_prompt):
134
- client = Groq(api_key=api)
135
- chat_completion = client.chat.completions.create(
136
- messages=[
137
- {"role": "system", "content": system_prompt},
138
- {"role": "user", "content": user_prompt}
139
- ],
140
- model="llama3-70b-8192"
141
- )
142
- res = chat_completion.choices[0].message.content
143
- return res if res.lower().startswith("select") else "Can't perform the task at the moment."
144
-
145
-
146
- # def response(user_query):
147
- # embeddings, model, student, employee, course = create_metadata_embeddings()
148
- # table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
149
- # system_prompt, user_prompt = create_prompt(user_query, table_metadata)
150
- # sql_query = generate_output(system_prompt, user_prompt)
151
-
152
- # if sql_query.lower().startswith("select"):
153
- # result = run_sql_query(sql_query)
154
- # return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
155
- # else:
156
- # return sql_query
157
- def response(user_query):
158
- embeddings, model, student, employee, course = create_metadata_embeddings()
159
- metadata_list = [student, employee, course]
160
- table_names = ["student", "employee", "course"]
161
-
162
- matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names)
163
-
164
- combined_metadata = "\n\n".join([table[1] for table in matched_tables])
165
-
166
- system_prompt, user_prompt = create_prompt(user_query, combined_metadata)
167
-
168
- sql_query = generate_output(system_prompt, user_prompt)
169
-
170
- if sql_query.lower().startswith("select"):
171
- result = run_sql_query(sql_query)
172
- return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
173
- else:
174
- return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query."
175
-
176
-
177
-
178
-
179
-
180
- # Initialize DB on app launch
181
- initialize_database()
182
-
183
- desc = """
184
- There are three tables in the database:
185
-
186
- Student Table:
187
- Contains student ID, name, DOB, email, phone, major, year of enrollment.
188
-
189
- Employee Table:
190
- Contains employee ID, name, email, department, position, salary, and joining date.
191
-
192
- Course Info Table:
193
- Contains course ID, name, code, instructor ID, department, credits, and semester.
194
- """
195
-
196
- import gradio as gr
197
-
198
- demo = gr.Interface(
199
- fn=response,
200
- inputs=gr.Textbox(label="Please provide the natural language query"),
201
- outputs=gr.Textbox(label="SQL Query and Result"),
202
- title="SQL Query Generator with Results",
203
- description=desc
204
- )
205
-
206
- demo.launch(share=True)
207
-
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from sentence_transformers import SentenceTransformer
4
+ import gradio as gr
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from groq import Groq
7
+ from database import initialize_database, run_sql_query
8
+
9
+ load_dotenv()
10
+ api = os.getenv("groq_api_key")
11
+
12
+ def create_metadata_embeddings():
13
+ student = """
14
+ Table: student
15
+ Columns:
16
+ - student_id: an integer representing the unique ID of a student.
17
+ - first_name: a string containing the first name of the student.
18
+ - last_name: a string containing the last name of the student.
19
+ - date_of_birth: a date representing the student's birthdate.
20
+ - email: a string for the student's email address.
21
+ - phone_number: a string for the student's contact number.
22
+ - major: a string representing the student's major field of study.
23
+ - year_of_enrollment: an integer for the year the student enrolled.
24
+ """
25
+
26
+ employee = """
27
+ Table: employee
28
+ Columns:
29
+ - employee_id: an integer representing the unique ID of an employee.
30
+ - first_name: a string containing the first name of the employee.
31
+ - last_name: a string containing the last name of the employee.
32
+ - email: a string for the employee's email address.
33
+ - department: a string for the department the employee works in.
34
+ - position: a string representing the employee's job title.
35
+ - salary: a float representing the employee's salary.
36
+ - date_of_joining: a date for when the employee joined the college.
37
+ """
38
+
39
+ course = """
40
+ Table: course
41
+ Columns:
42
+ - course_id: an integer representing the unique ID of the course.
43
+ - course_name: a string containing the course's name.
44
+ - course_code: a string for the course's unique code.
45
+ - instructor_id: an integer for the ID of the instructor that refers to employee_id in the employee table (who teaches the course).
46
+ - department: a string for the department offering the course.
47
+ - credits: an integer representing the course credits.
48
+ - semester: a string for the semester when the course is offered.
49
+ """
50
+
51
+ metadata_list = [student, employee, course]
52
+ model = SentenceTransformer('all-MiniLM-L6-v2')
53
+ embeddings = model.encode(metadata_list)
54
+
55
+ return embeddings, model, student, employee, course
56
+
57
+
58
+ # def find_best_fit(embeddings, model, user_query, student, employee, course):
59
+ # query_embedding = model.encode([user_query])
60
+ # similarities = cosine_similarity(query_embedding, embeddings)
61
+ # best_match_table = similarities.argmax()
62
+ # return [student, employee, course][best_match_table]
63
+ def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4):
64
+ """
65
+ Identifies relevant tables for a query based on semantic similarity.
66
+ Returns a list of matching table metadata strings.
67
+ """
68
+ query_embedding = model.encode([user_query])
69
+ similarities = cosine_similarity(query_embedding, embeddings)[0] # Flatten array
70
+
71
+ matched_tables = []
72
+ for idx, sim in enumerate(similarities):
73
+ if sim >= threshold:
74
+ matched_tables.append((table_names[idx], metadata_list[idx], sim))
75
+
76
+ # If none meet the threshold, use top-k highest scoring
77
+ if not matched_tables:
78
+ top_indices = similarities.argsort()[-top_k:][::-1]
79
+ matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices]
80
+
81
+ return matched_tables # List of (table_name, metadata, similarity)
82
+
83
+
84
+
85
+ # def create_prompt(user_query, table_metadata):
86
+
87
+ # system_prompt = """
88
+ # You are an SQL query generator capable of handling multiple tables with relationships.
89
+
90
+ # Use table metadata to construct accurate queries, including joins, based on the user's intent.
91
+
92
+ # Ensure:
93
+ # - The query is valid.
94
+ # - All table and column names match metadata.
95
+ # - Join conditions are based on matching keys (e.g., student_id = course.student_id).
96
+
97
+ # Output a single-line SQL query only, without any explanation.
98
+
99
+ # """
100
+
101
+ # user_prompt = f"""
102
+ # User Query: {user_query}
103
+ # Table Metadata: {table_metadata}
104
+ # """
105
+ # return system_prompt, user_prompt
106
+ def create_prompt(user_query, table_metadata):
107
+ system_prompt = """
108
+ You are a SQL query generator specialized in generating SQL queries for one or more tables.
109
+ Your task is to convert natural language queries into SQL statements using the provided metadata.
110
+
111
+ Rules:
112
+ - Use JOINs only if required by user intent.
113
+ - Ensure the generated SQL query only uses the tables and columns mentioned in the metadata.
114
+ - Use standard SQL syntax in a single line.
115
+ - Do NOT explain or add comments — only return the SQL query string.
116
+
117
+ Input Format:
118
+ User Query: A natural language request.
119
+ Table Metadata: List of available tables and their structures.
120
+
121
+ Output Format:
122
+ SQL Query: A valid single-line SQL query only.
123
+ """
124
+
125
+ user_prompt = f"""
126
+ User Query: {user_query}
127
+ Table Metadata: {table_metadata}
128
+ """
129
+ return system_prompt, user_prompt
130
+
131
+
132
+
133
+ def generate_output(system_prompt, user_prompt):
134
+ client = Groq(api_key=api)
135
+ chat_completion = client.chat.completions.create(
136
+ messages=[
137
+ {"role": "system", "content": system_prompt},
138
+ {"role": "user", "content": user_prompt}
139
+ ],
140
+ model="llama3-70b-8192"
141
+ )
142
+ res = chat_completion.choices[0].message.content
143
+ return res if res.lower().startswith("select") else "Can't perform the task at the moment."
144
+
145
+
146
+ # def response(user_query):
147
+ # embeddings, model, student, employee, course = create_metadata_embeddings()
148
+ # table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
149
+ # system_prompt, user_prompt = create_prompt(user_query, table_metadata)
150
+ # sql_query = generate_output(system_prompt, user_prompt)
151
+
152
+ # if sql_query.lower().startswith("select"):
153
+ # result = run_sql_query(sql_query)
154
+ # return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
155
+ # else:
156
+ # return sql_query
157
+ def response(user_query):
158
+ embeddings, model, student, employee, course = create_metadata_embeddings()
159
+ metadata_list = [student, employee, course]
160
+ table_names = ["student", "employee", "course"]
161
+
162
+ matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names)
163
+
164
+ combined_metadata = "\n\n".join([table[1] for table in matched_tables])
165
+
166
+ system_prompt, user_prompt = create_prompt(user_query, combined_metadata)
167
+
168
+ sql_query = generate_output(system_prompt, user_prompt)
169
+
170
+ if sql_query.lower().startswith("select"):
171
+ result = run_sql_query(sql_query)
172
+ return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
173
+ else:
174
+ return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query."
175
+
176
+
177
+
178
+
179
+
180
+ # Initialize DB on app launch
181
+ initialize_database()
182
+
183
+ desc = """
184
+ There are three tables in the database:
185
+
186
+ Student Table:
187
+ Contains student ID, name, DOB, email, phone, major, year of enrollment.
188
+
189
+ Employee Table:
190
+ Contains employee ID, name, email, department, position, salary, and joining date.
191
+
192
+ Course Info Table:
193
+ Contains course ID, name, code, instructor ID, department, credits, and semester.
194
+ """
195
+
196
+ import gradio as gr
197
+
198
+ demo = gr.Interface(
199
+ fn=response,
200
+ inputs=gr.Textbox(label="Please provide the natural language query"),
201
+ outputs=gr.Textbox(label="SQL Query and Result"),
202
+ title="SQL Query Generator with Results",
203
+ description=desc
204
+ )
205
+
206
+ demo.launch(share=True)
207
+