anjalirathore's picture
Update app.py
5f9b2f9 verified
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
import google.generativeai as genai
import os
from dotenv import load_dotenv
import pandas as pd
load_dotenv()
# Add this: read the API key from env and warn if missing
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.")
genai.configure(api_key=GEMINI_API_KEY)
# Use the current directory for Hugging Face Spaces
dataset_folder = "./data" # Assuming files are in a 'data/' folder
# Verify the folder exists
if not os.path.exists(dataset_folder):
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
dataset_folder = "." # Fallback: Look in the current directory
# Print available files for debugging
print("Available files:", os.listdir(dataset_folder))
import warnings
# Ignore DtypeWarning
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
# Load all CSV files in the dataset folder
dataframes = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"): # Check if the file is a CSV
try:
path = os.path.join(dataset_folder, file)
# Try reading with utf-8, fallback to latin1 if encoding fails
try:
sample_df = pd.read_csv(path, nrows=5, encoding="utf-8")
except UnicodeDecodeError:
sample_df = pd.read_csv(path, nrows=5, encoding="latin1")
column_types = {col: str for col in sample_df.columns} # Force all columns to string
try:
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8")
except UnicodeDecodeError:
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1")
df = df.fillna('') # Fill NaN values with empty strings
dataframes.append(df)
except Exception as e:
print(f"Error reading {file}: {e}")
# Merge all CSV files into one DataFrame (only if there are valid files)
if dataframes:
full_data = pd.concat(dataframes, ignore_index=True)
else:
print("Warning: No valid CSV files found in the dataset folder.")
full_data = pd.DataFrame() # Create an empty DataFrame as a fallback
def load_dataset_metadata(dataset_folder):
"""Loads metadata from all CSV files in the dataset folder."""
dataframes = []
metadata_list = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"):
df = pd.read_csv(os.path.join(dataset_folder, file))
dataframes.append((file, df))
# Generate table metadata
columns = df.columns.tolist()
table_metadata = f"""
Table: {file.replace('.csv', '')}
Columns:
{', '.join(columns)}
"""
metadata_list.append(table_metadata)
return dataframes, metadata_list
def create_metadata_embeddings(metadata_list):
"""Creates embeddings for all table metadata."""
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model
def find_best_fit(embeddings, model, user_query, metadata_list):
"""Finds the best matching table based on user query."""
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)
best_match_index = similarities.argmax()
return metadata_list[best_match_index]
def create_prompt(user_query, table_metadata):
"""Generates a direct and structured SQL prompt with stricter formatting."""
system_prompt = f"""
You are an AI assistant that generates precise SQL queries based on user questions.
**Table Name & Columns:**
{table_metadata}
**User Query:**
{user_query}
**Output Format (STRICT):**
- Provide ONLY the SQL query.
- Do NOT include explanations, comments, or unnecessary text.
- Ensure the table and column names match exactly.
- If the query is impossible, return: "ERROR: Unable to generate query."
**Example Queries:**
- User: "Show all startups founded in 2020."
- AI Response: SELECT * FROM startups WHERE founded_year = 2020;
- User: "List the top 5 startups by total funding."
- AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
"""
return system_prompt
def generate_sql_query(system_prompt):
"""Uses Gemini API to generate an SQL query."""
try:
# Initialize the Gemini model (use a reliable text model)
model = genai.GenerativeModel("gemini-2.5-pro")
# Generate content from the system prompt
response = model.generate_content(system_prompt)
# Debug: print the full Gemini response
print("🔍 Full API Response:", response)
# Extract AI text response
result = response.text.strip()
print(f"✅ AI Response: {result}")
# Validate SQL query
if result.lower().startswith("select"):
return result
else:
print("⚠️ Gemini did not generate a valid SQL query.")
return "⚠️ Invalid SQL query generated."
except Exception as e:
print(f"❌ Gemini API Error: {e}")
return "⚠️ API failed. Check logs."
def response(user_query, dataset_folder):
"""Processes the user query and returns an SQL query."""
dataframes, metadata_list = load_dataset_metadata(dataset_folder)
embeddings, model = create_metadata_embeddings(metadata_list)
table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
system_prompt = create_prompt(user_query, table_metadata)
return generate_sql_query(system_prompt)
dataset_folder = "./data" # Change this based on where your files are uploaded
user_query = "Show me the top 10 startups with the highest funding."
def sql_query_interface(user_query):
return response(user_query, dataset_folder)
# Define Gradio UI
iface = gr.Interface(
fn=sql_query_interface,
inputs=gr.Textbox(label="Enter your query"),
outputs=gr.Textbox(label="Generated SQL Query"),
title="AI-Powered SQL Query Generator"
)
# Run Gradio app
if __name__ == "__main__":
iface.launch()