File size: 6,462 Bytes
1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c 5f9b2f9 1052f2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
import google.generativeai as genai
import os
from dotenv import load_dotenv
import pandas as pd
load_dotenv()
# Add this: read the API key from env and warn if missing
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.")
genai.configure(api_key=GEMINI_API_KEY)
# Use the current directory for Hugging Face Spaces
dataset_folder = "./data" # Assuming files are in a 'data/' folder
# Verify the folder exists
if not os.path.exists(dataset_folder):
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
dataset_folder = "." # Fallback: Look in the current directory
# Print available files for debugging
print("Available files:", os.listdir(dataset_folder))
import warnings
# Ignore DtypeWarning
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
# Load all CSV files in the dataset folder
dataframes = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"): # Check if the file is a CSV
try:
path = os.path.join(dataset_folder, file)
# Try reading with utf-8, fallback to latin1 if encoding fails
try:
sample_df = pd.read_csv(path, nrows=5, encoding="utf-8")
except UnicodeDecodeError:
sample_df = pd.read_csv(path, nrows=5, encoding="latin1")
column_types = {col: str for col in sample_df.columns} # Force all columns to string
try:
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8")
except UnicodeDecodeError:
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1")
df = df.fillna('') # Fill NaN values with empty strings
dataframes.append(df)
except Exception as e:
print(f"Error reading {file}: {e}")
# Merge all CSV files into one DataFrame (only if there are valid files)
if dataframes:
full_data = pd.concat(dataframes, ignore_index=True)
else:
print("Warning: No valid CSV files found in the dataset folder.")
full_data = pd.DataFrame() # Create an empty DataFrame as a fallback
def load_dataset_metadata(dataset_folder):
"""Loads metadata from all CSV files in the dataset folder."""
dataframes = []
metadata_list = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"):
df = pd.read_csv(os.path.join(dataset_folder, file))
dataframes.append((file, df))
# Generate table metadata
columns = df.columns.tolist()
table_metadata = f"""
Table: {file.replace('.csv', '')}
Columns:
{', '.join(columns)}
"""
metadata_list.append(table_metadata)
return dataframes, metadata_list
def create_metadata_embeddings(metadata_list):
"""Creates embeddings for all table metadata."""
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model
def find_best_fit(embeddings, model, user_query, metadata_list):
"""Finds the best matching table based on user query."""
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)
best_match_index = similarities.argmax()
return metadata_list[best_match_index]
def create_prompt(user_query, table_metadata):
"""Generates a direct and structured SQL prompt with stricter formatting."""
system_prompt = f"""
You are an AI assistant that generates precise SQL queries based on user questions.
**Table Name & Columns:**
{table_metadata}
**User Query:**
{user_query}
**Output Format (STRICT):**
- Provide ONLY the SQL query.
- Do NOT include explanations, comments, or unnecessary text.
- Ensure the table and column names match exactly.
- If the query is impossible, return: "ERROR: Unable to generate query."
**Example Queries:**
- User: "Show all startups founded in 2020."
- AI Response: SELECT * FROM startups WHERE founded_year = 2020;
- User: "List the top 5 startups by total funding."
- AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
"""
return system_prompt
def generate_sql_query(system_prompt):
"""Uses Gemini API to generate an SQL query."""
try:
# Initialize the Gemini model (use a reliable text model)
model = genai.GenerativeModel("gemini-2.5-pro")
# Generate content from the system prompt
response = model.generate_content(system_prompt)
# Debug: print the full Gemini response
print("🔍 Full API Response:", response)
# Extract AI text response
result = response.text.strip()
print(f"✅ AI Response: {result}")
# Validate SQL query
if result.lower().startswith("select"):
return result
else:
print("⚠️ Gemini did not generate a valid SQL query.")
return "⚠️ Invalid SQL query generated."
except Exception as e:
print(f"❌ Gemini API Error: {e}")
return "⚠️ API failed. Check logs."
def response(user_query, dataset_folder):
"""Processes the user query and returns an SQL query."""
dataframes, metadata_list = load_dataset_metadata(dataset_folder)
embeddings, model = create_metadata_embeddings(metadata_list)
table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
system_prompt = create_prompt(user_query, table_metadata)
return generate_sql_query(system_prompt)
dataset_folder = "./data" # Change this based on where your files are uploaded
user_query = "Show me the top 10 startups with the highest funding."
def sql_query_interface(user_query):
return response(user_query, dataset_folder)
# Define Gradio UI
iface = gr.Interface(
fn=sql_query_interface,
inputs=gr.Textbox(label="Enter your query"),
outputs=gr.Textbox(label="Generated SQL Query"),
title="AI-Powered SQL Query Generator"
)
# Run Gradio app
if __name__ == "__main__":
iface.launch() |