File size: 6,260 Bytes
2116f61 b86c9dc 403fc41 684ee50 403fc41 b86c9dc 403fc41 b86c9dc 4765d25 7e25c09 4765d25 403fc41 684ee50 4765d25 7e25c09 4765d25 7e25c09 4765d25 7e25c09 4765d25 9e191ef 7e25c09 2116f61 b86c9dc 2116f61 6849134 d8cbcf4 6849134 d8cbcf4 6849134 d8cbcf4 6849134 d8cbcf4 6849134 2116f61 d8cbcf4 2116f61 d8cbcf4 7634430 d8cbcf4 7634430 d8cbcf4 7634430 2116f61 d8cbcf4 2116f61 403fc41 b86c9dc d8cbcf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import pandas as pd
load_dotenv()
groq_api_key = os.getenv("groq_api_key")
dataset_folder = "./data"
if not os.path.exists(dataset_folder):
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
dataset_folder = "." # Fallback: Look in the current directory
# Print available files for debugging
print("Available files:", os.listdir(dataset_folder))
import warnings
# Ignore DtypeWarning
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
# Load all CSV files in the dataset folder
dataframes = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"):
try:
# Read first few rows to identify column names
sample_df = pd.read_csv(
os.path.join(dataset_folder, file),
nrows=5, # Read only first 5 rows for column type inference
encoding="utf-8",
errors="replace" # Replace encoding errors with a placeholder
)
column_types = {col: str for col in sample_df.columns} # Force all columns to string
# Read the entire file with enforced column types
df = pd.read_csv(
os.path.join(dataset_folder, file),
dtype=column_types, # Apply enforced string types
low_memory=False, # Avoid chunk-based reading issues
encoding="utf-8",
errors="replace"
).fillna('') # Fill NaN values with empty strings
dataframes.append(df) # Append DataFrame to the list
except Exception as e:
print(f"Error reading {file}: {e}")
# Merge all CSV files into one DataFrame (only if there are valid files)
if dataframes:
full_data = pd.concat(dataframes, ignore_index=True)
else:
print("Warning: No valid CSV files found in the dataset folder.")
full_data = pd.DataFrame() # Create an empty DataFrame as a fallback
def load_dataset_metadata(dataset_folder):
"""Loads metadata from all CSV files in the dataset folder."""
dataframes = []
metadata_list = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"):
df = pd.read_csv(os.path.join(dataset_folder, file))
dataframes.append((file, df))
# Generate table metadata
columns = df.columns.tolist()
table_metadata = f"""
Table: {file.replace('.csv', '')}
Columns:
{', '.join(columns)}
"""
metadata_list.append(table_metadata)
return dataframes, metadata_list
def create_metadata_embeddings(metadata_list):
"""Creates embeddings for all table metadata."""
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model
def find_best_fit(embeddings, model, user_query, metadata_list):
"""Finds the best matching table based on user query."""
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)
best_match_index = similarities.argmax()
return metadata_list[best_match_index]
def create_prompt(user_query, table_metadata):
"""Generates a direct and structured SQL prompt with stricter formatting."""
system_prompt = f"""
You are an AI assistant that generates precise SQL queries based on user questions.
**Table Name & Columns:**
{table_metadata}
**User Query:**
{user_query}
**Output Format (STRICT):**
- Provide ONLY the SQL query.
- Do NOT include explanations, comments, or unnecessary text.
- Ensure the table and column names match exactly.
- If the query is impossible, return: "ERROR: Unable to generate query."
**Example Queries:**
- User: "Show all startups founded in 2020."
- AI Response: SELECT * FROM startups WHERE founded_year = 2020;
- User: "List the top 5 startups by total funding."
- AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
"""
return system_prompt
def generate_sql_query(system_prompt):
"""Uses Groq API to generate an SQL query with better debugging."""
try:
client = Groq(api_key=groq_api_key)
chat_completion = client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt}],
model="llama3-70b-8192"
)
# Debug: Print entire response
print("🔍 Full API Response:", chat_completion)
# Extract AI response
result = chat_completion.choices[0].message.content.strip()
print(f"✅ AI Response: {result}") # Debugging
# Check if the response starts with "SELECT"
if result.lower().startswith("select"):
return result
else:
print("⚠️ AI did not generate a valid SQL query!")
return "⚠️ AI response is not a valid SQL query."
except Exception as e:
print(f"❌ API Error: {e}")
return "⚠️ API failed. Check logs."
def response(user_query, dataset_folder):
"""Processes the user query and returns an SQL query."""
dataframes, metadata_list = load_dataset_metadata(dataset_folder)
embeddings, model = create_metadata_embeddings(metadata_list)
table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
system_prompt = create_prompt(user_query, table_metadata)
return generate_sql_query(system_prompt)
dataset_folder = "./data" # Change this based on where your files are uploaded
user_query = "Show me the top 10 startups with the highest funding."
def sql_query_interface(user_query):
return response(user_query, dataset_folder)
# Define Gradio UI
iface = gr.Interface(
fn=sql_query_interface,
inputs=gr.Textbox(label="Enter your query"),
outputs=gr.Textbox(label="Generated SQL Query"),
title="AI-Powered SQL Query Generator"
)
# Run Gradio app
if __name__ == "__main__":
iface.launch() |