Update app.py
Browse files
app.py
CHANGED
|
@@ -3,13 +3,22 @@ import os
|
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
import gradio as gr
|
| 5 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
load_dotenv()
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
|
|
|
|
| 13 |
if not os.path.exists(dataset_folder):
|
| 14 |
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
|
| 15 |
dataset_folder = "." # Fallback: Look in the current directory
|
|
@@ -25,28 +34,26 @@ warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
|
|
| 25 |
# Load all CSV files in the dataset folder
|
| 26 |
dataframes = []
|
| 27 |
for file in os.listdir(dataset_folder):
|
| 28 |
-
if file.endswith(".csv"):
|
| 29 |
try:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
encoding="utf-8"
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
column_types = {col: str for col in sample_df.columns} # Force all columns to string
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
dtype=column_types,
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
dataframes.append(df) # Append DataFrame to the list
|
| 50 |
except Exception as e:
|
| 51 |
print(f"Error reading {file}: {e}")
|
| 52 |
|
|
@@ -120,33 +127,32 @@ def create_prompt(user_query, table_metadata):
|
|
| 120 |
|
| 121 |
|
| 122 |
def generate_sql_query(system_prompt):
|
| 123 |
-
"""Uses
|
| 124 |
try:
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
messages=[{"role": "system", "content": system_prompt}],
|
| 128 |
-
model="llama3-70b-8192"
|
| 129 |
-
)
|
| 130 |
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
print(f"✅ AI Response: {result}") # Debugging
|
| 137 |
|
| 138 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
if result.lower().startswith("select"):
|
| 140 |
return result
|
| 141 |
else:
|
| 142 |
-
print("⚠️
|
| 143 |
-
return "⚠️
|
| 144 |
|
| 145 |
except Exception as e:
|
| 146 |
-
print(f"❌ API Error: {e}")
|
| 147 |
return "⚠️ API failed. Check logs."
|
| 148 |
|
| 149 |
-
|
| 150 |
def response(user_query, dataset_folder):
|
| 151 |
"""Processes the user query and returns an SQL query."""
|
| 152 |
dataframes, metadata_list = load_dataset_metadata(dataset_folder)
|
|
|
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
import gradio as gr
|
| 5 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
+
import google.generativeai as genai
|
| 7 |
+
import os
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
import pandas as pd
|
| 10 |
load_dotenv()
|
| 11 |
+
# Add this: read the API key from env and warn if missing
|
| 12 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 13 |
+
if not GEMINI_API_KEY:
|
| 14 |
+
print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.")
|
| 15 |
+
|
| 16 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
| 17 |
|
| 18 |
+
# Use the current directory for Hugging Face Spaces
|
| 19 |
+
dataset_folder = "./data" # Assuming files are in a 'data/' folder
|
| 20 |
|
| 21 |
+
# Verify the folder exists
|
| 22 |
if not os.path.exists(dataset_folder):
|
| 23 |
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
|
| 24 |
dataset_folder = "." # Fallback: Look in the current directory
|
|
|
|
| 34 |
# Load all CSV files in the dataset folder
|
| 35 |
dataframes = []
|
| 36 |
for file in os.listdir(dataset_folder):
|
| 37 |
+
if file.endswith(".csv"): # Check if the file is a CSV
|
| 38 |
try:
|
| 39 |
+
path = os.path.join(dataset_folder, file)
|
| 40 |
+
|
| 41 |
+
# Try reading with utf-8, fallback to latin1 if encoding fails
|
| 42 |
+
try:
|
| 43 |
+
sample_df = pd.read_csv(path, nrows=5, encoding="utf-8")
|
| 44 |
+
except UnicodeDecodeError:
|
| 45 |
+
sample_df = pd.read_csv(path, nrows=5, encoding="latin1")
|
| 46 |
|
| 47 |
column_types = {col: str for col in sample_df.columns} # Force all columns to string
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8")
|
| 51 |
+
except UnicodeDecodeError:
|
| 52 |
+
df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1")
|
| 53 |
+
|
| 54 |
+
df = df.fillna('') # Fill NaN values with empty strings
|
| 55 |
+
dataframes.append(df)
|
| 56 |
+
|
|
|
|
|
|
|
| 57 |
except Exception as e:
|
| 58 |
print(f"Error reading {file}: {e}")
|
| 59 |
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def generate_sql_query(system_prompt):
|
| 130 |
+
"""Uses Gemini API to generate an SQL query."""
|
| 131 |
try:
|
| 132 |
+
# Initialize the Gemini model (use a reliable text model)
|
| 133 |
+
model = genai.GenerativeModel("gemini-2.5-pro")
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
+
# Generate content from the system prompt
|
| 136 |
+
response = model.generate_content(system_prompt)
|
| 137 |
|
| 138 |
+
# Debug: print the full Gemini response
|
| 139 |
+
print("🔍 Full API Response:", response)
|
|
|
|
| 140 |
|
| 141 |
+
# Extract AI text response
|
| 142 |
+
result = response.text.strip()
|
| 143 |
+
print(f"✅ AI Response: {result}")
|
| 144 |
+
|
| 145 |
+
# Validate SQL query
|
| 146 |
if result.lower().startswith("select"):
|
| 147 |
return result
|
| 148 |
else:
|
| 149 |
+
print("⚠️ Gemini did not generate a valid SQL query.")
|
| 150 |
+
return "⚠️ Invalid SQL query generated."
|
| 151 |
|
| 152 |
except Exception as e:
|
| 153 |
+
print(f"❌ Gemini API Error: {e}")
|
| 154 |
return "⚠️ API failed. Check logs."
|
| 155 |
|
|
|
|
| 156 |
def response(user_query, dataset_folder):
|
| 157 |
"""Processes the user query and returns an SQL query."""
|
| 158 |
dataframes, metadata_list = load_dataset_metadata(dataset_folder)
|