nilotpaldhar2004 commited on
Commit
b7d2418
·
unverified ·
1 Parent(s): 3671635

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -37
app.py CHANGED
@@ -1,8 +1,15 @@
 
 
 
 
 
 
1
  import os
2
  import re
3
  import io
4
- import json
5
  import sqlite3
 
6
  import pandas as pd
7
  import urllib.request
8
  from fastapi import FastAPI, File, UploadFile, HTTPException
@@ -12,51 +19,49 @@ from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
13
 
14
  # ── Configuration ──────────────────────────────────────────────────────────────
 
15
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
16
 
17
  _db_store = {}
18
  _schema_store = {}
19
 
20
- app = FastAPI(title="QueryMind Gemini", version="3.0.3")
21
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
 
 
 
 
22
 
23
  class QueryRequest(BaseModel):
24
  session_id: str
25
  question: str
26
 
27
- # ── Logic Helpers ──────────────────────────────────────────────────────────────
28
-
29
- def _find_col(question: str, columns: list) -> str or None:
30
- q = question.lower()
31
- for col in sorted(columns, key=len, reverse=True):
32
- if col.lower() in q: return col
33
- return None
34
-
35
- # ── Improved Heuristics (Less aggressive) ──────────────────────────────────────
36
 
37
  def _heuristic_sql(question: str, table: str, columns: list) -> str or None:
 
38
  q = question.lower().strip()
39
  t = f'"{table}"'
40
 
41
- # Only trigger heuristics for VERY specific, simple patterns
42
- # "how many records" or "total rows"
43
  if re.fullmatch(r'(how many records|total rows|count rows|count total)', q):
44
  return f'SELECT COUNT(*) AS total_rows FROM {t}'
45
 
46
- # "preview" or "show head"
47
- if re.fullmatch(r'(preview|show head|data preview)', q):
48
  return f'SELECT * FROM {t} LIMIT 10'
49
 
50
- # If it's anything else, return None so Gemini can handle the logic
51
  return None
52
 
53
- # ── Improved Gemini Prompt (For Hard Queries) ──────────────────────────────────
54
 
55
  def _call_gemini(question: str, schema: str, columns: list, table: str) -> str:
56
- if not GEMINI_API_KEY: return ""
 
 
57
 
58
  col_list = ", ".join(columns[:30])
