Balaprime commited on
Commit
aee6de8
·
verified ·
1 Parent(s): 0bb948b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py CHANGED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+
6
+ load_dotenv()
7
+
8
+ api = os.getenv("groq_api_key") # Not needed with sqlcoder
9
+
10
+ def create_metadata_for_sqlcoder(): # Simplified metadata for sqlcoder
11
+ student_schema = """
12
+ Table: student
13
+ Columns:
14
+ - student_id (INTEGER)
15
+ - first_name (TEXT)
16
+ - last_name (TEXT)
17
+ - date_of_birth (DATE)
18
+ - email (TEXT)
19
+ - phone_number (TEXT)
20
+ - major (TEXT)
21
+ - year_of_enrollment (INTEGER)
22
+ """
23
+
24
+ employee_schema = """
25
+ Table: employee
26
+ Columns:
27
+ - employee_id (INTEGER)
28
+ - first_name (TEXT)
29
+ - last_name (TEXT)
30
+ - email (TEXT)
31
+ - department (TEXT)
32
+ - position (TEXT)
33
+ - salary (REAL)
34
+ - date_of_joining (DATE)
35
+ """
36
+
37
+ course_schema = """
38
+ Table: course_info
39
+ Columns:
40
+ - course_id (INTEGER)
41
+ - course_name (TEXT)
42
+ - course_code (TEXT)
43
+ - instructor_id (INTEGER)
44
+ - department (TEXT)
45
+ - credits (INTEGER)
46
+ - semester (TEXT)
47
+ """
48
+ schemas = {
49
+ "student": student_schema,
50
+ "employee": employee_schema,
51
+ "course": course_schema,
52
+ }
53
+ return schemas
54
+
55
+
56
+ def find_best_fit(user_query, schemas): # Simple keyword matching
57
+ """
58
+ Basic table selection based on keywords in the user query. This is a simplified
59
+ version and could be improved with more sophisticated methods.
60
+ """
61
+ query_lower = user_query.lower()
62
+ if "student" in query_lower:
63
+ return schemas["student"]
64
+ elif "employee" in query_lower:
65
+ return schemas["employee"]
66
+ elif "course" in query_lower:
67
+ return schemas["course"]
68
+ else:
69
+ # Default to student if no table is clearly mentioned
70
+ return schemas["student"]
71
+
72
+
73
+
74
+ def create_prompt(user_query, table_metadata):
75
+ """
76
+ Prompt for sqlcoder, including schema and query.
77
+ """
78
+ prompt = f"""
79
+ <s>[INST]You are a text-to-SQL model. Generate a SQL query to answer the question:
80
+ {user_query}
81
+ Here is the schema of the table:
82
+ {table_metadata}
83
+ [/INST]
84
+ SELECT
85
+ """
86
+ return prompt # sqlcoder expects the prompt to end with SELECT
87
+
88
+
89
+ def generate_output(prompt):
90
+ """
91
+ Use the b-mc2/sqlcoder model to generate the SQL query.
92
+ """
93
+ # Use a pipeline for easier interaction with the model
94
+ sql_generator = pipeline("text2sql", model="b-mc2/sqlcoder")
95
+ try:
96
+ result = sql_generator(prompt) # No extra parameters needed.
97
+ # The model is supposed to return only the SQL.
98
+ return result
99
+ except Exception as e:
100
+ return f"Error generating SQL: {e}"
101
+
102
+
103
+
104
+ def response(user_query):
105
+ """
106
+ Main function to process the user query and return the SQL response.
107
+ """
108
+ schemas = create_metadata_for_sqlcoder()
109
+ table_metadata = find_best_fit(user_query, schemas)
110
+ prompt = create_prompt(user_query, table_metadata)
111
+ output = generate_output(prompt)
112
+ return output
113
+
114
+
115
+
116
+ desc = """
117
+ There are three tables in the database:
118
+
119
+ Student Table:
120
+ The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment.
121
+
122
+ Employee Table:
123
+ The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining.
124
+
125
+ Course Info Table:
126
+ The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered.
127
+ """
128
+
129
+ demo = gr.Interface(
130
+ fn=response,
131
+ inputs=gr.Textbox(label="Please provide the natural language query"),
132
+ outputs=gr.Textbox(label="SQL Query"),
133
+ title="SQL Query generator",
134
+ description=desc,
135
+ )
136
+
137
+ demo.launch(share="True")