kkhushisaid commited on
Commit
2116f61
·
verified ·
1 Parent(s): 3423ae7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py CHANGED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from sentence_transformers import SentenceTransformer
4
+ import gradio as gr
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from groq import Groq
7
+ import pandas as pd
8
+
9
+ load_dotenv()
10
+
11
+ groq_api_key = os.getenv("groq_api_key")
12
+
13
+ def load_dataset_metadata(dataset_folder):
14
+ """Loads metadata from all CSV files in the dataset folder."""
15
+ dataframes = []
16
+ metadata_list = []
17
+
18
+ for file in os.listdir(dataset_folder):
19
+ if file.endswith(".csv"):
20
+ df = pd.read_csv(os.path.join(dataset_folder, file))
21
+ dataframes.append((file, df))
22
+
23
+ # Generate table metadata
24
+ columns = df.columns.tolist()
25
+ table_metadata = f"""
26
+ Table: {file.replace('.csv', '')}
27
+ Columns:
28
+ {', '.join(columns)}
29
+ """
30
+ metadata_list.append(table_metadata)
31
+
32
+ return dataframes, metadata_list
33
+
34
+ def create_metadata_embeddings(metadata_list):
35
+ """Creates embeddings for all table metadata."""
36
+ model = SentenceTransformer('all-MiniLM-L6-v2')
37
+ embeddings = model.encode(metadata_list)
38
+ return embeddings, model
39
+
40
+ def find_best_fit(embeddings, model, user_query, metadata_list):
41
+ """Finds the best matching table based on user query."""
42
+ query_embedding = model.encode([user_query])
43
+ similarities = cosine_similarity(query_embedding, embeddings)
44
+ best_match_index = similarities.argmax()
45
+ return metadata_list[best_match_index]
46
+
47
+ def create_prompt(user_query, table_metadata):
48
+ """Generates a prompt for the AI model."""
49
+ system_prompt = """
50
+ You are a SQL query generator specialized in generating SQL queries for a single table at a time.
51
+ Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata.
52
+
53
+ Rules:
54
+ - Assume all queries relate to a single table provided in the metadata. Ignore references to other tables.
55
+ - Ensure the generated query matches the table name, columns, and data types in the metadata.
56
+ - Capture filters, sorting, or aggregations as per user intent.
57
+ - Use standard SQL syntax.
58
+
59
+ Input:
60
+ User Query: {user_query}
61
+ Table Metadata: {table_metadata}
62
+
63
+ Output:
64
+ - Provide only the SQL query in a single line. No extra words.
65
+ """
66
+ return system_prompt
67
+
68
+ def generate_sql_query(system_prompt):
69
+ """Uses Groq API to generate an SQL query."""
70
+ client = Groq(api_key=groq_api_key)
71
+ chat_completion = client.chat.completions.create(
72
+ messages=[{"role": "system", "content": system_prompt}],
73
+ model="llama3-70b-8192"
74
+ )
75
+ result = chat_completion.choices[0].message.content.strip()
76
+ return result if result.lower().startswith("select") else "Can't perform the task at the moment."
77
+
78
+ def response(user_query, dataset_folder):
79
+ """Processes the user query and returns an SQL query."""
80
+ dataframes, metadata_list = load_dataset_metadata(dataset_folder)
81
+ embeddings, model = create_metadata_embeddings(metadata_list)
82
+ table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
83
+ system_prompt = create_prompt(user_query, table_metadata)
84
+ return generate_sql_query(system_prompt)
85
+
86
+ # Example usage:
87
+ dataset_folder = r"C:\\Users\\khuma\\startups"
88
+ user_query = "Show me the top 10 startups with the highest funding."
89
+ print(response(user_query, dataset_folder))