New_nl2sql / app.py
Balaprime's picture
Update app.py
3fbbd1d verified
from dotenv import load_dotenv
import os
import gradio as gr
from transformers import pipeline
load_dotenv()
#api = os.getenv("groq_api_key") # Not needed with sqlcoder
def create_metadata_for_sqlcoder(): # Simplified metadata for sqlcoder
student_schema = """
Table: student
Columns:
- student_id (INTEGER)
- first_name (TEXT)
- last_name (TEXT)
- date_of_birth (DATE)
- email (TEXT)
- phone_number (TEXT)
- major (TEXT)
- year_of_enrollment (INTEGER)
"""
employee_schema = """
Table: employee
Columns:
- employee_id (INTEGER)
- first_name (TEXT)
- last_name (TEXT)
- email (TEXT)
- department (TEXT)
- position (TEXT)
- salary (REAL)
- date_of_joining (DATE)
"""
course_schema = """
Table: course_info
Columns:
- course_id (INTEGER)
- course_name (TEXT)
- course_code (TEXT)
- instructor_id (INTEGER)
- department (TEXT)
- credits (INTEGER)
- semester (TEXT)
"""
schemas = {
"student": student_schema,
"employee": employee_schema,
"course": course_schema,
}
return schemas
def find_best_fit(user_query, schemas): # Simple keyword matching
"""
Basic table selection based on keywords in the user query. This is a simplified
version and could be improved with more sophisticated methods.
"""
query_lower = user_query.lower()
if "student" in query_lower:
return schemas["student"]
elif "employee" in query_lower:
return schemas["employee"]
elif "course" in query_lower:
return schemas["course"]
else:
# Default to student if no table is clearly mentioned
return schemas["student"]
def create_prompt(user_query, table_metadata):
"""
Prompt for sqlcoder, including schema and query.
"""
prompt = f"""
<s>[INST]You are a text-to-SQL model. Generate a SQL query to answer the question:
{user_query}
Here is the schema of the table:
{table_metadata}
[/INST]
SELECT
"""
return prompt # sqlcoder expects the prompt to end with SELECT
def generate_output(prompt):
"""
Use the b-mc2/sqlcoder model to generate the SQL query.
"""
# Use a pipeline for easier interaction with the model
sql_generator = pipeline("text2sql", model="b-mc2/sqlcoder")
try:
result = sql_generator(prompt) # No extra parameters needed.
# The model is supposed to return only the SQL.
return result
except Exception as e:
return f"Error generating SQL: {e}"
def response(user_query):
"""
Main function to process the user query and return the SQL response.
"""
schemas = create_metadata_for_sqlcoder()
table_metadata = find_best_fit(user_query, schemas)
prompt = create_prompt(user_query, table_metadata)
output = generate_output(prompt)
return output
desc = """
There are three tables in the database:
Student Table:
The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment.
Employee Table:
The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining.
Course Info Table:
The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered.
"""
demo = gr.Interface(
fn=response,
inputs=gr.Textbox(label="Please provide the natural language query"),
outputs=gr.Textbox(label="SQL Query"),
title="SQL Query generator",
description=desc,
)
demo.launch(share="True")