nilotpaldhar2004 commited on
Commit
e1f4b42
·
unverified ·
1 Parent(s): 5170b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -375
app.py CHANGED
@@ -1,9 +1,7 @@
1
  """
2
- Nilotpal SQL Bot Telegram Bot + Web App
3
- FastAPI backend serving:
4
- - Telegram Bot (standard messages + inline buttons)
5
- - Telegram Web App (full HTML/CSS/JS UI via /webapp)
6
- Model: cssupport/t5-small-awesome-text-to-sql (CPU-friendly)
7
  """
8
 
9
  import os
@@ -12,411 +10,180 @@ import io
12
  import json
13
  import sqlite3
14
  import tempfile
15
- import hashlib
16
  import pandas as pd
17
- from fastapi import FastAPI, File, UploadFile, HTTPException, Request
 
18
  from fastapi.staticfiles import StaticFiles
19
- from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
20
  from fastapi.middleware.cors import CORSMiddleware
21
  from pydantic import BaseModel
22
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
23
- import torch
24
- import httpx
25
 
26
- # ── Config ────────────────────────────────────────────────────────────────────
27
- MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"
28
- MAX_NEW_TOKENS = 256
29
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
- BOT_TOKEN = os.getenv("BOT_TOKEN", "") # set in HF Space secrets
31
- WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "nilotpalsqlbot")
32
- SPACE_URL = os.getenv("SPACE_URL", "") # e.g. https://nilotpaldhar2004-nilotpal-sql-bot.hf.space
33
 
34
- TELEGRAM_API = f"https://api.telegram.org/bot{BOT_TOKEN}"
 
35
 
36
- # ── Load model ────────────────────────────────────────────────────────────────
37
- print(f"[INFO] Loading {MODEL_NAME} on {DEVICE}...")
38
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
39
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
40
- model.eval()
41
- print("[INFO] Model ready.")
42
-
43
- # ── In-memory stores ──────────────────────────────────────────────────────────
44
- _db_store: dict[str, bytes] = {} # session_id → sqlite bytes
45
- _schema_store: dict[str, str] = {} # session_id → schema string
46
- _col_store: dict[str, list] = {} # session_id → column list
47
- _table_store: dict[str, str] = {} # session_id → table name
48
- _user_session: dict[int, str] = {} # telegram user_id → session_id
49
-
50
- app = FastAPI(title="Nilotpal SQL Bot", version="1.0.0")
51
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
52
- app.mount("/static", StaticFiles(directory="static"), name="static")
53
-
54
-
55
- # ── Helpers ───────────────────────────────────────────────────────────────────
56
- def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
57
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
58
- tmp_path = tmp.name
59
- conn = sqlite3.connect(tmp_path)
60
- df.to_sql(table_name, conn, if_exists="replace", index=False)
61
- conn.close()
62
- with open(tmp_path, "rb") as f:
63
- db_bytes = f.read()
64
- os.unlink(tmp_path)
65
- return db_bytes
66
 
 
 
 
67
 
68
- def get_schema(db_bytes: bytes) -> str:
69
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
70
- tmp.write(db_bytes)
71
- tmp_path = tmp.name
72
- conn = sqlite3.connect(tmp_path)
73
- cur = conn.cursor()
74
- cur.execute("SELECT sql FROM sqlite_master WHERE type='table'")
75
- rows = cur.fetchall()
76
- conn.close()
77
- os.unlink(tmp_path)
78
- return "\n".join(r[0] for r in rows if r[0])
79
 
 
 
 
 
 
 
 
80
 
81
- def generate_sql(question: str, schema: str, table_name: str) -> str:
82
- quoted = f'"{table_name}"'
83
  q = question.lower().strip()
