Spaces:
Build error
Build error
File size: 3,772 Bytes
aee6de8 3fbbd1d aee6de8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from dotenv import load_dotenv
import os
import gradio as gr
from transformers import pipeline
load_dotenv()
#api = os.getenv("groq_api_key") # Not needed with sqlcoder
def create_metadata_for_sqlcoder(): # Simplified metadata for sqlcoder
student_schema = """
Table: student
Columns:
- student_id (INTEGER)
- first_name (TEXT)
- last_name (TEXT)
- date_of_birth (DATE)
- email (TEXT)
- phone_number (TEXT)
- major (TEXT)
- year_of_enrollment (INTEGER)
"""
employee_schema = """
Table: employee
Columns:
- employee_id (INTEGER)
- first_name (TEXT)
- last_name (TEXT)
- email (TEXT)
- department (TEXT)
- position (TEXT)
- salary (REAL)
- date_of_joining (DATE)
"""
course_schema = """
Table: course_info
Columns:
- course_id (INTEGER)
- course_name (TEXT)
- course_code (TEXT)
- instructor_id (INTEGER)
- department (TEXT)
- credits (INTEGER)
- semester (TEXT)
"""
schemas = {
"student": student_schema,
"employee": employee_schema,
"course": course_schema,
}
return schemas
def find_best_fit(user_query, schemas): # Simple keyword matching
"""
Basic table selection based on keywords in the user query. This is a simplified
version and could be improved with more sophisticated methods.
"""
query_lower = user_query.lower()
if "student" in query_lower:
return schemas["student"]
elif "employee" in query_lower:
return schemas["employee"]
elif "course" in query_lower:
return schemas["course"]
else:
# Default to student if no table is clearly mentioned
return schemas["student"]
def create_prompt(user_query, table_metadata):
"""
Prompt for sqlcoder, including schema and query.
"""
prompt = f"""
<s>[INST]You are a text-to-SQL model. Generate a SQL query to answer the question:
{user_query}
Here is the schema of the table:
{table_metadata}
[/INST]
SELECT
"""
return prompt # sqlcoder expects the prompt to end with SELECT
def generate_output(prompt):
"""
Use the b-mc2/sqlcoder model to generate the SQL query.
"""
# Use a pipeline for easier interaction with the model
sql_generator = pipeline("text2sql", model="b-mc2/sqlcoder")
try:
result = sql_generator(prompt) # No extra parameters needed.
# The model is supposed to return only the SQL.
return result
except Exception as e:
return f"Error generating SQL: {e}"
def response(user_query):
"""
Main function to process the user query and return the SQL response.
"""
schemas = create_metadata_for_sqlcoder()
table_metadata = find_best_fit(user_query, schemas)
prompt = create_prompt(user_query, table_metadata)
output = generate_output(prompt)
return output
desc = """
There are three tables in the database:
Student Table:
The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment.
Employee Table:
The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining.
Course Info Table:
The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered.
"""
demo = gr.Interface(
fn=response,
inputs=gr.Textbox(label="Please provide the natural language query"),
outputs=gr.Textbox(label="SQL Query"),
title="SQL Query generator",
description=desc,
)
demo.launch(share="True")
|