|
|
from dotenv import load_dotenv |
|
|
import os |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import gradio as gr |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import google.generativeai as genai |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
import pandas as pd |
|
|
load_dotenv() |
|
|
|
|
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") |
|
|
if not GEMINI_API_KEY: |
|
|
print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.") |
|
|
|
|
|
genai.configure(api_key=GEMINI_API_KEY) |
|
|
|
|
|
|
|
|
dataset_folder = "./data" |
|
|
|
|
|
|
|
|
if not os.path.exists(dataset_folder): |
|
|
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.") |
|
|
dataset_folder = "." |
|
|
|
|
|
|
|
|
print("Available files:", os.listdir(dataset_folder)) |
|
|
|
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning) |
|
|
|
|
|
|
|
|
dataframes = [] |
|
|
for file in os.listdir(dataset_folder): |
|
|
if file.endswith(".csv"): |
|
|
try: |
|
|
path = os.path.join(dataset_folder, file) |
|
|
|
|
|
|
|
|
try: |
|
|
sample_df = pd.read_csv(path, nrows=5, encoding="utf-8") |
|
|
except UnicodeDecodeError: |
|
|
sample_df = pd.read_csv(path, nrows=5, encoding="latin1") |
|
|
|
|
|
column_types = {col: str for col in sample_df.columns} |
|
|
|
|
|
try: |
|
|
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8") |
|
|
except UnicodeDecodeError: |
|
|
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1") |
|
|
|
|
|
df = df.fillna('') |
|
|
dataframes.append(df) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error reading {file}: {e}") |
|
|
|
|
|
|
|
|
if dataframes: |
|
|
full_data = pd.concat(dataframes, ignore_index=True) |
|
|
else: |
|
|
print("Warning: No valid CSV files found in the dataset folder.") |
|
|
full_data = pd.DataFrame() |
|
|
|
|
|
|
|
|
def load_dataset_metadata(dataset_folder): |
|
|
"""Loads metadata from all CSV files in the dataset folder.""" |
|
|
dataframes = [] |
|
|
metadata_list = [] |
|
|
|
|
|
for file in os.listdir(dataset_folder): |
|
|
if file.endswith(".csv"): |
|
|
df = pd.read_csv(os.path.join(dataset_folder, file)) |
|
|
dataframes.append((file, df)) |
|
|
|
|
|
|
|
|
columns = df.columns.tolist() |
|
|
table_metadata = f""" |
|
|
Table: {file.replace('.csv', '')} |
|
|
Columns: |
|
|
{', '.join(columns)} |
|
|
""" |
|
|
metadata_list.append(table_metadata) |
|
|
|
|
|
return dataframes, metadata_list |
|
|
|
|
|
def create_metadata_embeddings(metadata_list): |
|
|
"""Creates embeddings for all table metadata.""" |
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
embeddings = model.encode(metadata_list) |
|
|
return embeddings, model |
|
|
|
|
|
def find_best_fit(embeddings, model, user_query, metadata_list): |
|
|
"""Finds the best matching table based on user query.""" |
|
|
query_embedding = model.encode([user_query]) |
|
|
similarities = cosine_similarity(query_embedding, embeddings) |
|
|
best_match_index = similarities.argmax() |
|
|
return metadata_list[best_match_index] |
|
|
|
|
|
def create_prompt(user_query, table_metadata): |
|
|
"""Generates a direct and structured SQL prompt with stricter formatting.""" |
|
|
system_prompt = f""" |
|
|
You are an AI assistant that generates precise SQL queries based on user questions. |
|
|
|
|
|
**Table Name & Columns:** |
|
|
{table_metadata} |
|
|
|
|
|
**User Query:** |
|
|
{user_query} |
|
|
|
|
|
**Output Format (STRICT):** |
|
|
- Provide ONLY the SQL query. |
|
|
- Do NOT include explanations, comments, or unnecessary text. |
|
|
- Ensure the table and column names match exactly. |
|
|
- If the query is impossible, return: "ERROR: Unable to generate query." |
|
|
|
|
|
**Example Queries:** |
|
|
- User: "Show all startups founded in 2020." |
|
|
- AI Response: SELECT * FROM startups WHERE founded_year = 2020; |
|
|
|
|
|
- User: "List the top 5 startups by total funding." |
|
|
- AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5; |
|
|
""" |
|
|
return system_prompt |
|
|
|
|
|
|
|
|
def generate_sql_query(system_prompt): |
|
|
"""Uses Gemini API to generate an SQL query.""" |
|
|
try: |
|
|
|
|
|
model = genai.GenerativeModel("gemini-2.5-pro") |
|
|
|
|
|
|
|
|
response = model.generate_content(system_prompt) |
|
|
|
|
|
|
|
|
print("🔍 Full API Response:", response) |
|
|
|
|
|
|
|
|
result = response.text.strip() |
|
|
print(f"✅ AI Response: {result}") |
|
|
|
|
|
|
|
|
if result.lower().startswith("select"): |
|
|
return result |
|
|
else: |
|
|
print("⚠️ Gemini did not generate a valid SQL query.") |
|
|
return "⚠️ Invalid SQL query generated." |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Gemini API Error: {e}") |
|
|
return "⚠️ API failed. Check logs." |
|
|
|
|
|
def response(user_query, dataset_folder): |
|
|
"""Processes the user query and returns an SQL query.""" |
|
|
dataframes, metadata_list = load_dataset_metadata(dataset_folder) |
|
|
embeddings, model = create_metadata_embeddings(metadata_list) |
|
|
table_metadata = find_best_fit(embeddings, model, user_query, metadata_list) |
|
|
system_prompt = create_prompt(user_query, table_metadata) |
|
|
return generate_sql_query(system_prompt) |
|
|
|
|
|
dataset_folder = "./data" |
|
|
user_query = "Show me the top 10 startups with the highest funding." |
|
|
|
|
|
def sql_query_interface(user_query): |
|
|
return response(user_query, dataset_folder) |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=sql_query_interface, |
|
|
inputs=gr.Textbox(label="Enter your query"), |
|
|
outputs=gr.Textbox(label="Generated SQL Query"), |
|
|
title="AI-Powered SQL Query Generator" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |