anjalirathore commited on
Commit
5f9b2f9
·
verified ·
1 Parent(s): 4e35995

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -38
app.py CHANGED
@@ -3,13 +3,22 @@ import os
3
  from sentence_transformers import SentenceTransformer
4
  import gradio as gr
5
  from sklearn.metrics.pairwise import cosine_similarity
6
- from groq import Groq
 
 
7
  import pandas as pd
8
  load_dotenv()
9
- groq_api_key = os.getenv("groq_api_key")
 
 
 
 
 
10
 
11
- dataset_folder = "./data"
 
12
 
 
13
  if not os.path.exists(dataset_folder):
14
  print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
15
  dataset_folder = "." # Fallback: Look in the current directory
@@ -25,28 +34,26 @@ warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
25
  # Load all CSV files in the dataset folder
26
  dataframes = []
27
  for file in os.listdir(dataset_folder):
28
- if file.endswith(".csv"):
29
  try:
30
- # Read first few rows to identify column names
31
- sample_df = pd.read_csv(
32
- os.path.join(dataset_folder, file),
33
- nrows=5, # Read only first 5 rows for column type inference
34
- encoding="utf-8",
35
- errors="replace" # Replace encoding errors with a placeholder
36
- )
37
 
38
  column_types = {col: str for col in sample_df.columns} # Force all columns to string
39
-
40
- # Read the entire file with enforced column types
41
- df = pd.read_csv(
42
- os.path.join(dataset_folder, file),
43
- dtype=column_types, # Apply enforced string types
44
- low_memory=False, # Avoid chunk-based reading issues
45
- encoding="utf-8",
46
- errors="replace"
47
- ).fillna('') # Fill NaN values with empty strings
48
-
49
- dataframes.append(df) # Append DataFrame to the list
50
  except Exception as e:
51
  print(f"Error reading {file}: {e}")
52
 
@@ -120,33 +127,32 @@ def create_prompt(user_query, table_metadata):
120
 
121
 
122
  def generate_sql_query(system_prompt):
123
- """Uses Groq API to generate an SQL query with better debugging."""
124
  try:
125
- client = Groq(api_key=groq_api_key)
126
- chat_completion = client.chat.completions.create(
127
- messages=[{"role": "system", "content": system_prompt}],
128
- model="llama3-70b-8192"
129
- )
130
 
131
- # Debug: Print entire response
132
- print("🔍 Full API Response:", chat_completion)
133
 
134
- # Extract AI response
135
- result = chat_completion.choices[0].message.content.strip()
136
- print(f"✅ AI Response: {result}") # Debugging
137
 
138
- # Check if the response starts with "SELECT"
 
 
 
 
139
  if result.lower().startswith("select"):
140
  return result
141
  else:
142
- print("⚠️ AI did not generate a valid SQL query!")
143
- return "⚠️ AI response is not a valid SQL query."
144
 
145
  except Exception as e:
146
- print(f"❌ API Error: {e}")
147
  return "⚠️ API failed. Check logs."
148
 
149
-
150
  def response(user_query, dataset_folder):
151
  """Processes the user query and returns an SQL query."""
152
  dataframes, metadata_list = load_dataset_metadata(dataset_folder)
 
3
  from sentence_transformers import SentenceTransformer
4
  import gradio as gr
5
  from sklearn.metrics.pairwise import cosine_similarity
6
+ import google.generativeai as genai
7
+ import os
8
+ from dotenv import load_dotenv
9
  import pandas as pd
10
  load_dotenv()
11
+ # Add this: read the API key from env and warn if missing
12
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
13
+ if not GEMINI_API_KEY:
14
+ print("Warning: GEMINI_API_KEY not set in environment. Set it in your .env file or system env vars.")
15
+
16
+ genai.configure(api_key=GEMINI_API_KEY)
17
 
18
+ # Use the current directory for Hugging Face Spaces
19
+ dataset_folder = "./data" # Assuming files are in a 'data/' folder
20
 
21
+ # Verify the folder exists
22
  if not os.path.exists(dataset_folder):
23
  print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
24
  dataset_folder = "." # Fallback: Look in the current directory
 
34
  # Load all CSV files in the dataset folder
35
  dataframes = []
36
  for file in os.listdir(dataset_folder):
37
+ if file.endswith(".csv"): # Check if the file is a CSV
38
  try:
39
+ path = os.path.join(dataset_folder, file)
40
+
41
+ # Try reading with utf-8, fallback to latin1 if encoding fails
42
+ try:
43
+ sample_df = pd.read_csv(path, nrows=5, encoding="utf-8")
44
+ except UnicodeDecodeError:
45
+ sample_df = pd.read_csv(path, nrows=5, encoding="latin1")
46
 
47
  column_types = {col: str for col in sample_df.columns} # Force all columns to string
48
+
49
+ try:
50
+ df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="utf-8")
51
+ except UnicodeDecodeError:
52
+ df = pd.read_csv(path, dtype=column_types, low_memory=False, encoding="latin1")
53
+
54
+ df = df.fillna('') # Fill NaN values with empty strings
55
+ dataframes.append(df)
56
+
 
 
57
  except Exception as e:
58
  print(f"Error reading {file}: {e}")
59
 
 
127
 
128
 
129
  def generate_sql_query(system_prompt):
130
+ """Uses Gemini API to generate an SQL query."""
131
  try:
132
+ # Initialize the Gemini model (use a reliable text model)
133
+ model = genai.GenerativeModel("gemini-2.5-pro")
 
 
 
134
 
135
+ # Generate content from the system prompt
136
+ response = model.generate_content(system_prompt)
137
 
138
+ # Debug: print the full Gemini response
139
+ print("🔍 Full API Response:", response)
 
140
 
141
+ # Extract AI text response
142
+ result = response.text.strip()
143
+ print(f"✅ AI Response: {result}")
144
+
145
+ # Validate SQL query
146
  if result.lower().startswith("select"):
147
  return result
148
  else:
149
+ print("⚠️ Gemini did not generate a valid SQL query.")
150
+ return "⚠️ Invalid SQL query generated."
151
 
152
  except Exception as e:
153
+ print(f"❌ Gemini API Error: {e}")
154
  return "⚠️ API failed. Check logs."
155
 
 
156
  def response(user_query, dataset_folder):
157
  """Processes the user query and returns an SQL query."""
158
  dataframes, metadata_list = load_dataset_metadata(dataset_folder)