File size: 3,772 Bytes
aee6de8
 
 
 
 
 
 
3fbbd1d
aee6de8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from dotenv import load_dotenv
import os
import gradio as gr
from transformers import pipeline

load_dotenv()

#api = os.getenv("groq_api_key") # Not needed with sqlcoder

def create_metadata_for_sqlcoder():  # Simplified metadata for sqlcoder
    student_schema = """
    Table: student
    Columns:
    - student_id (INTEGER)
    - first_name (TEXT)
    - last_name (TEXT)
    - date_of_birth (DATE)
    - email (TEXT)
    - phone_number (TEXT)
    - major (TEXT)
    - year_of_enrollment (INTEGER)
    """

    employee_schema = """
    Table: employee
    Columns:
    - employee_id (INTEGER)
    - first_name (TEXT)
    - last_name (TEXT)
    - email (TEXT)
    - department (TEXT)
    - position (TEXT)
    - salary (REAL)
    - date_of_joining (DATE)
    """

    course_schema = """
    Table: course_info
    Columns:
    - course_id (INTEGER)
    - course_name (TEXT)
    - course_code (TEXT)
    - instructor_id (INTEGER)
    - department (TEXT)
    - credits (INTEGER)
    - semester (TEXT)
    """
    schemas = {
        "student": student_schema,
        "employee": employee_schema,
        "course": course_schema,
    }
    return schemas


def find_best_fit(user_query, schemas):  # Simple keyword matching
    """
    Basic table selection based on keywords in the user query.  This is a simplified
    version and could be improved with more sophisticated methods.
    """
    query_lower = user_query.lower()
    if "student" in query_lower:
        return schemas["student"]
    elif "employee" in query_lower:
        return schemas["employee"]
    elif "course" in query_lower:
        return schemas["course"]
    else:
        # Default to student if no table is clearly mentioned
        return schemas["student"]



def create_prompt(user_query, table_metadata):
    """
    Prompt for sqlcoder, including schema and query.
    """
    prompt = f"""
    <s>[INST]You are a text-to-SQL model. Generate a SQL query to answer the question:
    {user_query}
    Here is the schema of the table:
    {table_metadata}
    [/INST]
    SELECT
    """
    return prompt  # sqlcoder expects the prompt to end with SELECT


def generate_output(prompt):
    """
    Use the b-mc2/sqlcoder model to generate the SQL query.
    """
    # Use a pipeline for easier interaction with the model
    sql_generator = pipeline("text2sql", model="b-mc2/sqlcoder")
    try:
        result = sql_generator(prompt)  # No extra parameters needed.
        # The model is supposed to return only the SQL.
        return result
    except Exception as e:
        return f"Error generating SQL: {e}"



def response(user_query):
    """
    Main function to process the user query and return the SQL response.
    """
    schemas = create_metadata_for_sqlcoder()
    table_metadata = find_best_fit(user_query, schemas)
    prompt = create_prompt(user_query, table_metadata)
    output = generate_output(prompt)
    return output



desc = """
There are three tables in the database:

Student Table:  
The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment.

Employee Table:  
The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining.

Course Info Table:  
The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered.
"""

demo = gr.Interface(
    fn=response,
    inputs=gr.Textbox(label="Please provide the natural language query"),
    outputs=gr.Textbox(label="SQL Query"),
    title="SQL Query generator",
    description=desc,
)

demo.launch(share="True")