File size: 6,260 Bytes
2116f61
 
 
 
 
 
 
b86c9dc
 
403fc41
684ee50
403fc41
 
b86c9dc
403fc41
 
b86c9dc
 
 
4765d25
 
7e25c09
4765d25
 
403fc41
 
 
684ee50
4765d25
7e25c09
 
 
 
 
 
 
 
 
 
 
4765d25
 
7e25c09
 
 
 
4765d25
7e25c09
4765d25
 
 
9e191ef
 
 
 
 
 
 
7e25c09
2116f61
 
 
 
 
 
 
 
b86c9dc
 
 
2116f61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6849134
d8cbcf4
6849134
d8cbcf4
6849134
d8cbcf4
 
6849134
d8cbcf4
 
6849134
 
 
 
 
 
 
 
 
 
 
 
2116f61
 
 
d8cbcf4
2116f61
d8cbcf4
7634430
 
 
 
 
 
 
d8cbcf4
 
 
 
7634430
 
 
d8cbcf4
7634430
 
 
 
 
 
 
 
 
2116f61
d8cbcf4
2116f61
 
 
 
 
 
 
 
403fc41
b86c9dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8cbcf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import pandas as pd
load_dotenv()
groq_api_key = os.getenv("groq_api_key")

dataset_folder = "./data"

if not os.path.exists(dataset_folder):
    print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
    dataset_folder = "."  # Fallback: Look in the current directory

# Print available files for debugging
print("Available files:", os.listdir(dataset_folder))

import warnings

# Ignore DtypeWarning
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)

# Load all CSV files in the dataset folder
dataframes = []
for file in os.listdir(dataset_folder):
    if file.endswith(".csv"): 
        try:
            # Read first few rows to identify column names
            sample_df = pd.read_csv(
                os.path.join(dataset_folder, file),
                nrows=5,  # Read only first 5 rows for column type inference
                encoding="utf-8",
                errors="replace"  # Replace encoding errors with a placeholder
            )

            column_types = {col: str for col in sample_df.columns}  # Force all columns to string
            
            # Read the entire file with enforced column types
            df = pd.read_csv(
                os.path.join(dataset_folder, file),
                dtype=column_types,  # Apply enforced string types
                low_memory=False,  # Avoid chunk-based reading issues
                encoding="utf-8",
                errors="replace"
            ).fillna('')  # Fill NaN values with empty strings
            
            dataframes.append(df)  # Append DataFrame to the list
        except Exception as e:
            print(f"Error reading {file}: {e}")

# Merge all CSV files into one DataFrame (only if there are valid files)
if dataframes:
    full_data = pd.concat(dataframes, ignore_index=True)
else:
    print("Warning: No valid CSV files found in the dataset folder.")
    full_data = pd.DataFrame()  # Create an empty DataFrame as a fallback
    

def load_dataset_metadata(dataset_folder):
    """Loads metadata from all CSV files in the dataset folder."""
    dataframes = []
    metadata_list = []
    
    for file in os.listdir(dataset_folder):
        if file.endswith(".csv"):
            df = pd.read_csv(os.path.join(dataset_folder, file))
            dataframes.append((file, df))

            # Generate table metadata
            columns = df.columns.tolist()
            table_metadata = f"""
            Table: {file.replace('.csv', '')}
            Columns:
            {', '.join(columns)}
            """
            metadata_list.append(table_metadata)
    
    return dataframes, metadata_list

def create_metadata_embeddings(metadata_list):
    """Creates embeddings for all table metadata."""
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(metadata_list)
    return embeddings, model

def find_best_fit(embeddings, model, user_query, metadata_list):
    """Finds the best matching table based on user query."""
    query_embedding = model.encode([user_query])
    similarities = cosine_similarity(query_embedding, embeddings)
    best_match_index = similarities.argmax()
    return metadata_list[best_match_index]

def create_prompt(user_query, table_metadata):
    """Generates a direct and structured SQL prompt with stricter formatting."""
    system_prompt = f"""
    You are an AI assistant that generates precise SQL queries based on user questions.

    **Table Name & Columns:**
    {table_metadata}

    **User Query:**
    {user_query}

    **Output Format (STRICT):**
    - Provide ONLY the SQL query.
    - Do NOT include explanations, comments, or unnecessary text.
    - Ensure the table and column names match exactly.
    - If the query is impossible, return: "ERROR: Unable to generate query."

    **Example Queries:**
    - User: "Show all startups founded in 2020."
    - AI Response: SELECT * FROM startups WHERE founded_year = 2020;
    
    - User: "List the top 5 startups by total funding."
    - AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
    """
    return system_prompt


def generate_sql_query(system_prompt):
    """Uses Groq API to generate an SQL query with better debugging."""
    try:
        client = Groq(api_key=groq_api_key)
        chat_completion = client.chat.completions.create(
            messages=[{"role": "system", "content": system_prompt}],
            model="llama3-70b-8192"
        )

        # Debug: Print entire response
        print("🔍 Full API Response:", chat_completion)

        # Extract AI response
        result = chat_completion.choices[0].message.content.strip()
        print(f"✅ AI Response: {result}")  # Debugging

        # Check if the response starts with "SELECT"
        if result.lower().startswith("select"):
            return result
        else:
            print("⚠️ AI did not generate a valid SQL query!")
            return "⚠️ AI response is not a valid SQL query."

    except Exception as e:
        print(f"❌ API Error: {e}")
        return "⚠️ API failed. Check logs."


def response(user_query, dataset_folder):
    """Processes the user query and returns an SQL query."""
    dataframes, metadata_list = load_dataset_metadata(dataset_folder)
    embeddings, model = create_metadata_embeddings(metadata_list)
    table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
    system_prompt = create_prompt(user_query, table_metadata)
    return generate_sql_query(system_prompt)

dataset_folder = "./data"  # Change this based on where your files are uploaded
user_query = "Show me the top 10 startups with the highest funding."

def sql_query_interface(user_query):
    return response(user_query, dataset_folder)

# Define Gradio UI
iface = gr.Interface(
    fn=sql_query_interface,
    inputs=gr.Textbox(label="Enter your query"),
    outputs=gr.Textbox(label="Generated SQL Query"),
    title="AI-Powered SQL Query Generator"
)

# Run Gradio app
if __name__ == "__main__":
    iface.launch()