84
-
85
- # ── Rule-based shortcuts (fast + accurate) ────────────────────────────
86
- if re.search(r'show.*(first|top).*\d+|first.*\d+.*row|top.*\d+', q):
87
- n = re.search(r'\d+', q)
88
- return f'SELECT * FROM {quoted} LIMIT {n.group() if n else 10}'
89
- if re.search(r'(show|display|get|give).*(first|all).*row|first.*row|show.*row', q):
90
- return f'SELECT * FROM {quoted} LIMIT 10'
91
- if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
92
- return f'SELECT COUNT(*) FROM {quoted}'
93
- if re.search(r'show.*(all|every).*row|all.*row|select all', q):
94
- return f'SELECT * FROM {quoted} LIMIT 50'
95
- if re.search(r'average|avg', q):
96
- col_match = re.findall(r'"(\w+)"', schema)
97
- # find numeric-looking column
98
- num_col = next((c for c in col_match if re.search(r'num|price|val|amt|count|qty|sal|rev|cost|pm|aqi|no|co|so|o3', c, re.I)), col_match[1] if len(col_match) > 1 else col_match[0])
99
- return f'SELECT AVG("{num_col}") FROM {quoted}'
100
- if re.search(r'unique|distinct', q):
101
- col_match = re.findall(r'"(\w+)"', schema)
102
- return f'SELECT COUNT(DISTINCT "{col_match[0]}") FROM {quoted}'
103
- if re.search(r'group by', q):
104
- col_match = re.findall(r'"(\w+)"', schema)
105
- return f'SELECT "{col_match[0]}", COUNT(*) FROM {quoted} GROUP BY "{col_match[0]}"'
106
- if re.search(r'max|maximum|highest', q):
107
- col_match = re.findall(r'"(\w+)"', schema)
108
- num_col = col_match[1] if len(col_match) > 1 else col_match[0]
109
- return f'SELECT MAX("{num_col}") FROM {quoted}'
110
- if re.search(r'min|minimum|lowest', q):
111
- col_match = re.findall(r'"(\w+)"', schema)
112
- num_col = col_match[1] if len(col_match) > 1 else col_match[0]
113
- return f'SELECT MIN("{num_col}") FROM {quoted}'
114
-
115
- # ── T5 model fallback ─────────────────────────────────────────────────
116
- col_match = re.findall(r'"(\w+)"', schema)
117
- col_hint = ", ".join(col_match)
118
- prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
119
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
120
- with torch.no_grad():
121
- outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
122
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
123
- sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
124
- sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
125
- sql = re.sub(
126
- r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
127
- r'\1', sql, flags=re.IGNORECASE
128
  )
129
- if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
130
- sql = f'SELECT * FROM {quoted} LIMIT 10'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  return sql
132
 
 
133
 
