nilotpaldhar2004 commited on
Commit
128de03
·
unverified ·
1 Parent(s): 3247bdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -264
app.py CHANGED
@@ -9,41 +9,23 @@ import urllib.error
9
 
10
  import pandas as pd
11
 
12
- from fastapi import (
13
- FastAPI,
14
- File,
15
- UploadFile,
16
- HTTPException
17
- )
18
-
19
  from fastapi.staticfiles import StaticFiles
20
  from fastapi.responses import FileResponse
21
  from fastapi.middleware.cors import CORSMiddleware
22
  from pydantic import BaseModel
23
 
24
 
25
- # ── Configuration ──────────────────────────────────────────────
 
 
26
 
27
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
28
 
29
- if GEMINI_API_KEY:
30
- print("✅ GEMINI_API_KEY Loaded")
31
- else:
32
- print("❌ GEMINI_API_KEY Missing")
33
-
34
-
35
- # ── In-Memory Stores ───────────────────────────────────────────
36
-
37
  _db_store = {}
38
  _schema_store = {}
39
 
40
-
41
- # ── FastAPI Setup ──────────────────────────────────────────────
42
-
43
- app = FastAPI(
44
- title="QueryMind Gemini",
45
- version="4.0.0"
46
- )
47
 
48
  app.add_middleware(
49
  CORSMiddleware,
@@ -52,82 +34,61 @@ app.add_middleware(
52
  allow_headers=["*"]
53
  )
54
 
55
-
56
- # ── Request Model ──────────────────────────────────────────────
57
-
58
  class QueryRequest(BaseModel):
59
  session_id: str
60
  question: str
61
 
62
 
63
- # ── Heuristic SQL Engine ───────────────────────────────────────
 
 
64
 
65
- def _heuristic_sql(question: str, table: str, columns: list):
66
-
67
- q = question.lower().strip()
68
-
69
- t = f'"{table}"'
70
 
71
- # ── Count Queries ─────────────────────────────
 
72
 
73
- if any(x in q for x in [
74
- "count total",
75
- "count records",
76
- "total records",
77
- "how many records",
78
- "total rows",
79
- "count rows"
80
- ]):
81
 
82
- return f'SELECT COUNT(*) AS total_rows FROM {t}'
 
83
 
84
- # ── Preview Queries ───────────────────────────
 
85
 
86
- if any(x in q for x in [
87
- "preview",
88
- "show head",
89
- "data preview",
90
- "show first",
91
- "first 10",
92
- "show rows"
93
- ]):
94
 
95
- return f'SELECT * FROM {t} LIMIT 10'
96
 
97
- # ── Show Columns ──────────────────────────────
 
 
98
 
99
- if "columns" in q or "column names" in q:
100
-
101
- cols = ", ".join(columns)
102
-
103
- return f"SELECT '{cols}' AS columns_list"
104
-
105
- # ── Unique Values ─────────────────────────────
106
 
107
- if "unique values in" in q:
 
108
 
109
- col = q.replace("unique values in", "").strip()
 
110
 
111
- if col in columns:
 
112
 
113
- return f'''
114
- SELECT DISTINCT "{col}"
115
- FROM {t}
116
- LIMIT 100
117
- '''
118
 
119
- # ── Group By ──────────────────────────────────
 
 
 
120
 
121
  if "group by" in q:
122
-
123
- match = re.search(r'group by\s+(\w+)', q)
124
-
125
  if match:
126
-
127
  col = match.group(1)
128
-
129
  if col in columns:
130
-
131
  return f'''
132
  SELECT "{col}", COUNT(*) AS count
133
  FROM {t}
@@ -138,246 +99,205 @@ def _heuristic_sql(question: str, table: str, columns: list):
138
  return None
139
 
140
 
141
- # ── Gemini SQL Generator ──────────────────────────────────────
 
 
142
 
143
- def _call_gemini(
144
- question: str,
145
- schema: str,
146
- columns: list,
147
- table: str
148
- ):
149
 
150
  if not GEMINI_API_KEY:
151
-
152
- print("❌ GEMINI_API_KEY Missing")
153
-
154
  return ""
155
 
156
- col_list = ", ".join(columns[:30])
157
-
158
  prompt = f"""
159
- You are a SQLite expert.
160
-
161
- Convert the natural language question into a valid SQLite query.
162
 
163
  Rules:
164
  - Output ONLY SQL
165
- - Use ONLY the provided table
166
- - Do NOT explain anything
167
- - Do NOT use markdown
168
-
169
- Table Name:
170
- {table}
171
 
172
- Columns:
173
- {col_list}
 
174
 
175
- Schema:
176
- {schema}
177
-
178
- Question:
179
- {question}
180
  """
181
 
182
  payload = json.dumps({
183
- "contents": [
184
- {
185
- "parts": [
186
- {
187
- "text": prompt
188
- }
189
- ]
190
- }
191
- ]
192
  }).encode("utf-8")
193
 
194
- # Correct Gemini endpoint
195
  url = (
196
  "https://generativelanguage.googleapis.com/"
197
  f"v1beta/models/gemini-1.5-flash:generateContent?key={GEMINI_API_KEY}"
198
  )
199
 
200
  try:
201
-
202
  req = urllib.request.Request(
203
  url,
204
  data=payload,
205
- headers={
206
- "Content-Type": "application/json"
207
- }
208
- )
209
-
210
- with urllib.request.urlopen(req, timeout=20) as resp:
211
-
212
- data = json.loads(resp.read())
213
-
214
- sql = (
215
- data["candidates"][0]
216
- ["content"]["parts"][0]
217
- ["text"]
218
- .strip()
219
  )
220
 
221
- # Cleanup
222
- sql = (
223
- sql
224
- .replace("```sql", "")
225
- .replace("```", "")
226
- .strip()
227
- .split(";")[0]
228
- )
229
 
230
- # Force correct table
231
- sql = re.sub(
232
- r'\bFROM\s+["\'\w\.]+',
233
- f'FROM "{table}"',
234
- sql,
235
- flags=re.IGNORECASE
236
- )
237
 
238
- return sql
239
 
240
- except urllib.error.HTTPError as e:
241
-
242
- error_body = e.read().decode()
243
-
244
- print(f"❌ GEMINI HTTP ERROR: {e.code}")
245
- print(error_body)
246
-
247
- return ""
248
 
249
  except Exception as e:
250
-
251
- print(f"❌ GEMINI ERROR: {str(e)}")
252
-
253
  return ""
254
 
255
 
256
- # ── SQLite Execution ──────────────────────────────────────────
 
 
257
 
258
  def execute_sql(sql, db_bytes):
259
 
260
  conn = sqlite3.connect(":memory:")
261
 
262
  with tempfile.NamedTemporaryFile(delete=False) as f:
263
-
264
  f.write(db_bytes)
265
-
266
  f.flush()
 
267
 
268
- temp_name = f.name
 
 
 
 
 
 
 
 
269
 
270
  try:
 
 
 
 
 
 
271
 
272
- disk_conn = sqlite3.connect(temp_name)
273
 
274
- disk_conn.backup(conn)
 
 
275
 
276
- disk_conn.close()
277
 
278
- finally:
 
279
 
280
- if os.path.exists(temp_name):
281
- os.remove(temp_name)
 
 
282
 
283
- conn.row_factory = sqlite3.Row
284
 
285
- try:
 
 
286
 
287
- cur = conn.execute(sql)
288
 
289
- rows = cur.fetchall()
 
290
 
291
- return [dict(r) for r in rows]
 
292
 
293
- except Exception as e:
 
 
294
 
295
- return [{"error": str(e)}]
 
296
 
297
- finally:
 
 
298
 
299
- conn.close()
 
 
 
300
 
 
 
 
 
 
 
301
 
302
- # ── Upload Endpoint ───────────────────────────────────────────
 
303
 
304
- @app.post("/upload")
305
- async def upload_csv(file: UploadFile = File(...)):
306
 
307
- try:
 
308
 
309
- contents = await file.read()
310
 
311
- df = pd.read_csv(
312
- io.BytesIO(contents)
313
- ).dropna(how='all')
314
 
315
- session_id = os.urandom(8).hex()
 
316
 
317
- # Clean table name
318
- clean_name = re.sub(
319
- r'[^a-zA-Z0-9_]',
320
- '_',
321
- os.path.splitext(file.filename)[0]
322
- )
323
 
324
- if clean_name[0].isdigit():
325
- clean_name = "t_" + clean_name
326
 
327
- table_name = clean_name[:32]
328
 
329
- with tempfile.NamedTemporaryFile(delete=False) as tf:
330
 
331
- conn = sqlite3.connect(tf.name)
 
332
 
333
- df.to_sql(
334
- table_name,
335
- conn,
336
- index=False,
337
- if_exists="replace"
338
- )
339
 
340
- schema = conn.execute(
341
- "SELECT sql FROM sqlite_master WHERE type='table'"
342
- ).fetchone()[0]
343
 
344
- conn.close()
345
 
346
- with open(tf.name, "rb") as f:
347
- db_data = f.read()
348
 
349
- if os.path.exists(tf.name):
350
- os.remove(tf.name)
 
351
 
352
- _db_store[session_id] = {
353
- "bytes": db_data,
354
- "table": table_name,
355
- "cols": list(df.columns)
356
- }
357
 
358
- _schema_store[session_id] = schema
359
 
360
- return {
361
- "session_id": session_id,
362
- "columns": list(df.columns),
363
- "row_count": len(df),
364
- "table_name": table_name,
365
- "preview": df.head(5).to_dict(
366
- orient="records"
367
- )
368
- }
369
 
370
- except Exception as e:
 
 
 
 
371
 
372
- print(f"❌ UPLOAD ERROR: {e}")
373
 
374
- raise HTTPException(
375
- status_code=500,
376
- detail=str(e)
377
- )
 
378
 
379
 
380
- # ── Query Endpoint ───────────────────────────────────────────
 
 
381
 
382
  @app.post("/query")
383
  async def query(req: QueryRequest):
@@ -385,24 +305,17 @@ async def query(req: QueryRequest):
385
  data = _db_store.get(req.session_id)
386
 
387
  if not data:
388
-
389
- raise HTTPException(
390
- status_code=404,
391
- detail="Invalid session_id"
392
- )
393
 
394
  schema = _schema_store.get(req.session_id)
395
 
396
- # Step 1 → Heuristic SQL
397
- sql = _heuristic_sql(
398
- req.question,
399
- data["table"],
400
- data["cols"]
401
- )
402
 
403
- # Step 2 → Gemini fallback
404
- if not sql:
405
 
 
406
  sql = _call_gemini(
407
  req.question,
408
  schema,
@@ -410,46 +323,53 @@ async def query(req: QueryRequest):
410
  data["table"]
411
  )
412
 
413
- # Step 3 → Failure
414
  if not sql:
 
415
 
416
- raise HTTPException(
417
- status_code=400,
418
- detail="Failed to generate SQL query"
419
- )
420
 
421
- results = execute_sql(
422
- sql,
423
- data["bytes"]
424
- )
 
 
 
 
 
 
 
 
 
425
 
426
  return {
 
427
  "sql": sql,
428
- "results": results
 
 
429
  }
430
 
431
 
432
- # ── Health Endpoint ───────────────────────────────────────────
 
 
433
 
434
  @app.get("/health")
435
  def health():
436
-
437
  return {
438
  "status": "ok",
439
- "model": "gemini-1.5-flash"
440
  }
441
 
442
 
443
- # ── Static Frontend ───────────────────────────────────────────
444
-
445
- app.mount(
446
- "/static",
447
- StaticFiles(directory="static"),
448
- name="static"
449
- )
450
 
 
451
 
452
  @app.get("/")
453
  def root():
454
-
455
  return FileResponse("static/webapp.html")
 
9
 
10
  import pandas as pd
11
 
12
+ from fastapi import FastAPI, File, UploadFile, HTTPException
 
 
 
 
 
 
13
  from fastapi.staticfiles import StaticFiles
14
  from fastapi.responses import FileResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from pydantic import BaseModel
17
 
18
 
19
+ # ─────────────────────────────────────────────
20
+ # CONFIG
21
+ # ─────────────────────────────────────────────
22
 
23
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
24
 
 
 
 
 
 
 
 
 
25
  _db_store = {}
26
  _schema_store = {}
27
 
28
+ app = FastAPI(title="AI Data Analyst Agent", version="5.0.0")
 
 
 
 
 
 
29
 
30
  app.add_middleware(
31
  CORSMiddleware,
 
34
  allow_headers=["*"]
35
  )
36
 
 
 
 
37
  class QueryRequest(BaseModel):
38
  session_id: str
39
  question: str
40
 
41
 
42
+ # ─────────────────────────────────────────────
43
+ # AGENT INTENT ENGINE
44
+ # ─────────────────────────────────────────────
45
 
46
+ def agent_think(question: str):
47
+ q = question.lower()
 
 
 
48
 
49
+ if any(x in q for x in ["chart", "graph", "plot"]):
50
+ return "VISUALIZE"
51
 
52
+ if any(x in q for x in ["average", "avg", "sum", "max", "min"]):
53
+ return "ANALYZE"
 
 
 
 
 
 
54
 
55
+ if any(x in q for x in ["why", "explain", "reason"]):
56
+ return "EXPLAIN"
57
 
58
+ if any(x in q for x in ["count", "how many"]):
59
+ return "COUNT"
60
 
61
+ return "SQL"
 
 
 
 
 
 
 
62
 
 
63
 
64
+ # ─────────────────────────────────────────────
65
+ # HEURISTIC SQL ENGINE
66
+ # ─────────────────────────────────────────────
67
 
68
+ def _heuristic_sql(question: str, table: str, columns: list):
 
 
 
 
 
 
69
 
70
+ q = question.lower()
71
+ t = f'"{table}"'
72
 
73
+ if "count" in q or "how many" in q:
74
+ return f"SELECT COUNT(*) AS total_rows FROM {t}"
75
 
76
+ if any(x in q for x in ["first row", "show first", "preview"]):
77
+ return f"SELECT * FROM {t} LIMIT 10"
78
 
79
+ if "last row" in q:
80
+ return f"SELECT * FROM {t} ORDER BY rowid DESC LIMIT 1"
 
 
 
81
 
82
+ if "unique values" in q:
83
+ for col in columns:
84
+ if col.lower() in q:
85
+ return f'SELECT DISTINCT "{col}" FROM {t} LIMIT 100'
86
 
87
  if "group by" in q:
88
+ match = re.search(r'group by (\w+)', q)
 
 
89
  if match:
 
90
  col = match.group(1)
 
91
  if col in columns:
 
92
  return f'''
93
  SELECT "{col}", COUNT(*) AS count
94
  FROM {t}
 
99
  return None
100
 
101
 
102
+ # ─────────────────────────────────────────────
103
+ # GEMINI SQL GENERATOR
104
+ # ─────────────────────────────────────────────
105
 
106
+ def _call_gemini(question, schema, columns, table):
 
 
 
 
 
107
 
108
  if not GEMINI_API_KEY:
 
 
 
109
  return ""
110
 
 
 
111
  prompt = f"""
112
+ You are a strict SQLite expert.
 
 
113
 
114
  Rules:
115
  - Output ONLY SQL
116
+ - Use only given table and columns
117
+ - No explanation
 
 
 
 
118
 
119
+ Table: {table}
120
+ Columns: {columns}
121
+ Schema: {schema}
122
 
123
+ Question: {question}
 
 
 
 
124
  """
125
 
126
  payload = json.dumps({
127
+ "contents": [{"parts": [{"text": prompt}]}]
 
 
 
 
 
 
 
 
128
  }).encode("utf-8")
129
 
 
130
  url = (
131
  "https://generativelanguage.googleapis.com/"
132
  f"v1beta/models/gemini-1.5-flash:generateContent?key={GEMINI_API_KEY}"
133
  )
134
 
135
  try:
 
136
  req = urllib.request.Request(
137
  url,
138
  data=payload,
139
+ headers={"Content-Type": "application/json"}
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
 
142
+ res = urllib.request.urlopen(req, timeout=20)
143
+ data = json.loads(res.read())
 
 
 
 
 
 
144
 
145
+ sql = data["candidates"][0]["content"]["parts"][0]["text"]
 
 
 
 
 
 
146
 
147
+ sql = sql.replace("```sql", "").replace("```", "").strip()
148
 
149
+ return sql.split(";")[0]
 
 
 
 
 
 
 
150
 
151
  except Exception as e:
152
+ print("Gemini Error:", e)
 
 
153
  return ""
154
 
155
 
156
+ # ────────────────────────────────────────────
157
+ # SQL EXECUTION
158
+ # ─────────────────────────────────────────────
159
 
160
  def execute_sql(sql, db_bytes):
161
 
162
  conn = sqlite3.connect(":memory:")
163
 
164
  with tempfile.NamedTemporaryFile(delete=False) as f:
 
165
  f.write(db_bytes)
 
166
  f.flush()
167
+ temp = f.name
168
 
169
+ try:
170
+ disk = sqlite3.connect(temp)
171
+ disk.backup(conn)
172
+ disk.close()
173
+ finally:
174
+ if os.path.exists(temp):
175
+ os.remove(temp)
176
+
177
+ conn.row_factory = sqlite3.Row
178
 
179
  try:
180
+ cur = conn.execute(sql)
181
+ return [dict(r) for r in cur.fetchall()]
182
+ except Exception as e:
183
+ return [{"error": str(e)}]
184
+ finally:
185
+ conn.close()
186
 
 
187
 
188
+ # ─────────────────────────────────────────────
189
+ # ANALYSIS ENGINE
190
+ # ──────────────────���──────────────────────────
191
 
192
+ def analyze_results(results):
193
 
194
+ if not results:
195
+ return {"message": "No data found"}
196
 
197
+ return {
198
+ "rows_returned": len(results),
199
+ "sample": results[:3]
200
+ }
201
 
 
202
 
203
+ # ─────────────────────────────────────────────
204
+ # EXPLANATION ENGINE (Gemini)
205
+ # ─────────────────────────────────────────────
206
 
207
+ def explain_results(question, sql, results):
208
 
209
+ if not GEMINI_API_KEY:
210
+ return None
211
 
212
+ prompt = f"""
213
+ You are a data analyst.
214
 
215
+ Question: {question}
216
+ SQL: {sql}
217
+ Results: {results[:5]}
218
 
219
+ Explain this in simple words.
220
+ """
221
 
222
+ payload = json.dumps({
223
+ "contents": [{"parts": [{"text": prompt}]}]
224
+ }).encode("utf-8")
225
 
226
+ url = (
227
+ "https://generativelanguage.googleapis.com/"
228
+ f"v1beta/models/gemini-1.5-flash:generateContent?key={GEMINI_API_KEY}"
229
+ )
230
 
231
+ try:
232
+ req = urllib.request.Request(
233
+ url,
234
+ data=payload,
235
+ headers={"Content-Type": "application/json"}
236
+ )
237
 
238
+ res = urllib.request.urlopen(req)
239
+ data = json.loads(res.read())
240
 
241
+ return data["candidates"][0]["content"]["parts"][0]["text"]
 
242
 
243
+ except:
244
+ return None
245
 
 
246
 
247
+ # ─────────────────────────────────────────────
248
+ # UPLOAD CSV
249
+ # ─────────────────────────────────────────────
250
 
251
+ @app.post("/upload")
252
+ async def upload_csv(file: UploadFile = File(...)):
253
 
254
+ contents = await file.read()
 
 
 
 
 
255
 
256
+ df = pd.read_csv(io.BytesIO(contents)).dropna(how="all")
 
257
 
258
+ session_id = os.urandom(8).hex()
259
 
260
+ table_name = re.sub(r"[^a-zA-Z0-9_]", "_", file.filename)
261
 
262
+ if table_name[0].isdigit():
263
+ table_name = "t_" + table_name
264
 
265
+ table_name = table_name[:32]
 
 
 
 
 
266
 
267
+ with tempfile.NamedTemporaryFile(delete=False) as tf:
 
 
268
 
269
+ conn = sqlite3.connect(tf.name)
270
 
271
+ df.to_sql(table_name, conn, index=False, if_exists="replace")
 
272
 
273
+ schema = conn.execute(
274
+ "SELECT sql FROM sqlite_master WHERE type='table'"
275
+ ).fetchone()[0]
276
 
277
+ conn.close()
 
 
 
 
278
 
279
+ db_bytes = open(tf.name, "rb").read()
280
 
281
+ os.remove(tf.name)
 
 
 
 
 
 
 
 
282
 
283
+ _db_store[session_id] = {
284
+ "bytes": db_bytes,
285
+ "table": table_name,
286
+ "cols": list(df.columns)
287
+ }
288
 
289
+ _schema_store[session_id] = schema
290
 
291
+ return {
292
+ "session_id": session_id,
293
+ "rows": len(df),
294
+ "columns": list(df.columns)
295
+ }
296
 
297
 
298
+ # ─────────────────────────────────────────────
299
+ # QUERY ENGINE (AGENT CORE)
300
+ # ─────────────────────────────────────────────
301
 
302
  @app.post("/query")
303
  async def query(req: QueryRequest):
 
305
  data = _db_store.get(req.session_id)
306
 
307
  if not data:
308
+ raise HTTPException(404, "Invalid session")
 
 
 
 
309
 
310
  schema = _schema_store.get(req.session_id)
311
 
312
+ # 🧠 Agent thinking
313
+ intent = agent_think(req.question)
 
 
 
 
314
 
315
+ # 🗄️ SQL generation
316
+ sql = _heuristic_sql(req.question, data["table"], data["cols"])
317
 
318
+ if not sql:
319
  sql = _call_gemini(
320
  req.question,
321
  schema,
 
323
  data["table"]
324
  )
325
 
 
326
  if not sql:
327
+ raise HTTPException(400, "SQL generation failed")
328
 
329
+ # ⚡ Execute SQL
330
+ results = execute_sql(sql, data["bytes"])
 
 
331
 
332
+ # 📊 Analysis
333
+ analysis = analyze_results(results)
334
+
335
+ # 💬 Explanation (only for analytical intent)
336
+ explanation = None
337
+
338
+ if intent in ["ANALYZE", "EXPLAIN"]:
339
+
340
+ explanation = explain_results(
341
+ req.question,
342
+ sql,
343
+ results
344
+ )
345
 
346
  return {
347
+ "intent": intent,
348
  "sql": sql,
349
+ "results": results[:20],
350
+ "analysis": analysis,
351
+ "explanation": explanation
352
  }
353
 
354
 
355
+ # ─────────────────────────────────────────────
356
+ # HEALTH CHECK
357
+ # ─────────────────────────────────────────────
358
 
359
  @app.get("/health")
360
  def health():
 
361
  return {
362
  "status": "ok",
363
+ "model": "AI Data Analyst Agent"
364
  }
365
 
366
 
367
+ # ─────────────────────────────────────────────
368
+ # FRONTEND
369
+ # ─────────────────────────────────────────────
 
 
 
 
370
 
371
+ app.mount("/static", StaticFiles(directory="static"), name="static")
372
 
373
  @app.get("/")
374
  def root():
 
375
  return FileResponse("static/webapp.html")