nilotpaldhar2004 commited on
Commit
ff03062
·
unverified ·
1 Parent(s): 10f0e68

Add files via upload

Browse files
Files changed (1) hide show
  1. app.py +377 -0
app.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nilotpal SQL Bot — Telegram Bot + Web App
3
+ FastAPI backend serving:
4
+ - Telegram Bot (standard messages + inline buttons)
5
+ - Telegram Web App (full HTML/CSS/JS UI via /webapp)
6
+ Model: cssupport/t5-small-awesome-text-to-sql (CPU-friendly)
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import io
12
+ import json
13
+ import sqlite3
14
+ import tempfile
15
+ import hashlib
16
+ import pandas as pd
17
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Request
18
+ from fastapi.staticfiles import StaticFiles
19
+ from fastapi.responses import FileResponse, JSONResponse, HTMLResponse
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from pydantic import BaseModel
22
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
23
+ import torch
24
+ import httpx
25
+
26
+ # ── Config ────────────────────────────────────────────────────────────────────
27
+ MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"
28
+ MAX_NEW_TOKENS = 256
29
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
30
+ BOT_TOKEN = os.getenv("BOT_TOKEN", "") # set in HF Space secrets
31
+ WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "nilotpalsqlbot")
32
+ SPACE_URL = os.getenv("SPACE_URL", "") # e.g. https://nilotpaldhar2004-nilotpal-sql-bot.hf.space
33
+
34
+ TELEGRAM_API = f"https://api.telegram.org/bot{BOT_TOKEN}"
35
+
36
+ # ── Load model ────────────────────────────────────────────────────────────────
37
+ print(f"[INFO] Loading {MODEL_NAME} on {DEVICE}...")
38
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
40
+ model.eval()
41
+ print("[INFO] Model ready.")
42
+
43
+ # ── In-memory stores ──────────────────────────────────────────────────────────
44
+ _db_store: dict[str, bytes] = {} # session_id → sqlite bytes
45
+ _schema_store: dict[str, str] = {} # session_id → schema string
46
+ _col_store: dict[str, list] = {} # session_id → column list
47
+ _table_store: dict[str, str] = {} # session_id → table name
48
+ _user_session: dict[int, str] = {} # telegram user_id → session_id
49
+
50
+ app = FastAPI(title="Nilotpal SQL Bot", version="1.0.0")
51
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
52
+ app.mount("/static", StaticFiles(directory="static"), name="static")
53
+
54
+
55
+ # ── Helpers ───────────────────────────────────────────────────────────────────
56
+ def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
57
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
58
+ tmp_path = tmp.name
59
+ conn = sqlite3.connect(tmp_path)
60
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
61
+ conn.close()
62
+ with open(tmp_path, "rb") as f:
63
+ db_bytes = f.read()
64
+ os.unlink(tmp_path)
65
+ return db_bytes
66
+
67
+
68
+ def get_schema(db_bytes: bytes) -> str:
69
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
70
+ tmp.write(db_bytes)
71
+ tmp_path = tmp.name
72
+ conn = sqlite3.connect(tmp_path)
73
+ cur = conn.cursor()
74
+ cur.execute("SELECT sql FROM sqlite_master WHERE type='table'")
75
+ rows = cur.fetchall()
76
+ conn.close()
77
+ os.unlink(tmp_path)
78
+ return "\n".join(r[0] for r in rows if r[0])
79
+
80
+
81
+ def generate_sql(question: str, schema: str, table_name: str) -> str:
82
+ quoted = f'"{table_name}"'
83
+ col_match = re.findall(r'"(\w+)"', schema)
84
+ col_hint = ", ".join(col_match)
85
+ prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
86
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
87
+ with torch.no_grad():
88
+ outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
89
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
90
+ sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
91
+ sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
92
+ sql = re.sub(
93
+ r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
94
+ r'\1', sql, flags=re.IGNORECASE
95
+ )
96
+ if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
97
+ sql = f'SELECT * FROM {quoted} LIMIT 10'
98
+ return sql
99
+
100
+
101
+ def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
102
+ with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
103
+ tmp.write(db_bytes)
104
+ tmp_path = tmp.name
105
+ conn = sqlite3.connect(tmp_path)
106
+ conn.row_factory = sqlite3.Row
107
+ try:
108
+ cur = conn.execute(sql)
109
+ rows = [dict(r) for r in cur.fetchall()]
110
+ except Exception as e:
111
+ conn.close(); os.unlink(tmp_path)
112
+ raise HTTPException(status_code=400, detail=f"SQL error: {e}")
113
+ conn.close(); os.unlink(tmp_path)
114
+ return rows
115
+
116
+
117
+ def format_table(rows: list[dict]) -> str:
118
+ """Format query results as plain text for Telegram."""
119
+ if not rows:
120
+ return "No rows returned."
121
+ cols = list(rows[0].keys())
122
+ # Simple text table
123
+ lines = [" | ".join(cols)]
124
+ lines.append("-" * len(lines[0]))
125
+ for r in rows[:20]:
126
+ lines.append(" | ".join(str(r[c]) if r[c] is not None else "null" for c in cols))
127
+ if len(rows) > 20:
128
+ lines.append(f"... ({len(rows)} rows total, showing 20)")
129
+ return "\n".join(lines)
130
+
131
+
132
+ # ── Telegram API helpers ───────────────────────────────────────────────────────
133
+ async def tg(method: str, **kwargs):
134
+ async with httpx.AsyncClient(timeout=30) as client:
135
+ r = await client.post(f"{TELEGRAM_API}/{method}", json=kwargs)
136
+ return r.json()
137
+
138
+
139
+ async def send_msg(chat_id: int, text: str, reply_markup=None, parse_mode="Markdown"):
140
+ payload = dict(chat_id=chat_id, text=text, parse_mode=parse_mode)
141
+ if reply_markup:
142
+ payload["reply_markup"] = reply_markup
143
+ return await tg("sendMessage", **payload)
144
+
145
+
146
+ async def send_doc_request(chat_id: int):
147
+ """Ask user to send a CSV file."""
148
+ await send_msg(
149
+ chat_id,
150
+ "📂 *Send me a CSV file* to get started!\n\nI'll convert your questions to SQL and query it instantly.",
151
+ reply_markup={
152
+ "inline_keyboard": [[
153
+ {"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}
154
+ ]]
155
+ }
156
+ )
157
+
158
+
159
+ # ── REST: CSV Upload (used by both bot and webapp) ────────────────────────────
160
+ @app.post("/upload")
161
+ async def upload_csv(file: UploadFile = File(...), user_id: int = 0):
162
+ if not file.filename.endswith(".csv"):
163
+ raise HTTPException(status_code=400, detail="Only CSV files accepted.")
164
+ contents = await file.read()
165
+ try:
166
+ df = pd.read_csv(io.BytesIO(contents))
167
+ except Exception as e:
168
+ raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
169
+
170
+ session_id = hashlib.md5(contents[:1024]).hexdigest()[:12]
171
+ table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
172
+ if table_name[0].isdigit():
173
+ table_name = "t_" + table_name
174
+
175
+ db_bytes = csv_to_sqlite(df, table_name)
176
+ schema = get_schema(db_bytes)
177
+ columns = list(df.columns)
178
+
179
+ _db_store[session_id] = db_bytes
180
+ _schema_store[session_id] = schema
181
+ _col_store[session_id] = columns
182
+ _table_store[session_id] = table_name
183
+ if user_id:
184
+ _user_session[user_id] = session_id
185
+
186
+ return JSONResponse({
187
+ "session_id": session_id,
188
+ "table_name": table_name,
189
+ "columns": columns,
190
+ "row_count": len(df),
191
+ "schema": schema,
192
+ "preview": df.head(5).to_dict(orient="records"),
193
+ })
194
+
195
+
196
+ # ── REST: Query (used by both bot and webapp) ─────────────────────────────────
197
+ class QueryRequest(BaseModel):
198
+ session_id: str
199
+ question: str
200
+
201
+ @app.post("/query")
202
+ async def query(req: QueryRequest):
203
+ if req.session_id not in _db_store:
204
+ raise HTTPException(status_code=404, detail="Session not found. Upload CSV first.")
205
+ schema = _schema_store[req.session_id]
206
+ table_name = _table_store[req.session_id]
207
+ sql = generate_sql(req.question, schema, table_name)
208
+ results = execute_sql(sql, _db_store[req.session_id])
209
+ return JSONResponse({"sql": sql, "results": results})
210
+
211
+
212
+ # ── Web App route ──────────────────────────────────────────────────────────────
213
+ @app.get("/webapp", response_class=HTMLResponse)
214
+ async def webapp():
215
+ return FileResponse("static/webapp.html")
216
+
217
+
218
+ @app.get("/")
219
+ async def root():
220
+ return FileResponse("static/webapp.html")
221
+
222
+
223
+ # ── Health ────────────────────────────────────────────────────────────────────
224
+ @app.get("/health")
225
+ def health():
226
+ return {"status": "ok", "model": MODEL_NAME, "device": DEVICE, "bot": bool(BOT_TOKEN)}
227
+
228
+
229
+ # ── Telegram Webhook ──────────────────────────────────────────────────────────
230
+ @app.post(f"/webhook/{WEBHOOK_SECRET}")
231
+ async def webhook(request: Request):
232
+ update = await request.json()
233
+
234
+ # Handle document (CSV upload via bot)
235
+ msg = update.get("message", {})
236
+ if not msg:
237
+ msg = update.get("edited_message", {})
238
+ chat_id = msg.get("chat", {}).get("id")
239
+ user_id = msg.get("from", {}).get("id", 0)
240
+ text = msg.get("text", "").strip()
241
+
242
+ # ── /start ──
243
+ if text in ["/start", "/help"]:
244
+ await send_msg(
245
+ chat_id,
246
+ "👋 *Nilotpal SQL Bot*\n\n"
247
+ "I convert plain English questions into SQL and query your CSV data.\n\n"
248
+ "📌 *How to use:*\n"
249
+ "1️⃣ Send a CSV file\n"
250
+ "2️⃣ Ask me anything about your data\n\n"
251
+ "Or use the Web App for a richer experience ↓",
252
+ reply_markup={
253
+ "inline_keyboard": [[
254
+ {"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}
255
+ ]]
256
+ }
257
+ )
258
+ return {"ok": True}
259
+
260
+ # ── CSV Document ──
261
+ doc = msg.get("document")
262
+ if doc and doc.get("file_name", "").endswith(".csv"):
263
+ await send_msg(chat_id, "⏳ Processing your CSV...")
264
+ # Download file from Telegram
265
+ file_info = await tg("getFile", file_id=doc["file_id"])
266
+ file_path = file_info["result"]["file_path"]
267
+ async with httpx.AsyncClient() as client:
268
+ file_resp = await client.get(f"https://api.telegram.org/file/bot{BOT_TOKEN}/{file_path}")
269
+ contents = file_resp.content
270
+ try:
271
+ df = pd.read_csv(io.BytesIO(contents))
272
+ except Exception as e:
273
+ await send_msg(chat_id, f"❌ Could not parse CSV: {e}")
274
+ return {"ok": True}
275
+
276
+ fname = doc["file_name"]
277
+ session_id = hashlib.md5(contents[:1024]).hexdigest()[:12]
278
+ table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(fname)[0])[:32] or "data"
279
+ if table_name[0].isdigit():
280
+ table_name = "t_" + table_name
281
+
282
+ db_bytes = csv_to_sqlite(df, table_name)
283
+ schema = get_schema(db_bytes)
284
+ columns = list(df.columns)
285
+
286
+ _db_store[session_id] = db_bytes
287
+ _schema_store[session_id] = schema
288
+ _col_store[session_id] = columns
289
+ _table_store[session_id] = table_name
290
+ _user_session[user_id] = session_id
291
+
292
+ col_preview = ", ".join(columns[:8]) + ("..." if len(columns) > 8 else "")
293
+ await send_msg(
294
+ chat_id,
295
+ f"✅ *Loaded:* `{fname}`\n"
296
+ f"📊 *{len(df):,} rows · {len(columns)} columns*\n"
297
+ f"📋 *Columns:* `{col_preview}`\n\n"
298
+ f"Now ask me anything about your data!\n"
299
+ f'Example: _"Show first 10 rows"_',
300
+ reply_markup={
301
+ "inline_keyboard": [
302
+ [{"text": "📊 Show first 10 rows", "callback_data": f"q:{session_id}:Show the first 10 rows"}],
303
+ [{"text": "🔢 Count total records", "callback_data": f"q:{session_id}:Count total number of records"}],
304
+ [{"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}],
305
+ ]
306
+ }
307
+ )
308
+ return {"ok": True}
309
+
310
+ # ── Text question ──
311
+ if text and not text.startswith("/"):
312
+ sid = _user_session.get(user_id)
313
+ if not sid or sid not in _db_store:
314
+ await send_msg(
315
+ chat_id,
316
+ "📂 Please send a CSV file first so I can query it for you.",
317
+ reply_markup={
318
+ "inline_keyboard": [[
319
+ {"text": "🌐 Open Web App", "web_app": {"url": f"{SPACE_URL}/webapp"}}
320
+ ]]
321
+ }
322
+ )
323
+ return {"ok": True}
324
+
325
+ await tg("sendChatAction", chat_id=chat_id, action="typing")
326
+ try:
327
+ schema = _schema_store[sid]
328
+ table_name = _table_store[sid]
329
+ sql = generate_sql(text, schema, table_name)
330
+ results = execute_sql(sql, _db_store[sid])
331
+ table_txt = format_table(results)
332
+ reply = f"🔍 *Query*\n```sql\n{sql}\n```\n\n📋 *Results*\n```\n{table_txt}\n```"
333
+ except Exception as e:
334
+ reply = f"⚠️ Error: {e}"
335
+
336
+ await send_msg(chat_id, reply, parse_mode="Markdown")
337
+ return {"ok": True}
338
+
339
+ # ── Callback query (button press) ──
340
+ cb = update.get("callback_query", {})
341
+ if cb:
342
+ cb_id = cb["id"]
343
+ cb_data = cb.get("data", "")
344
+ cb_chat = cb["message"]["chat"]["id"]
345
+ cb_user = cb["from"]["id"]
346
+
347
+ if cb_data.startswith("q:"):
348
+ _, sid, question = cb_data.split(":", 2)
349
+ if sid not in _db_store:
350
+ await tg("answerCallbackQuery", callback_query_id=cb_id, text="Session expired. Re-upload CSV.")
351
+ return {"ok": True}
352
+ await tg("answerCallbackQuery", callback_query_id=cb_id, text="Running query...")
353
+ await tg("sendChatAction", chat_id=cb_chat, action="typing")
354
+ try:
355
+ schema = _schema_store[sid]
356
+ table_name = _table_store[sid]
357
+ sql = generate_sql(question, schema, table_name)
358
+ results = execute_sql(sql, _db_store[sid])
359
+ table_txt = format_table(results)
360
+ reply = f"🔍 *Query*\n```sql\n{sql}\n```\n\n📋 *Results*\n```\n{table_txt}\n```"
361
+ except Exception as e:
362
+ reply = f"⚠️ Error: {e}"
363
+ await send_msg(cb_chat, reply, parse_mode="Markdown")
364
+
365
+ return {"ok": True}
366
+
367
+
368
+ # ── Startup: register webhook ─────────────────────────────────────────────────
369
+ @app.on_event("startup")
370
+ async def set_webhook():
371
+ if BOT_TOKEN and SPACE_URL:
372
+ url = f"{SPACE_URL}/webhook/{WEBHOOK_SECRET}"
373
+ async with httpx.AsyncClient() as client:
374
+ r = await client.post(f"{TELEGRAM_API}/setWebhook", json={"url": url})
375
+ print(f"[INFO] Webhook set: {r.json()}")
376
+ else:
377
+ print("[WARN] BOT_TOKEN or SPACE_URL not set — webhook not registered.")