File size: 6,463 Bytes
3851cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f50c10
3851cd3
 
 
 
3f50c10
 
 
 
 
 
 
 
 
 
 
 
 
3851cd3
 
 
 
 
 
 
 
cfb9d4e
 
 
3851cd3
 
cfb9d4e
 
3851cd3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import sqlite3
import pandas as pd

load_dotenv()
api = os.getenv("groq_api_key")

# 🔹 STEP 1: Create a sample in-memory SQLite database with mock data
def setup_database():
    conn = sqlite3.connect("college.db")
    cursor = conn.cursor()

    # Drop existing tables
    cursor.execute("DROP TABLE IF EXISTS student;")
    cursor.execute("DROP TABLE IF EXISTS employee;")
    cursor.execute("DROP TABLE IF EXISTS course_info;")

    # Student table
    cursor.execute("""
    CREATE TABLE student (
        student_id INTEGER,
        first_name TEXT,
        last_name TEXT,
        date_of_birth TEXT,
        email TEXT,
        phone_number TEXT,
        major TEXT,
        year_of_enrollment INTEGER
    );
    """)

    cursor.execute("INSERT INTO student VALUES (1, 'Alice', 'Smith', '2000-05-01', 'alice@example.com', '1234567890', 'Computer Science', 2019);")

    # Employee table
    cursor.execute("""
    CREATE TABLE employee (
        employee_id INTEGER,
        first_name TEXT,
        last_name TEXT,
        email TEXT,
        department TEXT,
        position TEXT,
        salary REAL,
        date_of_joining TEXT
    );
    """)

    cursor.execute("INSERT INTO employee VALUES (101, 'John', 'Doe', 'john@college.edu', 'CSE', 'Professor', 80000, '2015-08-20');")

    # Course table
    cursor.execute("""
    CREATE TABLE course_info (
        course_id INTEGER,
        course_name TEXT,
        course_code TEXT,
        instructor_id INTEGER,
        department TEXT,
        credits INTEGER,
        semester TEXT
    );
    """)

    cursor.execute("INSERT INTO course_info VALUES (501, 'AI Basics', 'CS501', 101, 'CSE', 4, 'Fall');")

    conn.commit()
    conn.close()

# Call it once to setup
setup_database()

# 🔹 STEP 2: Embedding & LLM logic (unchanged mostly)
def create_metadata_embeddings():
    student = """Table: student..."""  # (same as your original metadata)
    employee = """Table: employee..."""
    course = """Table: course_info..."""
    metadata_list = [student, employee, course]
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(metadata_list)
    return embeddings, model, student, employee, course

def find_best_fit(embeddings, model, user_query, student, employee, course):
    query_embedding = model.encode([user_query])
    similarities = cosine_similarity(query_embedding, embeddings)
    best_match_table = similarities.argmax()
    return [student, employee, course][best_match_table]

def create_prompt(user_query, table_metadata):
    system_prompt = """You are a SQL query generator specialized in generating SQL queries for a single table at a time. Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata.
  
Rules:
- Multi-Table Queries Allowed: You can generate queries involving multiple tables using appropriate SQL JOIN operations, based on the provided metadata.
- Join Logic: Use INNER JOIN, LEFT JOIN, or other appropriate joins based on logical relationships (e.g., foreign keys like `student_id`, `instructor_id`, etc.) inferred from the metadata.
- Metadata-Based Validation: Always ensure the generated query matches the table names, columns, and data types as described in the metadata.
- User Intent: Accurately capture the user's requirements such as filters, sorting, aggregations, and selections across one or more tables.
- SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems.
- Output Format: Provide only the SQL query in a single line. Do not include explanations or any extra text.
  
  Input Format:
  User Query: The user's natural language request.
  Table Metadata: The structure of the relevant table, including the table name, column names, and data types.
  
  Output Format:
  SQL Query: A valid SQL query formatted for readability.
  Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only.
  You are ready to generate SQL queries based on the user input and table metadata."""
    user_prompt = f"User Query: {user_query}\nTable Metadata: {table_metadata}"
    return system_prompt, user_prompt

def generate_sql(system_prompt, user_prompt):
    client = Groq(api_key=api)
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        model="llama3-70b-8192",
    )
    res = chat_completion.choices[0].message.content.strip()
    if res.lower().startswith("select"):
        return res
    else:
        return None

# 🔹 STEP 3: Execute SQL and return results
def execute_sql(sql_query):
    try:
        conn = sqlite3.connect("college.db")
        df = pd.read_sql_query(sql_query, conn)
        conn.close()
        return df
    except Exception as e:
        return str(e)

# 🔹 STEP 4: Final combined response

def response(user_query):
    embeddings, model, student, employee, course = create_metadata_embeddings()
    table_metadata = find_best_fit(embeddings, model, user_query, student, employee, course)
    system_prompt, user_prompt = create_prompt(user_query, table_metadata)
    sql_query = generate_output(system_prompt, user_prompt)

    # Try running the query against the SQLite database
    try:
        conn = sqlite3.connect("college.db")  # Make sure college.db is present in your repo
        cursor = conn.cursor()
        cursor.execute(sql_query)
        result = cursor.fetchall()
        conn.close()

        return f"SQL Query:\n{sql_query}\n\nQuery Result:\n{result}"
    except Exception as e:
        return f"SQL Query:\n{sql_query}\n\nQuery Result:\nError: {str(e)}"


# 🔹 Gradio UI
desc = """Ask a natural language question about students, employees, or courses. I'll generate and run a SQL query for you."""

demo = gr.Interface(
    fn=response,
    inputs=gr.Textbox(label="Your Question"),
    outputs=gr.Textbox(label="SQL + Result"),
    title="Natural Language to SQL + Result",
    description="Ask a natural language question about students, employees, or courses. I'll generate and run a SQL query for you."
)



demo.launch(share=True)