Spaces:
Sleeping
Sleeping
| from dotenv import load_dotenv | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from groq import Groq | |
| from database import initialize_database, run_sql_query | |
| load_dotenv() | |
| api = os.getenv("groq_api_key") | |
| def create_metadata_embeddings(): | |
| student = """ | |
| Table: student | |
| Columns: | |
| - student_id: an integer representing the unique ID of a student. | |
| - first_name: a string containing the first name of the student. | |
| - last_name: a string containing the last name of the student. | |
| - date_of_birth: a date representing the student's birthdate. | |
| - email: a string for the student's email address. | |
| - phone_number: a string for the student's contact number. | |
| - major: a string representing the student's major field of study. | |
| - year_of_enrollment: an integer for the year the student enrolled. | |
| """ | |
| employee = """ | |
| Table: employee | |
| Columns: | |
| - employee_id: an integer representing the unique ID of an employee. | |
| - first_name: a string containing the first name of the employee. | |
| - last_name: a string containing the last name of the employee. | |
| - email: a string for the employee's email address. | |
| - department: a string for the department the employee works in. | |
| - position: a string representing the employee's job title. | |
| - salary: a float representing the employee's salary. | |
| - date_of_joining: a date for when the employee joined the college. | |
| """ | |
| course = """ | |
| Table: course | |
| Columns: | |
| - course_id: an integer representing the unique ID of the course. | |
| - course_name: a string containing the course's name. | |
| - course_code: a string for the course's unique code. | |
| - instructor_id: an integer for the ID of the instructor that refers to employee_id in the employee table (who teaches the course). | |
| - department: a string for the department offering the course. | |
| - credits: an integer representing the course credits. | |
| - semester: a string for the semester when the course is offered. | |
| """ | |
| metadata_list = [student, employee, course] | |
| model = SentenceTransformer('all-MiniLM-L6-v2') | |
| embeddings = model.encode(metadata_list) | |
| return embeddings, model, student, employee, course | |
| # def find_best_fit(embeddings, model, user_query, student, employee, course): | |
| # query_embedding = model.encode([user_query]) | |
| # similarities = cosine_similarity(query_embedding, embeddings) | |
| # best_match_table = similarities.argmax() | |
| # return [student, employee, course][best_match_table] | |
| def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4): | |
| """ | |
| Identifies relevant tables for a query based on semantic similarity. | |
| Returns a list of matching table metadata strings. | |
| """ | |
| query_embedding = model.encode([user_query]) | |
| similarities = cosine_similarity(query_embedding, embeddings)[0] # Flatten array | |
| matched_tables = [] | |
| for idx, sim in enumerate(similarities): | |
| if sim >= threshold: | |
| matched_tables.append((table_names[idx], metadata_list[idx], sim)) | |
| # If none meet the threshold, use top-k highest scoring | |
| if not matched_tables: | |
| top_indices = similarities.argsort()[-top_k:][::-1] | |
| matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices] | |
| return matched_tables # List of (table_name, metadata, similarity) | |
| # def create_prompt(user_query, table_metadata): | |
| # system_prompt = """ | |
| # You are an SQL query generator capable of handling multiple tables with relationships. | |
| # Use table metadata to construct accurate queries, including joins, based on the user's intent. | |
| # Ensure: | |
| # - The query is valid. | |
| # - All table and column names match metadata. | |
| # - Join conditions are based on matching keys (e.g., student_id = course.student_id). | |
| # Output a single-line SQL query only, without any explanation. | |
| # """ | |
| # user_prompt = f""" | |
| # User Query: {user_query} | |
| # Table Metadata: {table_metadata} | |
| # """ | |
| # return system_prompt, user_prompt | |
| def create_prompt(user_query, table_metadata): | |
| system_prompt = """ | |
| You are a SQL query generator specialized in generating SQL queries for one or more tables. | |
| Your task is to convert natural language queries into SQL statements using the provided metadata. | |
| Rules: | |
| - Use JOINs only if required by user intent. | |
| - Ensure the generated SQL query only uses the tables and columns mentioned in the metadata. | |
| - Use standard SQL syntax in a single line. | |
| - Do NOT explain or add comments — only return the SQL query string. | |
| Input Format: | |
| User Query: A natural language request. | |
| Table Metadata: List of available tables and their structures. | |
| Output Format: | |
| SQL Query: A valid single-line SQL query only. | |
| """ | |
| user_prompt = f""" | |
| User Query: {user_query} | |
| Table Metadata: {table_metadata} | |
| """ | |
| return system_prompt, user_prompt | |
| def generate_output(system_prompt, user_prompt): | |
| client = Groq(api_key=api) | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| model="llama3-70b-8192" | |
| ) | |
| res = chat_completion.choices[0].message.content | |
| return res if res.lower().startswith("select") else "Can't perform the task at the moment." | |
| # def response(user_query): | |
| # embeddings, model, student, employee, course = create_metadata_embeddings() | |
| # table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course) | |
| # system_prompt, user_prompt = create_prompt(user_query, table_metadata) | |
| # sql_query = generate_output(system_prompt, user_prompt) | |
| # if sql_query.lower().startswith("select"): | |
| # result = run_sql_query(sql_query) | |
| # return f"SQL Query:\n{sql_query}\n\nResult:\n{result}" | |
| # else: | |
| # return sql_query | |
| def response(user_query): | |
| embeddings, model, student, employee, course = create_metadata_embeddings() | |
| metadata_list = [student, employee, course] | |
| table_names = ["student", "employee", "course"] | |
| matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names) | |
| combined_metadata = "\n\n".join([table[1] for table in matched_tables]) | |
| system_prompt, user_prompt = create_prompt(user_query, combined_metadata) | |
| sql_query = generate_output(system_prompt, user_prompt) | |
| if sql_query.lower().startswith("select"): | |
| result = run_sql_query(sql_query) | |
| return f"SQL Query:\n{sql_query}\n\nResult:\n{result}" | |
| else: | |
| return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query." | |
| # Initialize DB on app launch | |
| initialize_database() | |
| desc = """ | |
| There are three tables in the database: | |
| Student Table: | |
| Contains student ID, name, DOB, email, phone, major, year of enrollment. | |
| Employee Table: | |
| Contains employee ID, name, email, department, position, salary, and joining date. | |
| Course Info Table: | |
| Contains course ID, name, code, instructor ID, department, credits, and semester. | |
| """ | |
| import gradio as gr | |
| demo = gr.Interface( | |
| fn=response, | |
| inputs=gr.Textbox(label="Please provide the natural language query"), | |
| outputs=gr.Textbox(label="SQL Query and Result"), | |
| title="SQL Query Generator with Results", | |
| description=desc | |
| ) | |
| demo.launch(share=True) | |