File size: 6,462 Bytes
1052f2c
 
 
 
 
5f9b2f9
 
 
1052f2c
 
5f9b2f9
 
 
 
 
 
1052f2c
5f9b2f9
 
1052f2c
5f9b2f9
1052f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f9b2f9
1052f2c
5f9b2f9
 
 
 
 
 
 
1052f2c
 
5f9b2f9
 
 
 
 
 
 
 
 
1052f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f9b2f9
1052f2c
5f9b2f9
 
1052f2c
5f9b2f9
 
1052f2c
5f9b2f9
 
1052f2c
5f9b2f9
 
 
 
 
1052f2c
 
 
5f9b2f9
 
1052f2c
 
5f9b2f9
1052f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
import google.generativeai as genai
import os
from dotenv import load_dotenv
import pandas as pd
load_dotenv()
# Add this: read the API key from env and warn if missing
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.")

genai.configure(api_key=GEMINI_API_KEY)

# Use the current directory for Hugging Face Spaces
dataset_folder = "./data" # Assuming files are in a 'data/' folder

# Verify the folder exists
if not os.path.exists(dataset_folder):
    print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
    dataset_folder = "."  # Fallback: Look in the current directory

# Print available files for debugging
print("Available files:", os.listdir(dataset_folder))

import warnings

# Ignore DtypeWarning
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)

# Load all CSV files in the dataset folder
dataframes = []
for file in os.listdir(dataset_folder):
    if file.endswith(".csv"):  # Check if the file is a CSV
        try:
            path = os.path.join(dataset_folder, file)

            # Try reading with utf-8, fallback to latin1 if encoding fails
            try:
                sample_df = pd.read_csv(path, nrows=5, encoding="utf-8")
            except UnicodeDecodeError:
                sample_df = pd.read_csv(path, nrows=5, encoding="latin1")

            column_types = {col: str for col in sample_df.columns}  # Force all columns to string

            try:
                df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8")
            except UnicodeDecodeError:
                df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1")

            df = df.fillna('')  # Fill NaN values with empty strings
            dataframes.append(df)

        except Exception as e:
            print(f"Error reading {file}: {e}")

# Merge all CSV files into one DataFrame (only if there are valid files)
if dataframes:
    full_data = pd.concat(dataframes, ignore_index=True)
else:
    print("Warning: No valid CSV files found in the dataset folder.")
    full_data = pd.DataFrame()  # Create an empty DataFrame as a fallback
    

def load_dataset_metadata(dataset_folder):
    """Loads metadata from all CSV files in the dataset folder."""
    dataframes = []
    metadata_list = []
    
    for file in os.listdir(dataset_folder):
        if file.endswith(".csv"):
            df = pd.read_csv(os.path.join(dataset_folder, file))
            dataframes.append((file, df))

            # Generate table metadata
            columns = df.columns.tolist()
            table_metadata = f"""
            Table: {file.replace('.csv', '')}
            Columns:
            {', '.join(columns)}
            """
            metadata_list.append(table_metadata)
    
    return dataframes, metadata_list

def create_metadata_embeddings(metadata_list):
    """Creates embeddings for all table metadata."""
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(metadata_list)
    return embeddings, model

def find_best_fit(embeddings, model, user_query, metadata_list):
    """Finds the best matching table based on user query."""
    query_embedding = model.encode([user_query])
    similarities = cosine_similarity(query_embedding, embeddings)
    best_match_index = similarities.argmax()
    return metadata_list[best_match_index]

def create_prompt(user_query, table_metadata):
    """Generates a direct and structured SQL prompt with stricter formatting."""
    system_prompt = f"""
    You are an AI assistant that generates precise SQL queries based on user questions.

    **Table Name & Columns:**
    {table_metadata}

    **User Query:**
    {user_query}

    **Output Format (STRICT):**
    - Provide ONLY the SQL query.
    - Do NOT include explanations, comments, or unnecessary text.
    - Ensure the table and column names match exactly.
    - If the query is impossible, return: "ERROR: Unable to generate query."

    **Example Queries:**
    - User: "Show all startups founded in 2020."
    - AI Response: SELECT * FROM startups WHERE founded_year = 2020;
    
    - User: "List the top 5 startups by total funding."
    - AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
    """
    return system_prompt


def generate_sql_query(system_prompt):
    """Uses Gemini API to generate an SQL query."""
    try:
        # Initialize the Gemini model (use a reliable text model)
        model = genai.GenerativeModel("gemini-2.5-pro")

        # Generate content from the system prompt
        response = model.generate_content(system_prompt)

        # Debug: print the full Gemini response
        print("🔍 Full API Response:", response)

        # Extract AI text response
        result = response.text.strip()
        print(f"✅ AI Response: {result}")

        # Validate SQL query
        if result.lower().startswith("select"):
            return result
        else:
            print("⚠️ Gemini did not generate a valid SQL query.")
            return "⚠️ Invalid SQL query generated."

    except Exception as e:
        print(f"❌ Gemini API Error: {e}")
        return "⚠️ API failed. Check logs."

def response(user_query, dataset_folder):
    """Processes the user query and returns an SQL query."""
    dataframes, metadata_list = load_dataset_metadata(dataset_folder)
    embeddings, model = create_metadata_embeddings(metadata_list)
    table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
    system_prompt = create_prompt(user_query, table_metadata)
    return generate_sql_query(system_prompt)

dataset_folder = "./data"  # Change this based on where your files are uploaded
user_query = "Show me the top 10 startups with the highest funding."

def sql_query_interface(user_query):
    return response(user_query, dataset_folder)

# Define Gradio UI
iface = gr.Interface(
    fn=sql_query_interface,
    inputs=gr.Textbox(label="Enter your query"),
    outputs=gr.Textbox(label="Generated SQL Query"),
    title="AI-Powered SQL Query Generator"
)

# Run Gradio app
if __name__ == "__main__":
    iface.launch()