Spaces:
Sleeping
Sleeping
File size: 7,513 Bytes
d3a8a4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
from database import initialize_database, run_sql_query
load_dotenv()
api = os.getenv("groq_api_key")
def create_metadata_embeddings():
student = """
Table: student
Columns:
- student_id: an integer representing the unique ID of a student.
- first_name: a string containing the first name of the student.
- last_name: a string containing the last name of the student.
- date_of_birth: a date representing the student's birthdate.
- email: a string for the student's email address.
- phone_number: a string for the student's contact number.
- major: a string representing the student's major field of study.
- year_of_enrollment: an integer for the year the student enrolled.
"""
employee = """
Table: employee
Columns:
- employee_id: an integer representing the unique ID of an employee.
- first_name: a string containing the first name of the employee.
- last_name: a string containing the last name of the employee.
- email: a string for the employee's email address.
- department: a string for the department the employee works in.
- position: a string representing the employee's job title.
- salary: a float representing the employee's salary.
- date_of_joining: a date for when the employee joined the college.
"""
course = """
Table: course
Columns:
- course_id: an integer representing the unique ID of the course.
- course_name: a string containing the course's name.
- course_code: a string for the course's unique code.
- instructor_id: an integer for the ID of the instructor that refers to employee_id in the employee table (who teaches the course).
- department: a string for the department offering the course.
- credits: an integer representing the course credits.
- semester: a string for the semester when the course is offered.
"""
metadata_list = [student, employee, course]
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model, student, employee, course
# def find_best_fit(embeddings, model, user_query, student, employee, course):
# query_embedding = model.encode([user_query])
# similarities = cosine_similarity(query_embedding, embeddings)
# best_match_table = similarities.argmax()
# return [student, employee, course][best_match_table]
def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4):
"""
Identifies relevant tables for a query based on semantic similarity.
Returns a list of matching table metadata strings.
"""
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)[0] # Flatten array
matched_tables = []
for idx, sim in enumerate(similarities):
if sim >= threshold:
matched_tables.append((table_names[idx], metadata_list[idx], sim))
# If none meet the threshold, use top-k highest scoring
if not matched_tables:
top_indices = similarities.argsort()[-top_k:][::-1]
matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices]
return matched_tables # List of (table_name, metadata, similarity)
# def create_prompt(user_query, table_metadata):
# system_prompt = """
# You are an SQL query generator capable of handling multiple tables with relationships.
# Use table metadata to construct accurate queries, including joins, based on the user's intent.
# Ensure:
# - The query is valid.
# - All table and column names match metadata.
# - Join conditions are based on matching keys (e.g., student_id = course.student_id).
# Output a single-line SQL query only, without any explanation.
# """
# user_prompt = f"""
# User Query: {user_query}
# Table Metadata: {table_metadata}
# """
# return system_prompt, user_prompt
def create_prompt(user_query, table_metadata):
system_prompt = """
You are a SQL query generator specialized in generating SQL queries for one or more tables.
Your task is to convert natural language queries into SQL statements using the provided metadata.
Rules:
- Use JOINs only if required by user intent.
- Ensure the generated SQL query only uses the tables and columns mentioned in the metadata.
- Use standard SQL syntax in a single line.
- Do NOT explain or add comments — only return the SQL query string.
Input Format:
User Query: A natural language request.
Table Metadata: List of available tables and their structures.
Output Format:
SQL Query: A valid single-line SQL query only.
"""
user_prompt = f"""
User Query: {user_query}
Table Metadata: {table_metadata}
"""
return system_prompt, user_prompt
def generate_output(system_prompt, user_prompt):
client = Groq(api_key=api)
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
model="llama3-70b-8192"
)
res = chat_completion.choices[0].message.content
return res if res.lower().startswith("select") else "Can't perform the task at the moment."
# def response(user_query):
# embeddings, model, student, employee, course = create_metadata_embeddings()
# table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
# system_prompt, user_prompt = create_prompt(user_query, table_metadata)
# sql_query = generate_output(system_prompt, user_prompt)
# if sql_query.lower().startswith("select"):
# result = run_sql_query(sql_query)
# return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
# else:
# return sql_query
def response(user_query):
embeddings, model, student, employee, course = create_metadata_embeddings()
metadata_list = [student, employee, course]
table_names = ["student", "employee", "course"]
matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names)
combined_metadata = "\n\n".join([table[1] for table in matched_tables])
system_prompt, user_prompt = create_prompt(user_query, combined_metadata)
sql_query = generate_output(system_prompt, user_prompt)
if sql_query.lower().startswith("select"):
result = run_sql_query(sql_query)
return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
else:
return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query."
# Initialize DB on app launch
initialize_database()
desc = """
There are three tables in the database:
Student Table:
Contains student ID, name, DOB, email, phone, major, year of enrollment.
Employee Table:
Contains employee ID, name, email, department, position, salary, and joining date.
Course Info Table:
Contains course ID, name, code, instructor ID, department, credits, and semester.
"""
import gradio as gr
demo = gr.Interface(
fn=response,
inputs=gr.Textbox(label="Please provide the natural language query"),
outputs=gr.Textbox(label="SQL Query and Result"),
title="SQL Query Generator with Results",
description=desc
)
demo.launch(share=True)
|