File size: 7,513 Bytes
d3a8a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
from database import initialize_database, run_sql_query

load_dotenv()
api = os.getenv("groq_api_key")

def create_metadata_embeddings():
    student = """
    Table: student
    Columns:
    - student_id: an integer representing the unique ID of a student.
    - first_name: a string containing the first name of the student.
    - last_name: a string containing the last name of the student.
    - date_of_birth: a date representing the student's birthdate.
    - email: a string for the student's email address.
    - phone_number: a string for the student's contact number.
    - major: a string representing the student's major field of study.
    - year_of_enrollment: an integer for the year the student enrolled.
    """

    employee = """
    Table: employee
    Columns:
    - employee_id: an integer representing the unique ID of an employee.
    - first_name: a string containing the first name of the employee.
    - last_name: a string containing the last name of the employee.
    - email: a string for the employee's email address.
    - department: a string for the department the employee works in.
    - position: a string representing the employee's job title.
    - salary: a float representing the employee's salary.
    - date_of_joining: a date for when the employee joined the college.
    """

    course = """
    Table: course
    Columns:
    - course_id: an integer representing the unique ID of the course.
    - course_name: a string containing the course's name.
    - course_code: a string for the course's unique code.
    - instructor_id: an integer for the ID of the instructor that refers to employee_id in the employee table (who teaches the course).
    - department: a string for the department offering the course.
    - credits: an integer representing the course credits.
    - semester: a string for the semester when the course is offered.
    """

    metadata_list = [student, employee, course]
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(metadata_list)

    return embeddings, model, student, employee, course


# def find_best_fit(embeddings, model, user_query, student, employee, course):
#     query_embedding = model.encode([user_query])
#     similarities = cosine_similarity(query_embedding, embeddings)
#     best_match_table = similarities.argmax()
#     return [student, employee, course][best_match_table]
def find_best_fit(embeddings, model, user_query, metadata_list, table_names, top_k=2, threshold=0.4):
    """
    Identifies relevant tables for a query based on semantic similarity.
    Returns a list of matching table metadata strings.
    """
    query_embedding = model.encode([user_query])
    similarities = cosine_similarity(query_embedding, embeddings)[0]  # Flatten array

    matched_tables = []
    for idx, sim in enumerate(similarities):
        if sim >= threshold:
            matched_tables.append((table_names[idx], metadata_list[idx], sim))

    # If none meet the threshold, use top-k highest scoring
    if not matched_tables:
        top_indices = similarities.argsort()[-top_k:][::-1]
        matched_tables = [(table_names[i], metadata_list[i], similarities[i]) for i in top_indices]

    return matched_tables  # List of (table_name, metadata, similarity)



# def create_prompt(user_query, table_metadata):

#     system_prompt = """
#     You are an SQL query generator capable of handling multiple tables with relationships.

# Use table metadata to construct accurate queries, including joins, based on the user's intent.

# Ensure:
# - The query is valid.
# - All table and column names match metadata.
# - Join conditions are based on matching keys (e.g., student_id = course.student_id).

# Output a single-line SQL query only, without any explanation.

#     """

#     user_prompt = f"""
#     User Query: {user_query}
#     Table Metadata: {table_metadata}
#     """
#     return system_prompt, user_prompt
def create_prompt(user_query, table_metadata):
    system_prompt = """
    You are a SQL query generator specialized in generating SQL queries for one or more tables. 
    Your task is to convert natural language queries into SQL statements using the provided metadata.

    Rules:
    - Use JOINs only if required by user intent.
    - Ensure the generated SQL query only uses the tables and columns mentioned in the metadata.
    - Use standard SQL syntax in a single line.
    - Do NOT explain or add comments — only return the SQL query string.

    Input Format:
    User Query: A natural language request.
    Table Metadata: List of available tables and their structures.

    Output Format:
    SQL Query: A valid single-line SQL query only.
    """

    user_prompt = f"""
    User Query: {user_query}
    Table Metadata: {table_metadata}
    """
    return system_prompt, user_prompt



def generate_output(system_prompt, user_prompt):
    client = Groq(api_key=api)
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        model="llama3-70b-8192"
    )
    res = chat_completion.choices[0].message.content
    return res if res.lower().startswith("select") else "Can't perform the task at the moment."


# def response(user_query):
#     embeddings, model, student, employee, course = create_metadata_embeddings()
#     table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
#     system_prompt, user_prompt = create_prompt(user_query, table_metadata)
#     sql_query = generate_output(system_prompt, user_prompt)

#     if sql_query.lower().startswith("select"):
#         result = run_sql_query(sql_query)
#         return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
#     else:
#         return sql_query
def response(user_query):
    embeddings, model, student, employee, course = create_metadata_embeddings()
    metadata_list = [student, employee, course]
    table_names = ["student", "employee", "course"]

    matched_tables = find_best_fit(embeddings, model, user_query, metadata_list, table_names)

    combined_metadata = "\n\n".join([table[1] for table in matched_tables])

    system_prompt, user_prompt = create_prompt(user_query, combined_metadata)

    sql_query = generate_output(system_prompt, user_prompt)

    if sql_query.lower().startswith("select"):
        result = run_sql_query(sql_query)
        return f"SQL Query:\n{sql_query}\n\nResult:\n{result}"
    else:
        return f"SQL Query:\n{sql_query}\n\nResult:\nUnable to fetch data or unsupported query."





# Initialize DB on app launch
initialize_database()

desc = """
There are three tables in the database:

Student Table:  
Contains student ID, name, DOB, email, phone, major, year of enrollment.

Employee Table:  
Contains employee ID, name, email, department, position, salary, and joining date.

Course Info Table:  
Contains course ID, name, code, instructor ID, department, credits, and semester.
"""

import gradio as gr

demo = gr.Interface(
    fn=response,
    inputs=gr.Textbox(label="Please provide the natural language query"),
    outputs=gr.Textbox(label="SQL Query and Result"),
    title="SQL Query Generator with Results",
    description=desc
)

demo.launch(share=True)