anjalirathore commited on
Commit
1052f2c
·
verified ·
1 Parent(s): de5e1d8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +174 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from sentence_transformers import SentenceTransformer
4
+ import gradio as gr
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from groq import Groq
7
+ import pandas as pd
8
+ load_dotenv()
9
+ groq_api_key = os.getenv("groq_api_key")
10
+
11
+ dataset_folder = "./data"
12
+
13
+ if not os.path.exists(dataset_folder):
14
+ print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
15
+ dataset_folder = "." # Fallback: Look in the current directory
16
+
17
+ # Print available files for debugging
18
+ print("Available files:", os.listdir(dataset_folder))
19
+
20
+ import warnings
21
+
22
+ # Ignore DtypeWarning
23
+ warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
24
+
25
+ # Load all CSV files in the dataset folder
26
+ dataframes = []
27
+ for file in os.listdir(dataset_folder):
28
+ if file.endswith(".csv"):
29
+ try:
30
+ # Read first few rows to identify column names
31
+ sample_df = pd.read_csv(
32
+ os.path.join(dataset_folder, file),
33
+ nrows=5, # Read only first 5 rows for column type inference
34
+ encoding="utf-8",
35
+ errors="replace" # Replace encoding errors with a placeholder
36
+ )
37
+
38
+ column_types = {col: str for col in sample_df.columns} # Force all columns to string
39
+
40
+ # Read the entire file with enforced column types
41
+ df = pd.read_csv(
42
+ os.path.join(dataset_folder, file),
43
+ dtype=column_types, # Apply enforced string types
44
+ low_memory=False, # Avoid chunk-based reading issues
45
+ encoding="utf-8",
46
+ errors="replace"
47
+ ).fillna('') # Fill NaN values with empty strings
48
+
49
+ dataframes.append(df) # Append DataFrame to the list
50
+ except Exception as e:
51
+ print(f"Error reading {file}: {e}")
52
+
53
+ # Merge all CSV files into one DataFrame (only if there are valid files)
54
+ if dataframes:
55
+ full_data = pd.concat(dataframes, ignore_index=True)
56
+ else:
57
+ print("Warning: No valid CSV files found in the dataset folder.")
58
+ full_data = pd.DataFrame() # Create an empty DataFrame as a fallback
59
+
60
+
61
+ def load_dataset_metadata(dataset_folder):
62
+ """Loads metadata from all CSV files in the dataset folder."""
63
+ dataframes = []
64
+ metadata_list = []
65
+
66
+ for file in os.listdir(dataset_folder):
67
+ if file.endswith(".csv"):
68
+ df = pd.read_csv(os.path.join(dataset_folder, file))
69
+ dataframes.append((file, df))
70
+
71
+ # Generate table metadata
72
+ columns = df.columns.tolist()
73
+ table_metadata = f"""
74
+ Table: {file.replace('.csv', '')}
75
+ Columns:
76
+ {', '.join(columns)}
77
+ """
78
+ metadata_list.append(table_metadata)
79
+
80
+ return dataframes, metadata_list
81
+
82
+ def create_metadata_embeddings(metadata_list):
83
+ """Creates embeddings for all table metadata."""
84
+ model = SentenceTransformer('all-MiniLM-L6-v2')
85
+ embeddings = model.encode(metadata_list)
86
+ return embeddings, model
87
+
88
+ def find_best_fit(embeddings, model, user_query, metadata_list):
89
+ """Finds the best matching table based on user query."""
90
+ query_embedding = model.encode([user_query])
91
+ similarities = cosine_similarity(query_embedding, embeddings)
92
+ best_match_index = similarities.argmax()
93
+ return metadata_list[best_match_index]
94
+
95
+ def create_prompt(user_query, table_metadata):
96
+ """Generates a direct and structured SQL prompt with stricter formatting."""
97
+ system_prompt = f"""
98
+ You are an AI assistant that generates precise SQL queries based on user questions.
99
+
100
+ **Table Name & Columns:**
101
+ {table_metadata}
102
+
103
+ **User Query:**
104
+ {user_query}
105
+
106
+ **Output Format (STRICT):**
107
+ - Provide ONLY the SQL query.
108
+ - Do NOT include explanations, comments, or unnecessary text.
109
+ - Ensure the table and column names match exactly.
110
+ - If the query is impossible, return: "ERROR: Unable to generate query."
111
+
112
+ **Example Queries:**
113
+ - User: "Show all startups founded in 2020."
114
+ - AI Response: SELECT * FROM startups WHERE founded_year = 2020;
115
+
116
+ - User: "List the top 5 startups by total funding."
117
+ - AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
118
+ """
119
+ return system_prompt
120
+
121
+
122
+ def generate_sql_query(system_prompt):
123
+ """Uses Groq API to generate an SQL query with better debugging."""
124
+ try:
125
+ client = Groq(api_key=groq_api_key)
126
+ chat_completion = client.chat.completions.create(
127
+ messages=[{"role": "system", "content": system_prompt}],
128
+ model="llama3-70b-8192"
129
+ )
130
+
131
+ # Debug: Print entire response
132
+ print("🔍 Full API Response:", chat_completion)
133
+
134
+ # Extract AI response
135
+ result = chat_completion.choices[0].message.content.strip()
136
+ print(f"✅ AI Response: {result}") # Debugging
137
+
138
+ # Check if the response starts with "SELECT"
139
+ if result.lower().startswith("select"):
140
+ return result
141
+ else:
142
+ print("⚠️ AI did not generate a valid SQL query!")
143
+ return "⚠️ AI response is not a valid SQL query."
144
+
145
+ except Exception as e:
146
+ print(f"❌ API Error: {e}")
147
+ return "⚠️ API failed. Check logs."
148
+
149
+
150
+ def response(user_query, dataset_folder):
151
+ """Processes the user query and returns an SQL query."""
152
+ dataframes, metadata_list = load_dataset_metadata(dataset_folder)
153
+ embeddings, model = create_metadata_embeddings(metadata_list)
154
+ table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
155
+ system_prompt = create_prompt(user_query, table_metadata)
156
+ return generate_sql_query(system_prompt)
157
+
158
+ dataset_folder = "./data" # Change this based on where your files are uploaded
159
+ user_query = "Show me the top 10 startups with the highest funding."
160
+
161
+ def sql_query_interface(user_query):
162
+ return response(user_query, dataset_folder)
163
+
164
+ # Define Gradio UI
165
+ iface = gr.Interface(
166
+ fn=sql_query_interface,
167
+ inputs=gr.Textbox(label="Enter your query"),
168
+ outputs=gr.Textbox(label="Generated SQL Query"),
169
+ title="AI-Powered SQL Query Generator"
170
+ )
171
+
172
+ # Run Gradio app
173
+ if __name__ == "__main__":
174
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ sentence-transformers
2
+ gradio
3
+ scikit-learn
4
+ groq
5
+ pandas
6
+ python-dotenv