134
- def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
135
- with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
136
- tmp.write(db_bytes)
137
- tmp_path = tmp.name
138
- conn = sqlite3.connect(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
139
  conn.row_factory = sqlite3.Row
140
  try:
141
- cur = conn.execute(sql)
142
- rows = [dict(r) for r in cur.fetchall()]
 
 
143
  except Exception as e:
144
- conn.close(); os.unlink(tmp_path)
145
- raise HTTPException(status_code=400, detail=f"SQL error: {e}")
146
- conn.close(); os.unlink(tmp_path)
147
- return rows
148
-
149
-
150
- def format_table(rows: list[dict]) -> str:
151
- """Format query results as plain text for Telegram."""
152
- if not rows:
153
- return "No rows returned."
154
- cols = list(rows[0].keys())
155
- # Simple text table
156
- lines = [" | ".join(cols)]
157
- lines.append("-" * len(lines[0]))
158
- for r in rows[:20]:
159
- lines.append(" | ".join(str(r[c]) if r[c] is not None else "null" for c in cols))
160
- if len(rows) > 20:
161
- lines.append(f"... ({len(rows)} rows total, showing 20)")
162
- return "\n".join(lines)
163
 
 
164
 
165
- # ── Telegram API helpers ───────────────────────────────────────────────────────
166
- async def tg(method: str, **kwargs):
167
- try:
168
- async with httpx.AsyncClient(timeout=30) as client:
169
- r = await client.post(f"{TELEGRAM_API}/{method}", json=kwargs)
170
- return r.json()
171
- except Exception as e:
172
- print(f"[ERROR] Telegram API call failed ({method}): {e}")
173
- return {"ok": False, "error": str(e)}
174
-
175
-
176
- async def send_msg(chat_id: int, text: str, reply_markup=None, parse_mode="Markdown"):
177
- payload = dict(chat_id=chat_id, text=text, parse_mode=parse_mode)
178
- if reply_markup:
179
- payload["reply_markup"] = reply_markup
180
- return await tg("sendMessage", **payload)
181
-
182
-
183
- async def send_doc_request(chat_id: int):
184
- """Ask user to send a CSV file."""
185
- await send_msg(
186
- chat_id,
187
- "📂 *Send me a CSV file* to get started!\n\nI'll convert your questions to SQL and query it instantly.",
188
- reply_markup={
189
- "inline_keyboard": [[
190
- {"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}
191
- ]]
192
- }
193
- )
194
-
195
-
196
- # ── REST: CSV Upload (used by both bot and webapp) ────────────────────────────
197
  @app.post("/upload")
198
- async def upload_csv(file: UploadFile = File(...), user_id: int = 0):
199
- if not file.filename.endswith(".csv"):
200
- raise HTTPException(status_code=400, detail="Only CSV files accepted.")
201
  contents = await file.read()
202
- try:
203
- df = pd.read_csv(io.BytesIO(contents))
204
- except Exception as e:
205
- raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
206
-
207
- session_id = hashlib.md5(contents[:1024]).hexdigest()[:12]
208
- table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
209
- if table_name[0].isdigit():
210
- table_name = "t_" + table_name
211
-
212
- db_bytes = csv_to_sqlite(df, table_name)
213
- schema = get_schema(db_bytes)
214
- columns = list(df.columns)
215
-
216
- _db_store[session_id] = db_bytes
217
  _schema_store[session_id] = schema
218
- _col_store[session_id] = columns
219
- _table_store[session_id] = table_name
220
- if user_id:
221
- _user_session[user_id] = session_id
222
 
223
- return JSONResponse({
224
  "session_id": session_id,
225
- "table_name": table_name,
226
- "columns": columns,
227
- "row_count": len(df),
228
- "schema": schema,
229
- "preview": df.head(5).to_dict(orient="records"),
230
- })
231
-
232
-
233
- # ── REST: Query (used by both bot and webapp) ─────────────────────────────────
234
- class QueryRequest(BaseModel):
235
- session_id: str
236
- question: str
237
 
238
  @app.post("/query")
239
  async def query(req: QueryRequest):
240
  if req.session_id not in _db_store:
241
- raise HTTPException(status_code=404, detail="Session not found. Upload CSV first.")
242
- schema = _schema_store[req.session_id]
243
- table_name = _table_store[req.session_id]
244
- sql = generate_sql(req.question, schema, table_name)
245
- results = execute_sql(sql, _db_store[req.session_id])
246
- return JSONResponse({"sql": sql, "results": results})
247
-
 
 
 
 
 
 
 
 
248
 
249
- # ── Web App route ──────────────────────────────────────────────────────────────
250
- @app.get("/webapp", response_class=HTMLResponse)
251
- async def webapp():
252
- return FileResponse("static/webapp.html")
253
 
 
 
254
 
255
  @app.get("/")
256
- async def root():
257
- return FileResponse("static/webapp.html")
258
-
259
 
260
- # ── Health ────────────────────────────────────────────────────────────────────
261
  @app.get("/health")
262
  def health():
263
- return {"status": "ok", "model": MODEL_NAME, "device": DEVICE, "bot": bool(BOT_TOKEN)}
264
-
265
-
266
- # ── Telegram Webhook ──────────────────────────────────────────────────────────
267
- @app.post(f"/webhook/{WEBHOOK_SECRET}")
268
- async def webhook(request: Request):
269
- update = await request.json()
270
-
271
- # Handle document (CSV upload via bot)
272
- msg = update.get("message", {})
273
- if not msg:
274
- msg = update.get("edited_message", {})
275
- chat_id = msg.get("chat", {}).get("id")
276
- user_id = msg.get("from", {}).get("id", 0)
277
- text = msg.get("text", "").strip()
278
-
279
- # ── /start ──
280
- if text in ["/start", "/help"]:
281
- await send_msg(
282
- chat_id,
283
- "👋 *Nilotpal SQL Bot*\n\n"
284
- "I convert plain English questions into SQL and query your CSV data.\n\n"
285
- "📌 *How to use:*\n"
286
- "1️⃣ Send a CSV file\n"
287
- "2️⃣ Ask me anything about your data\n\n"
288
- "Or use the Web App for a richer experience ↓",
289
- reply_markup={
290
- "inline_keyboard": [[
291
- {"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}
292
- ]]
293
- }
294
- )
295
- return {"ok": True}
296
-
297
- # ── CSV Document ──
298
- doc = msg.get("document")
299
- if doc and doc.get("file_name", "").endswith(".csv"):
300
- await send_msg(chat_id, "⏳ Processing your CSV...")
301
- # Download file from Telegram
302
- file_info = await tg("getFile", file_id=doc["file_id"])
303
- file_path = file_info["result"]["file_path"]
304
- async with httpx.AsyncClient() as client:
305
- file_resp = await client.get(f"https://api.telegram.org/file/bot{BOT_TOKEN}/{file_path}")
306
- contents = file_resp.content
307
- try:
308
- df = pd.read_csv(io.BytesIO(contents))
309
- except Exception as e:
310
- await send_msg(chat_id, f"❌ Could not parse CSV: {e}")
311
- return {"ok": True}
312
-
313
- fname = doc["file_name"]
314
- session_id = hashlib.md5(contents[:1024]).hexdigest()[:12]
315
- table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(fname)[0])[:32] or "data"
316
- if table_name[0].isdigit():
317
- table_name = "t_" + table_name
318
-
319
- db_bytes = csv_to_sqlite(df, table_name)
320
- schema = get_schema(db_bytes)
321
- columns = list(df.columns)
322
-
323
- _db_store[session_id] = db_bytes
324
- _schema_store[session_id] = schema
325
- _col_store[session_id] = columns
326
- _table_store[session_id] = table_name
327
- _user_session[user_id] = session_id
328
-
329
- col_preview = ", ".join(columns[:8]) + ("..." if len(columns) > 8 else "")
330
- await send_msg(
331
- chat_id,
332
- f"✅ *Loaded:* `{fname}`\n"
333
- f"📊 *{len(df):,} rows · {len(columns)} columns*\n"
334
- f"📋 *Columns:* `{col_preview}`\n\n"
335
- f"Now ask me anything about your data!\n"
336
- f'Example: _"Show first 10 rows"_',
337
- reply_markup={
338
- "inline_keyboard": [
339
- [{"text": "📊 Show first 10 rows", "callback_data": f"q:{session_id}:Show the first 10 rows"}],
340
- [{"text": "🔢 Count total records", "callback_data": f"q:{session_id}:Count total number of records"}],
341
- [{"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}],
342
- ]
343
- }
344
- )
345
- return {"ok": True}
346
-
347
- # ── Text question ──
348
- if text and not text.startswith("/"):
349
- sid = _user_session.get(user_id)
350
- if not sid or sid not in _db_store:
351
- await send_msg(
352
- chat_id,
353
- "📂 Please send a CSV file first so I can query it for you.",
354
- reply_markup={
355
- "inline_keyboard": [[
356
- {"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}
357
- ]]
358
- }
359
- )
360
- return {"ok": True}
361
-
362
- await tg("sendChatAction", chat_id=chat_id, action="typing")
363
- try:
364
- schema = _schema_store[sid]
365
- table_name = _table_store[sid]
366
- sql = generate_sql(text, schema, table_name)
367
- results = execute_sql(sql, _db_store[sid])
368
- table_txt = format_table(results)
369
- reply = f"🔍 *Query*\n```sql\n{sql}\n```\n\n📋 *Results*\n```\n{table_txt}\n```"
370
- except Exception as e:
371
- reply = f"⚠️ Error: {e}"
372
-
373
- await send_msg(chat_id, reply, parse_mode="Markdown")
374
- return {"ok": True}
375
-
376
- # ── Callback query (button press) ──
377
- cb = update.get("callback_query", {})
378
- if cb:
379
- cb_id = cb["id"]
380
- cb_data = cb.get("data", "")
381
- cb_chat = cb["message"]["chat"]["id"]
382
- cb_user = cb["from"]["id"]
383
-
384
- if cb_data.startswith("q:"):
385
- _, sid, question = cb_data.split(":", 2)
386
- if sid not in _db_store:
387
- await tg("answerCallbackQuery", callback_query_id=cb_id, text="Session expired. Re-upload CSV.")
388
- return {"ok": True}
389
- await tg("answerCallbackQuery", callback_query_id=cb_id, text="Running query...")
390
- await tg("sendChatAction", chat_id=cb_chat, action="typing")
391
- try:
392
- schema = _schema_store[sid]
393
- table_name = _table_store[sid]
394
- sql = generate_sql(question, schema, table_name)
395
- results = execute_sql(sql, _db_store[sid])
396
- table_txt = format_table(results)
397
- reply = f"🔍 *Query*\n```sql\n{sql}\n```\n\n📋 *Results*\n```\n{table_txt}\n```"
398
- except Exception as e:
399
- reply = f"⚠️ Error: {e}"
400
- await send_msg(cb_chat, reply, parse_mode="Markdown")
401
-
402
- return {"ok": True}
403
-
404
-
405
- # ── Startup: register webhook ─────────────────────────────────────────────────
406
- @app.on_event("startup")
407
- async def set_webhook():
408
- if not BOT_TOKEN or not SPACE_URL:
409
- print("[WARN] BOT_TOKEN or SPACE_URL not set — webhook skipped.")
410
- return
411
- url = f"{SPACE_URL}/webhook/{WEBHOOK_SECRET}"
412
- for attempt in range(1, 4):
413
- try:
414
- async with httpx.AsyncClient(timeout=15) as client:
415
- r = await client.post(f"{TELEGRAM_API}/setWebhook", json={"url": url})
416
- print(f"[INFO] Webhook set: {r.json()}")
417
- return
418
- except Exception as e:
419
- print(f"[WARN] Webhook attempt {attempt}/3 failed: {e}")
420
- if attempt < 3:
421
- import asyncio; await asyncio.sleep(3)
422
- print("[WARN] Webhook registration failed — bot still runs, set webhook manually.")
 
1
  """
2
+ QueryMindCSV-to-SQL Engine (v3.0.0 - Gemini Powered)
3
+ Engine: Gemini 1.5 Flash + Heuristic Rules
4
+ Hardware: HuggingFace Free Tier (Ultra-Light)
 
 
5
  """
6
 
7
  import os
 
10
  import json
11
  import sqlite3
12
  import tempfile
 
13
  import pandas as pd
14
+ import urllib.request
15
+ from fastapi import FastAPI, File, UploadFile, HTTPException
16
  from fastapi.staticfiles import StaticFiles
17
+ from fastapi.responses import FileResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel
 
 
 
20
 
21
+ # ── Configuration ──────────────────────────────────────────────────────────────
22
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
 
 
 
 
 
23
 
24
+ _db_store = {}
25
+ _schema_store = {}
26
 
27
+ app = FastAPI(title="QueryMind Gemini", version="3.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ class QueryRequest(BaseModel):
31
+ session_id: str
32
+ question: str
33
 
34
+ # ── Heuristic Logic (Fast Layer) ──────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def _find_col(question: str, columns: list) -> str or None:
37
+ q = question.lower()
38
+ # Sort by length DESC so 'AQI_Bucket' matches before 'AQI'
39
+ for col in sorted(columns, key=len, reverse=True):
40
+ if col.lower() in q:
41
+ return col
42
+ return None
43
 
44
+ def _heuristic_sql(question: str, table: str, columns: list) -> str or None:
 
45
  q = question.lower().strip()
46
+ t = f'"{table}"'
47
+
48
+ if re.search(r'\bgroup\s+by\b', q):
49
+ col = _find_col(q, columns) or columns[0]
50
+ return f'SELECT "{col}", COUNT(*) AS count FROM {t} GROUP BY "{col}" ORDER BY count DESC'
51
+
52
+ if re.search(r'\bunique\b|\bdistinct\b', q):
53
+ col = _find_col(q, columns) or columns[0]
54
+ if re.search(r'\bhow many\b|\bcount\b', q):
55
+ return f'SELECT COUNT(DISTINCT "{col}") AS unique_count FROM {t}'
56
+ return f'SELECT DISTINCT "{col}" FROM {t} LIMIT 50'
57
+
58
+ if re.search(r'\bhow many\b|\bcount\b|\btotal\s+(records|rows)\b', q):
59
+ return f'SELECT COUNT(*) AS total_rows FROM {t}'
60
+
61
+ if re.search(r'\baverage\b|\bavg\b', q):
62
+ col = _find_col(q, columns) or columns[0]
63
+ return f'SELECT AVG(CAST("{col}" AS REAL)) AS average FROM {t}'
64
+
65
+ if re.search(r'\bfirst\b|\bpreview\b|\bshow\b|\bhead\b', q):
66
+ m = re.search(r'\b(\d+)\b', q)
67
+ return f'SELECT * FROM {t} LIMIT {int(m.group(1)) if m else 10}'
68
+
69
+ return None
70
+
71
+ # ── Gemini API Call (Neural Layer) ───────────────────────────────────────────
72
+
73
+ def _call_gemini(question: str, schema: str, columns: list, table: str) -> str:
74
+ if not GEMINI_API_KEY:
75
+ raise Exception("Gemini API Key missing")
76
+
77
+ col_list = ", ".join(columns[:30])
78
+ prompt = (
79
+ f"You are a SQLite expert. Output ONLY a single valid SQLite SELECT statement. "
80
+ f"No explanation, no backticks, no markdown.\n\n"
81
+ f"Table: {table}\n"
82
+ f"Columns: {col_list}\n"
83
+ f"Schema: {schema}\n\n"
84
+ f"Question: {question}\n\nSQL:"
 
 
 
 
 
85
  )
86
+
87
+ payload = json.dumps({
88
+ "contents": [{"parts": [{"text": prompt}]}],
89
+ "generationConfig": {"temperature": 0, "maxOutputTokens": 200}
90
+ }).encode("utf-8")
91
+
92
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={GEMINI_API_KEY}"
93
+
94
+ req = urllib.request.Request(url, data=payload, headers={"Content-Type": "application/json"})
95
+
96
+ with urllib.request.urlopen(req, timeout=10) as resp:
97
+ data = json.loads(resp.read())
98
+ sql = data["candidates"][0]["content"]["parts"][0]["text"].strip()
99
+
100
+ # Cleaning up common LLM artifacts
101
+ sql = sql.replace("```sql", "").replace("```", "").strip()
102
+ sql = sql.split(";")[0].strip()
103
+ # Force the correct table name into the generated SQL
104
+ sql = re.sub(r'\bFROM\s+["\'\w\.]+', f'FROM "{table}"', sql, flags=re.IGNORECASE)
105
  return sql
106
 
107
+ # ── Logic Helpers ──────────────────────────────────────────────────────────────
108
 
109
+ def csv_to_sqlite(df, table_name):
110
+ temp_db = io.BytesIO()
111
+ conn = sqlite3.connect(temp_db)
112
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
113
+ # Extract schema string
114
+ schema = conn.execute("SELECT sql FROM sqlite_master WHERE type='table'").fetchone()[0]
115
+ conn.close()
116
+ return temp_db.getvalue(), schema
117
+
118
+ def execute_sql(sql, db_bytes):
119
+ # Load DB into memory for execution
120
+ conn = sqlite3.connect(":memory:")
121
+ source = sqlite3.connect(io.BytesIO(db_bytes))
122
+ source.backup(conn)
123
+ source.close()
124
+
125
  conn.row_factory = sqlite3.Row
126
  try:
127
+ cur = conn.execute(sql)
128
+ results = [dict(r) for r in cur.fetchall()]
129
+ conn.close()
130
+ return results
131
  except Exception as e:
132
+ conn.close()
133
+ raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # ── API Endpoints ─────────────────────────────────────────────────────────────
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  @app.post("/upload")
138
+ async def upload_csv(file: UploadFile = File(...)):
 
 
139
  contents = await file.read()
140
+ df = pd.read_csv(io.BytesIO(contents)).dropna(how='all')
141
+
142
+ session_id = os.urandom(8).hex()
143
+ clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', os.path.splitext(file.filename)[0])
144
+ if clean_name[0].isdigit(): clean_name = "t_" + clean_name
145
+ table_name = clean_name[:32]
146
+
147
+ db_bytes, schema = csv_to_sqlite(df, table_name)
148
+ _db_store[session_id] = {"bytes": db_bytes, "table": table_name, "cols": list(df.columns)}
 
 
 
 
 
 
149
  _schema_store[session_id] = schema
 
 
 
 
150
 
151
+ return {
152
  "session_id": session_id,
153
+ "columns": list(df.columns),
154
+ "preview": df.head(5).to_dict(orient="records"),
155
+ "table_name": table_name
156
+ }
 
 
 
 
 
 
 
 
157
 
158
  @app.post("/query")
159
  async def query(req: QueryRequest):
160
  if req.session_id not in _db_store:
161
+ raise HTTPException(status_code=404, detail="Session expired.")
162
+
163
+ data = _db_store[req.session_id]
164
+ schema = _schema_store[req.session_id]
165
+
166
+ # 1. Try Fast Heuristics
167
+ sql = _heuristic_sql(req.question, data["table"], data["cols"])
168
+
169
+ # 2. Try Gemini
170
+ if not sql:
171
+ try:
172
+ sql = _call_gemini(req.question, schema, data["cols"], data["table"])
173
+ except Exception as e:
174
+ print(f"[API ERROR] {e}")
175
+ raise HTTPException(status_code=500, detail="Gemini API failed.")
176
 
177
+ results = execute_sql(sql, data["bytes"])
178
+ return {"sql": sql, "results": results}
 
 
179
 
180
+ # ── Static & Main ──
181
+ app.mount("/static", StaticFiles(directory="static"), name="static")
182
 
183
  @app.get("/")
184
+ def root():
185
+ return FileResponse("static/index.html")
 
186
 
 
187
  @app.get("/health")
188
  def health():
189
+ return {"status": "ok", "mode": "gemini-api"}