Balaprime's picture
Update app.py
d3a8a4a verified
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
from database import initialize_database, run_sql_query
load_dotenv()
api = os.getenv("groq_api_key")
def create_metadata_embeddings():
student = """
Table: student
Columns:
- student_id: an integer representing the unique ID of a student.
- first_name: a string containing the first name of the student.
- last_name: a string containing the last name of the student.
- date_of_birth: a date representing the student's birthdate.
- email: a string for the student's email address.
- phone_number: a string for the student's contact number.
- major: a string representing the student's major field of study.
- year_of_enrollment: an integer for the year the student enrolled.
"""
employee = """
Table: employee
Columns:
- employee_id: an integer representing the unique ID of an employee.
- first_name: a string containing the first name of the employee.
- last_name: a string containing the last name of the employee.
- email: a string for the employee's email address.
- department: a string for the department the employee works in.
- position: a string representing the employee's job title.
- salary: a float representing the employee's salary.
- date_of_joining: a date for when the employee joined the college.
"""
course = """
Table: course
Columns:
- course_id: an integer representing the unique ID of the course.
- course_name: a string containing the course's name.
- course_code: a string for the course's unique code.
- instructor_id: an integer for the ID of the instructor that refers to employee_id in the employee table (who teaches the course).
- department: a string for the department offering the course.
- credits: an integer representing the course credits.
- semester: a string for the semester when the course is offered.
"""
metadata_list = [student, employee, course]
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model, student, employee, course
# def find_best_fit(embeddings, model, user_query, student, employee, course):
# query_embedding = model.encode([user_query])
# similarities = cosine_similarity(query_embedding, embeddings)
# best_match_table = similarities.argmax()
# return [student, employee, course][best_match_table]
def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4):
"""
Identifies relevant tables for a query based on semantic similarity.
Returns a list of matching table metadata strings.
"""
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)[0] # Flatten array
matched_tables = []
for idx, sim in enumerate(similarities):
if sim >= threshold:
matched_tables.append((table_names[idx], metadata_list[idx], sim))
# If none meet the threshold, use top-k highest scoring
if not matched_tables:
top_indices = similarities.argsort()[-top_k:][::-1]
matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices]
return matched_tables # List of (table_name, metadata, similarity)
# def create_prompt(user_query, table_metadata):
# system_prompt = """
# You are an SQL query generator capable of handling multiple tables with relationships.
# Use table metadata to construct accurate queries, including joins, based on the user's intent.
# Ensure:
# - The query is valid.
# - All table and column names match metadata.
# - Join conditions are based on matching keys (e.g., student_id = course.student_id).
# Output a single-line SQL query only, without any explanation.
# """
# user_prompt = f"""
# User Query: {user_query}
# Table Metadata: {table_metadata}
# """
# return system_prompt, user_prompt
def create_prompt(user_query, table_metadata):
system_prompt = """
You are a SQL query generator specialized in generating SQL queries for one or more tables.
Your task is to convert natural language queries into SQL statements using the provided metadata.
Rules:
- Use JOINs only if required by user intent.
- Ensure the generated SQL query only uses the tables and columns mentioned in the metadata.
- Use standard SQL syntax in a single line.
- Do NOT explain or add comments — only return the SQL query string.
Input Format:
User Query: A natural language request.
Table Metadata: List of available tables and their structures.
Output Format:
SQL Query: A valid single-line SQL query only.
"""
user_prompt = f"""
User Query: {user_query}
Table Metadata: {table_metadata}
"""
return system_prompt, user_prompt
def generate_output(system_prompt, user_prompt):
client = Groq(api_key=api)
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
model="llama3-70b-8192"
)
res = chat_completion.choices[0].message.content
return res if res.lower().startswith("select") else "Can't perform the task at the moment."
# def response(user_query):
# embeddings, model, student, employee, course = create_metadata_embeddings()
# table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
# system_prompt, user_prompt = create_prompt(user_query, table_metadata)
# sql_query = generate_output(system_prompt, user_prompt)
# if sql_query.lower().startswith("select"):
# result = run_sql_query(sql_query)
# return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
# else:
# return sql_query
def response(user_query):
embeddings, model, student, employee, course = create_metadata_embeddings()
metadata_list = [student, employee, course]
table_names = ["student", "employee", "course"]
matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names)
combined_metadata = "\n\n".join([table[1] for table in matched_tables])
system_prompt, user_prompt = create_prompt(user_query, combined_metadata)
sql_query = generate_output(system_prompt, user_prompt)
if sql_query.lower().startswith("select"):
result = run_sql_query(sql_query)
return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
else:
return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query."
# Initialize DB on app launch
initialize_database()
desc = """
There are three tables in the database:
Student Table:
Contains student ID, name, DOB, email, phone, major, year of enrollment.
Employee Table:
Contains employee ID, name, email, department, position, salary, and joining date.
Course Info Table:
Contains course ID, name, code, instructor ID, department, credits, and semester.
"""
import gradio as gr
demo = gr.Interface(
fn=response,
inputs=gr.Textbox(label="Please provide the natural language query"),
outputs=gr.Textbox(label="SQL Query and Result"),
title="SQL Query Generator with Results",
description=desc
)
demo.launch(share=True)