59
- # Improved prompt to handle filtering, grouping, and ordering
60
  prompt = (
61
  f"You are a SQLite expert. Convert the question into a single valid SQL query.\n"
62
  f"Table: {table}\n"
@@ -65,7 +70,7 @@ def _call_gemini(question: str, schema: str, columns: list, table: str) -> str:
65
  f"Question: {question}\n\n"
66
  f"Rules:\n"
67
  f"1. Use double quotes for table and column names.\n"
68
- f"2. Output ONLY the SQL code.\n"
69
  f"3. If the question asks for 'the first', use LIMIT 1.\n"
70
  f"4. If filtering by text, use the LIKE operator for flexibility.\n\n"
71
  f"SQL:"
@@ -73,7 +78,7 @@ def _call_gemini(question: str, schema: str, columns: list, table: str) -> str:
73
 
74
  payload = json.dumps({
75
  "contents": [{"parts": [{"text": prompt}]}],
76
- "generationConfig": {"temperature": 0.1} # Slight temperature for better reasoning
77
  }).encode("utf-8")
78
 
79
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={GEMINI_API_KEY}"
@@ -84,49 +89,57 @@ def _call_gemini(question: str, schema: str, columns: list, table: str) -> str:
84
  data = json.loads(resp.read())
85
  sql = data["candidates"][0]["content"]["parts"][0]["text"].strip()
86
 
87
- # Clean response
88
  sql = sql.replace("```sql", "").replace("```", "").strip().split(";")[0]
89
- # Force correct table name
90
  sql = re.sub(r'\bFROM\s+["\'\w\.]+', f'FROM "{table}"', sql, flags=re.IGNORECASE)
91
  return sql
92
  except Exception as e:
93
- print(f"Gemini Error: {e}")
94
  return ""
95
 
96
- # 🔥 FIXED: Execution logic to handle bytes correctly
 
97
  def execute_sql(sql, db_bytes):
98
- # Create a fresh in-memory DB and load the bytes into it
 
99
  conn = sqlite3.connect(":memory:")
100
- db_file = io.BytesIO(db_bytes)
101
- # We use a temporary disk file because SQLite backup doesn't like raw BytesIO directly
102
  with tempfile.NamedTemporaryFile() as f:
103
  f.write(db_bytes)
104
  f.flush()
105
  disk_conn = sqlite3.connect(f.name)
106
- disk_conn.backup(conn)
107
- disk_conn.close()
 
 
108
 
109
  conn.row_factory = sqlite3.Row
110
  try:
111
  cur = conn.execute(sql)
112
  return [dict(r) for r in cur.fetchall()]
113
- finally: conn.close()
 
 
 
114
 
115
  # ── API Endpoints ─────────────────────────────────────────────────────────────
116
- import tempfile # Added this import
117
 
118
  @app.post("/upload")
119
  async def upload_csv(file: UploadFile = File(...)):
 
120
  try:
121
  contents = await file.read()
122
  df = pd.read_csv(io.BytesIO(contents)).dropna(how='all')
123
 
124
  session_id = os.urandom(8).hex()
 
125
  clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', os.path.splitext(file.filename)[0])
126
  if clean_name[0].isdigit(): clean_name = "t_" + clean_name
127
  table_name = clean_name[:32]
128
 
129
- # 🔥 FIXED: Use a temporary file to bridge Pandas to SQLite bytes
130
  with tempfile.NamedTemporaryFile() as tf:
131
  conn = sqlite3.connect(tf.name)
132
  df.to_sql(table_name, conn, index=False, if_exists="replace")
@@ -136,6 +149,7 @@ async def upload_csv(file: UploadFile = File(...)):
136
  with open(tf.name, "rb") as f:
137
  db_data = f.read()
138
 
 
139
  _db_store[session_id] = {
140
  "bytes": db_data,
141
  "table": table_name,
@@ -143,6 +157,7 @@ async def upload_csv(file: UploadFile = File(...)):
143
  }
144
  _schema_store[session_id] = schema
145
 
 
146
  return {
147
  "session_id": session_id,
148
  "columns": list(df.columns),
@@ -156,17 +171,35 @@ async def upload_csv(file: UploadFile = File(...)):
156
 
157
  @app.post("/query")
158
  async def query(req: QueryRequest):
 
159
  data = _db_store.get(req.session_id)
160
- if not data: raise HTTPException(status_code=404)
161
- sql = _heuristic_sql(req.question, data["table"], data["cols"]) or _call_gemini(req.question, _schema_store[req.session_id], data["cols"], data["table"])
162
- return {"sql": sql, "results": execute_sql(sql, data["bytes"])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  @app.get("/health")
165
  def health():
 
166
  return {"status": "ok", "model": "gemini-1.5-flash"}
167
 
 
168
  app.mount("/static", StaticFiles(directory="static"), name="static")
169
 
170
  @app.get("/")
171
  def root():
 
172
  return FileResponse("static/webapp.html")
 
1
+ """
2
+ QueryMind — CSV-to-SQL Engine (v3.0.4)
3
+ Final Production Build: Gemini 1.5 Flash + Hybrid Heuristics
4
+ Author: Nilotpal Dhar
5
+ """
6
+
7
  import os
8
  import re
9
  import io
10
+ import json
11
  import sqlite3
12
+ import tempfile
13
  import pandas as pd
14
  import urllib.request
15
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
19
  from pydantic import BaseModel
20
 
21
  # ── Configuration ──────────────────────────────────────────────────────────────
22
+ # In HF Spaces, set this in Settings -> Variables and Secrets
23
  GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
24
 
25
  _db_store = {}
26
  _schema_store = {}
27
 
28
+ app = FastAPI(title="QueryMind Gemini", version="3.0.4")
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_methods=["*"],
33
+ allow_headers=["*"]
34
+ )
35
 
36
  class QueryRequest(BaseModel):
37
  session_id: str
38
  question: str
39
 
40
+ # ── Heuristic Logic (Instant Speed Layer) ─────────────────────────────────────
 
 
 
 
 
 
 
 
41
 
42
  def _heuristic_sql(question: str, table: str, columns: list) -> str or None:
43
+ """Handles basic queries locally without calling Gemini to save time/quota."""
44
  q = question.lower().strip()
45
  t = f'"{table}"'
46
 
47
+ # Simple counting
 
48
  if re.fullmatch(r'(how many records|total rows|count rows|count total)', q):
49
  return f'SELECT COUNT(*) AS total_rows FROM {t}'
50
 
51
+ # Simple data preview
52
+ if re.fullmatch(r'(preview|show head|data preview|show all)', q):
53
  return f'SELECT * FROM {t} LIMIT 10'
54
 
 
55
  return None
56
 
57
+ # ── Gemini API Call (Neural Logic Layer) ─────────────────────────────────────
58
 
59
  def _call_gemini(question: str, schema: str, columns: list, table: str) -> str:
60
+ """Calls Gemini 1.5 Flash to translate Natural Language into SQLite."""
61
+ if not GEMINI_API_KEY:
62
+ return ""
63
 
64
  col_list = ", ".join(columns[:30])
 
65
  prompt = (
66
  f"You are a SQLite expert. Convert the question into a single valid SQL query.\n"
67
  f"Table: {table}\n"
 
70
  f"Question: {question}\n\n"
71
  f"Rules:\n"
72
  f"1. Use double quotes for table and column names.\n"
73
+ f"2. Output ONLY the SQL code. No markdown, no explanation.\n"
74
  f"3. If the question asks for 'the first', use LIMIT 1.\n"
75
  f"4. If filtering by text, use the LIKE operator for flexibility.\n\n"
76
  f"SQL:"
 
78
 
79
  payload = json.dumps({
80
  "contents": [{"parts": [{"text": prompt}]}],
81
+ "generationConfig": {"temperature": 0.1, "maxOutputTokens": 300}
82
  }).encode("utf-8")
83
 
84
  url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={GEMINI_API_KEY}"
 
89
  data = json.loads(resp.read())
90
  sql = data["candidates"][0]["content"]["parts"][0]["text"].strip()
91
 
92
+ # Clean potential markdown artifacts
93
  sql = sql.replace("```sql", "").replace("```", "").strip().split(";")[0]
94
+ # Safety check: Force the correct table name from our store
95
  sql = re.sub(r'\bFROM\s+["\'\w\.]+', f'FROM "{table}"', sql, flags=re.IGNORECASE)
96
  return sql
97
  except Exception as e:
98
+ print(f"[GEMINI ERROR] {e}")
99
  return ""
100
 
101
+ # ── Database Management ───────────────────────────────────────────────────────
102
+
103
  def execute_sql(sql, db_bytes):
104
+ """Restores the SQLite database from memory and executes the query."""
105
+ # Create an empty in-memory database
106
  conn = sqlite3.connect(":memory:")
107
+
108
+ # We use a temporary file to bridge bytes back into a SQLite connection
109
  with tempfile.NamedTemporaryFile() as f:
110
  f.write(db_bytes)
111
  f.flush()
112
  disk_conn = sqlite3.connect(f.name)
113
+ try:
114
+ disk_conn.backup(conn)
115
+ finally:
116
+ disk_conn.close()
117
 
118
  conn.row_factory = sqlite3.Row
119
  try:
120
  cur = conn.execute(sql)
121
  return [dict(r) for r in cur.fetchall()]
122
+ except Exception as e:
123
+ return [{"error": str(e)}]
124
+ finally:
125
+ conn.close()
126
 
127
  # ── API Endpoints ─────────────────────────────────────────────────────────────
 
128
 
129
  @app.post("/upload")
130
  async def upload_csv(file: UploadFile = File(...)):
131
+ """Receives CSV, creates session, and prepares in-memory SQL database."""
132
  try:
133
  contents = await file.read()
134
  df = pd.read_csv(io.BytesIO(contents)).dropna(how='all')
135
 
136
  session_id = os.urandom(8).hex()
137
+ # Clean table name for SQLite safety
138
  clean_name = re.sub(r'[^a-zA-Z0-9_]', '_', os.path.splitext(file.filename)[0])
139
  if clean_name[0].isdigit(): clean_name = "t_" + clean_name
140
  table_name = clean_name[:32]
141
 
142
+ # Build the SQL database using a temp file to get raw bytes
143
  with tempfile.NamedTemporaryFile() as tf:
144
  conn = sqlite3.connect(tf.name)
145
  df.to_sql(table_name, conn, index=False, if_exists="replace")
 
149
  with open(tf.name, "rb") as f:
150
  db_data = f.read()
151
 
152
+ # Store session data globally (Shared with bot.py)
153
  _db_store[session_id] = {
154
  "bytes": db_data,
155
  "table": table_name,
 
157
  }
158
  _schema_store[session_id] = schema
159
 
160
+ # This response is synchronized with webapp.html logic
161
  return {
162
  "session_id": session_id,
163
  "columns": list(df.columns),
 
171
 
172
  @app.post("/query")
173
  async def query(req: QueryRequest):
174
+ """Main query handler: Heuristics -> Gemini -> SQL Execution."""
175
  data = _db_store.get(req.session_id)
176
+ if not data:
177
+ raise HTTPException(status_code=404, detail="Session expired. Please re-upload your file.")
178
+
179
+ # 1. Try Heuristics
180
+ sql = _heuristic_sql(req.question, data["table"], data["cols"])
181
+
182
+ # 2. Try Gemini
183
+ if not sql:
184
+ sql = _call_gemini(req.question, _schema_store[req.session_id], data["cols"], data["table"])
185
+
186
+ if not sql:
187
+ raise HTTPException(status_code=400, detail="I couldn't translate that question into a SQL query.")
188
+
189
+ results = execute_sql(sql, data["bytes"])
190
+ return {"sql": sql, "results": results}
191
+
192
+ # ── Health & Static Assets ──
193
 
194
  @app.get("/health")
195
  def health():
196
+ """Health check endpoint for Hugging Face and the Web UI."""
197
  return {"status": "ok", "model": "gemini-1.5-flash"}
198
 
199
+ # Mount the static directory to serve webapp.html and CSS
200
  app.mount("/static", StaticFiles(directory="static"), name="static")
201
 
202
  @app.get("/")
203
  def root():
204
+ """Serves the main frontend dashboard."""
205
  return FileResponse("static/webapp.html")