Spaces:
Build error
Build error
| from dotenv import load_dotenv | |
| import os | |
| import gradio as gr | |
| from transformers import pipeline | |
| load_dotenv() | |
| #api = os.getenv("groq_api_key") # Not needed with sqlcoder | |
| def create_metadata_for_sqlcoder(): # Simplified metadata for sqlcoder | |
| student_schema = """ | |
| Table: student | |
| Columns: | |
| - student_id (INTEGER) | |
| - first_name (TEXT) | |
| - last_name (TEXT) | |
| - date_of_birth (DATE) | |
| - email (TEXT) | |
| - phone_number (TEXT) | |
| - major (TEXT) | |
| - year_of_enrollment (INTEGER) | |
| """ | |
| employee_schema = """ | |
| Table: employee | |
| Columns: | |
| - employee_id (INTEGER) | |
| - first_name (TEXT) | |
| - last_name (TEXT) | |
| - email (TEXT) | |
| - department (TEXT) | |
| - position (TEXT) | |
| - salary (REAL) | |
| - date_of_joining (DATE) | |
| """ | |
| course_schema = """ | |
| Table: course_info | |
| Columns: | |
| - course_id (INTEGER) | |
| - course_name (TEXT) | |
| - course_code (TEXT) | |
| - instructor_id (INTEGER) | |
| - department (TEXT) | |
| - credits (INTEGER) | |
| - semester (TEXT) | |
| """ | |
| schemas = { | |
| "student": student_schema, | |
| "employee": employee_schema, | |
| "course": course_schema, | |
| } | |
| return schemas | |
| def find_best_fit(user_query, schemas): # Simple keyword matching | |
| """ | |
| Basic table selection based on keywords in the user query. This is a simplified | |
| version and could be improved with more sophisticated methods. | |
| """ | |
| query_lower = user_query.lower() | |
| if "student" in query_lower: | |
| return schemas["student"] | |
| elif "employee" in query_lower: | |
| return schemas["employee"] | |
| elif "course" in query_lower: | |
| return schemas["course"] | |
| else: | |
| # Default to student if no table is clearly mentioned | |
| return schemas["student"] | |
| def create_prompt(user_query, table_metadata): | |
| """ | |
| Prompt for sqlcoder, including schema and query. | |
| """ | |
| prompt = f""" | |
| <s>[INST]You are a text-to-SQL model. Generate a SQL query to answer the question: | |
| {user_query} | |
| Here is the schema of the table: | |
| {table_metadata} | |
| [/INST] | |
| SELECT | |
| """ | |
| return prompt # sqlcoder expects the prompt to end with SELECT | |
| def generate_output(prompt): | |
| """ | |
| Use the b-mc2/sqlcoder model to generate the SQL query. | |
| """ | |
| # Use a pipeline for easier interaction with the model | |
| sql_generator = pipeline("text2sql", model="b-mc2/sqlcoder") | |
| try: | |
| result = sql_generator(prompt) # No extra parameters needed. | |
| # The model is supposed to return only the SQL. | |
| return result | |
| except Exception as e: | |
| return f"Error generating SQL: {e}" | |
| def response(user_query): | |
| """ | |
| Main function to process the user query and return the SQL response. | |
| """ | |
| schemas = create_metadata_for_sqlcoder() | |
| table_metadata = find_best_fit(user_query, schemas) | |
| prompt = create_prompt(user_query, table_metadata) | |
| output = generate_output(prompt) | |
| return output | |
| desc = """ | |
| There are three tables in the database: | |
| Student Table: | |
| The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment. | |
| Employee Table: | |
| The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining. | |
| Course Info Table: | |
| The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered. | |
| """ | |
| demo = gr.Interface( | |
| fn=response, | |
| inputs=gr.Textbox(label="Please provide the natural language query"), | |
| outputs=gr.Textbox(label="SQL Query"), | |
| title="SQL Query generator", | |
| description=desc, | |
| ) | |
| demo.launch(share="True") | |