from dotenv import load_dotenv import os from sentence_transformers import SentenceTransformer import gradio as gr from sklearn.metrics.pairwise import cosine_similarity from groq import Groq from database import initialize_database, run_sql_query load_dotenv() api = os.getenv("groq_api_key") def create_metadata_embeddings(): student = """ Table: student Columns: - student_id: an integer representing the unique ID of a student. - first_name: a string containing the first name of the student. - last_name: a string containing the last name of the student. - date_of_birth: a date representing the student's birthdate. - email: a string for the student's email address. - phone_number: a string for the student's contact number. - major: a string representing the student's major field of study. - year_of_enrollment: an integer for the year the student enrolled. """ employee = """ Table: employee Columns: - employee_id: an integer representing the unique ID of an employee. - first_name: a string containing the first name of the employee. - last_name: a string containing the last name of the employee. - email: a string for the employee's email address. - department: a string for the department the employee works in. - position: a string representing the employee's job title. - salary: a float representing the employee's salary. - date_of_joining: a date for when the employee joined the college. """ course = """ Table: course Columns: - course_id: an integer representing the unique ID of the course. - course_name: a string containing the course's name. - course_code: a string for the course's unique code. - instructor_id: an integer for the ID of the instructor that refers to employee_id in the employee table (who teaches the course). - department: a string for the department offering the course. - credits: an integer representing the course credits. - semester: a string for the semester when the course is offered. """ metadata_list = [student, employee, course] model = SentenceTransformer('all-MiniLM-L6-v2') embeddings = model.encode(metadata_list) return embeddings, model, student, employee, course # def find_best_fit(embeddings, model, user_query, student, employee, course): # query_embedding = model.encode([user_query]) # similarities = cosine_similarity(query_embedding, embeddings) # best_match_table = similarities.argmax() # return [student, employee, course][best_match_table] def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4): """ Identifies relevant tables for a query based on semantic similarity. Returns a list of matching table metadata strings. """ query_embedding = model.encode([user_query]) similarities = cosine_similarity(query_embedding, embeddings)[0] # Flatten array matched_tables = [] for idx, sim in enumerate(similarities): if sim >= threshold: matched_tables.append((table_names[idx], metadata_list[idx], sim)) # If none meet the threshold, use top-k highest scoring if not matched_tables: top_indices = similarities.argsort()[-top_k:][::-1] matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices] return matched_tables # List of (table_name, metadata, similarity) # def create_prompt(user_query, table_metadata): # system_prompt = """ # You are an SQL query generator capable of handling multiple tables with relationships. # Use table metadata to construct accurate queries, including joins, based on the user's intent. # Ensure: # - The query is valid. # - All table and column names match metadata. # - Join conditions are based on matching keys (e.g., student_id = course.student_id). # Output a single-line SQL query only, without any explanation. # """ # user_prompt = f""" # User Query: {user_query} # Table Metadata: {table_metadata} # """ # return system_prompt, user_prompt def create_prompt(user_query, table_metadata): system_prompt = """ You are a SQL query generator specialized in generating SQL queries for one or more tables. Your task is to convert natural language queries into SQL statements using the provided metadata. Rules: - Use JOINs only if required by user intent. - Ensure the generated SQL query only uses the tables and columns mentioned in the metadata. - Use standard SQL syntax in a single line. - Do NOT explain or add comments — only return the SQL query string. Input Format: User Query: A natural language request. Table Metadata: List of available tables and their structures. Output Format: SQL Query: A valid single-line SQL query only. """ user_prompt = f""" User Query: {user_query} Table Metadata: {table_metadata} """ return system_prompt, user_prompt def generate_output(system_prompt, user_prompt): client = Groq(api_key=api) chat_completion = client.chat.completions.create( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], model="llama3-70b-8192" ) res = chat_completion.choices[0].message.content return res if res.lower().startswith("select") else "Can't perform the task at the moment." # def response(user_query): # embeddings, model, student, employee, course = create_metadata_embeddings() # table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course) # system_prompt, user_prompt = create_prompt(user_query, table_metadata) # sql_query = generate_output(system_prompt, user_prompt) # if sql_query.lower().startswith("select"): # result = run_sql_query(sql_query) # return f"SQL Query:\n{sql_query}\n\nResult:\n{result}" # else: # return sql_query def response(user_query): embeddings, model, student, employee, course = create_metadata_embeddings() metadata_list = [student, employee, course] table_names = ["student", "employee", "course"] matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names) combined_metadata = "\n\n".join([table[1] for table in matched_tables]) system_prompt, user_prompt = create_prompt(user_query, combined_metadata) sql_query = generate_output(system_prompt, user_prompt) if sql_query.lower().startswith("select"): result = run_sql_query(sql_query) return f"SQL Query:\n{sql_query}\n\nResult:\n{result}" else: return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query." # Initialize DB on app launch initialize_database() desc = """ There are three tables in the database: Student Table: Contains student ID, name, DOB, email, phone, major, year of enrollment. Employee Table: Contains employee ID, name, email, department, position, salary, and joining date. Course Info Table: Contains course ID, name, code, instructor ID, department, credits, and semester. """ import gradio as gr demo = gr.Interface( fn=response, inputs=gr.Textbox(label="Please provide the natural language query"), outputs=gr.Textbox(label="SQL Query and Result"), title="SQL Query Generator with Results", description=desc ) demo.launch(share=True)