nilotpaldhar2004 commited on
Commit
fff6817
Β·
verified Β·
1 Parent(s): 1e5473b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -91
app.py CHANGED
@@ -1,13 +1,12 @@
1
  """
2
- QueryMind β€” CSV-to-SQL Engine (Final Production Version)
3
- Model: T5-Small Hybrid (Regex + Transformer)
4
- Hardware: HuggingFace Free Tier (CPU)
5
  """
6
 
7
  import os
8
  import re
9
  import io
10
- import json
11
  import sqlite3
12
  import tempfile
13
  import pandas as pd
@@ -16,28 +15,31 @@ from fastapi.staticfiles import StaticFiles
16
  from fastapi.responses import FileResponse, JSONResponse
17
  from fastapi.middleware.cors import CORSMiddleware
18
  from pydantic import BaseModel
19
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
20
  import torch
 
21
 
22
  # ── Configuration ─────────────────────────────────────────────────────────────
23
- MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"
24
- MAX_NEW_TOKENS = 256
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # ── Model Initialization ──────────────────────────────────────────────────────
28
- print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
29
-
30
- # CRITICAL: use_fast=False fixes the sentencepiece/backend tokenizer error on CPU
31
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
32
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
 
 
 
33
  model.eval()
34
- print("[INFO] Model ready.")
35
 
36
- # ── State Management (In-Memory) ──────────────────────────────────────────────
37
- _db_store: dict[str, bytes] = {} # session_id -> sqlite db bytes
38
- _schema_store: dict[str, str] = {} # session_id -> create table schema
39
 
40
- app = FastAPI(title="QueryMind Engine", version="1.1.0")
41
 
42
  app.add_middleware(
43
  CORSMiddleware,
@@ -46,7 +48,6 @@ app.add_middleware(
46
  allow_headers=["*"],
47
  )
48
 
49
- # ── Static Frontend ───────────────────────────────────────────────────────────
50
  app.mount("/static", StaticFiles(directory="static"), name="static")
51
 
52
  @app.get("/")
@@ -55,7 +56,6 @@ def root():
55
 
56
  # ── Logic Helpers ──────────────────────────────────────────────────────────────
57
  def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
58
- """Converts Pandas DataFrame into a portable SQLite binary blob."""
59
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
60
  tmp_path = tmp.name
61
  conn = sqlite3.connect(tmp_path)
@@ -68,7 +68,6 @@ def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
68
  return db_bytes
69
 
70
  def get_schema(db_bytes: bytes) -> str:
71
- """Extracts the SQL schema used to create the table."""
72
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
73
  tmp.write(db_bytes)
74
  tmp_path = tmp.name
@@ -80,76 +79,49 @@ def get_schema(db_bytes: bytes) -> str:
80
  os.unlink(tmp_path)
81
  return "\n".join(r[0] for r in rows if r[0])
82
 
 
 
83
  def generate_sql(question: str, schema: str) -> str:
84
- # 1. Context Extraction
 
 
85
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
86
  table_name = table_match.group(1) if table_match else "data"
87
- quoted = f'"{table_name}"'
88
- col_match = re.findall(r'"(\w+)"', schema)
89
- q = question.lower().strip()
90
-
91
- # Smart Column Detection
92
- target_col = None
93
- for col in col_match:
94
- if col.lower() in q:
95
- target_col = col
96
- break
97
-
98
- # 2. Advanced Rule-Based Shortcuts
99
 
100
- # FILTERING (e.g., "ans is Asia")
101
- if "is" in q or "where" in q:
102
- # Improved value extraction: look for the last word in the sentence
103
- words = q.split()
104
- val = words[-1].strip("?.!")
105
-
106
- # Determine columns
107
- select_col = col_match[0] if "question" in q else "*"
108
- filter_col = target_col if target_col else (col_match[1] if len(col_match)>1 else col_match[0])
109
-
110
- # Don't trigger if the 'value' is just a common instruction word
111
- if val not in ["null", "not", "the", "average", "rows"]:
112
- return f'SELECT "{select_col}" FROM {quoted} WHERE "{filter_col}" LIKE "%{val}%"'
113
-
114
- # SELECT DISTINCT (List) vs COUNT DISTINCT (Number)
115
- if re.search(r'unique|distinct', q):
116
- col = target_col if target_col else col_match[0]
117
- if re.search(r'show|list|get|give', q):
118
- return f'SELECT DISTINCT "{col}" FROM {quoted} LIMIT 50'
119
- return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
120
-
121
- # AGGREGATIONS
122
- if re.search(r'average|mean|avg', q):
123
- num_col = target_col if target_col else (col_match[1] if len(col_match)>1 else col_match[0])
124
- return f'SELECT AVG("{num_col}") FROM {quoted}'
125
-
126
- # LIMIT/SHOW
127
- if re.search(r'show|display|get|first|top', q) and not target_col:
128
- n_match = re.search(r'\d+', q)
129
- return f'SELECT * FROM {quoted} LIMIT {n_match.group() if n_match else 10}'
130
-
131
- # 3. Transformer Fallback (MANDATORY FIX)
132
- # Ensure this part is NOT skipped
133
- col_hint = ", ".join(col_match) if col_match else ""
134
- prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
135
 
136
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
137
  with torch.no_grad():
138
- outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
 
 
 
 
 
139
 
140
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
-
142
- # Sanitization
143
- if "|" in sql: sql = sql.split("|")[-1].strip()
144
- sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
 
 
145
 
146
- if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
147
- sql = f'SELECT * FROM {quoted} LIMIT 10'
148
-
149
  return sql
150
 
151
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
152
- """Runs SQL against the binary blob via a temporary SQLite instance."""
153
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
154
  tmp.write(db_bytes)
155
  tmp_path = tmp.name
@@ -161,7 +133,7 @@ def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
161
  except Exception as e:
162
  conn.close()
163
  os.unlink(tmp_path)
164
- raise HTTPException(status_code=400, detail=f"SQL error: {e}")
165
  conn.close()
166
  os.unlink(tmp_path)
167
  return rows
@@ -177,10 +149,7 @@ async def upload_csv(file: UploadFile = File(...)):
177
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
178
 
179
  contents = await file.read()
180
- try:
181
- df = pd.read_csv(io.BytesIO(contents))
182
- except Exception as e:
183
- raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
184
 
185
  session_id = os.urandom(8).hex()
186
  raw_name = os.path.splitext(file.filename)[0]
@@ -188,24 +157,21 @@ async def upload_csv(file: UploadFile = File(...)):
188
  if table_name[0].isdigit(): table_name = "t_" + table_name
189
 
190
  db_bytes = csv_to_sqlite(df, table_name)
191
- schema = get_schema(db_bytes)
192
-
193
  _db_store[session_id] = db_bytes
194
- _schema_store[session_id] = schema
195
 
196
  return JSONResponse({
197
  "session_id": session_id,
198
- "table_name": table_name,
199
  "columns": list(df.columns),
200
  "row_count": len(df),
201
  "preview": df.head(5).to_dict(orient="records"),
202
- "schema": schema,
203
  })
204
 
205
  @app.post("/query")
206
  async def query(req: QueryRequest):
207
  if req.session_id not in _db_store:
208
- raise HTTPException(status_code=404, detail="Session expired or not found.")
209
 
210
  schema = _schema_store[req.session_id]
211
  sql = generate_sql(req.question, schema)
@@ -215,4 +181,4 @@ async def query(req: QueryRequest):
215
 
216
  @app.get("/health")
217
  def health():
218
- return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
 
1
  """
2
+ QueryMind β€” CSV-to-SQL Engine (High Performance 7B Version)
3
+ Model: SQLCoder-7B-2 (State-of-the-art Text-to-SQL)
4
+ Hardware: HuggingFace ZeroGPU (Free A100 Tier)
5
  """
6
 
7
  import os
8
  import re
9
  import io
 
10
  import sqlite3
11
  import tempfile
12
  import pandas as pd
 
15
  from fastapi.responses import FileResponse, JSONResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM
19
  import torch
20
+ import spaces # Required for HuggingFace ZeroGPU
21
 
22
  # ── Configuration ─────────────────────────────────────────────────────────────
23
+ MODEL_ID = "defog/sqlcoder-7b-2"
 
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  # ── Model Initialization ──────────────────────────────────────────────────────
27
+ print(f"[INFO] Loading 7B Model: {MODEL_ID}")
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_ID,
31
+ trust_remote_code=True,
32
+ torch_dtype=torch.float16,
33
+ device_map="auto" if torch.cuda.is_available() else None
34
+ )
35
  model.eval()
36
+ print("[INFO] 7B Model ready.")
37
 
38
+ # ── State Management ──────────────────────────────────────────────────────────
39
+ _db_store: dict[str, bytes] = {}
40
+ _schema_store: dict[str, str] = {}
41
 
42
+ app = FastAPI(title="QueryMind 7B", version="2.0.0")
43
 
44
  app.add_middleware(
45
  CORSMiddleware,
 
48
  allow_headers=["*"],
49
  )
50
 
 
51
  app.mount("/static", StaticFiles(directory="static"), name="static")
52
 
53
  @app.get("/")
 
56
 
57
  # ── Logic Helpers ──────────────────────────────────────────────────────────────
58
  def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
 
59
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
60
  tmp_path = tmp.name
61
  conn = sqlite3.connect(tmp_path)
 
68
  return db_bytes
69
 
70
  def get_schema(db_bytes: bytes) -> str:
 
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp.write(db_bytes)
73
  tmp_path = tmp.name
 
79
  os.unlink(tmp_path)
80
  return "\n".join(r[0] for r in rows if r[0])
81
 
82
+ # ── 7B Inference with ZeroGPU Decorator ──────────────────────────────────────
83
+ @spaces.GPU(duration=60) # <── This is the secret for Free GPU access
84
  def generate_sql(question: str, schema: str) -> str:
85
+ """Uses SQLCoder-7B to generate high-accuracy SQL."""
86
+
87
+ # Extract table name for the prompt
88
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
89
  table_name = table_match.group(1) if table_match else "data"
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Prompt format required by SQLCoder
92
+ prompt = f"""### Task
93
+ Generate a SQL query to answer the question based on the table schema.
94
+
95
+ ### Schema
96
+ {schema}
97
+
98
+ ### Question
99
+ {question}
100
+
101
+ ### SQL
102
+ """
103
+
104
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
106
  with torch.no_grad():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=200,
110
+ do_sample=False, # Use greedy decoding for SQL accuracy
111
+ num_beams=1
112
+ )
113
 
114
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
+
116
+ # Extract only the SQL part after the prompt
117
+ sql = full_output.split("### SQL")[-1].strip()
118
+
119
+ # Basic cleanup
120
+ sql = sql.split(';')[0].strip() # Take only the first statement
121
 
 
 
 
122
  return sql
123
 
124
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
 
125
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
126
  tmp.write(db_bytes)
127
  tmp_path = tmp.name
 
133
  except Exception as e:
134
  conn.close()
135
  os.unlink(tmp_path)
136
+ raise HTTPException(status_code=400, detail=f"SQL Error: {e}")
137
  conn.close()
138
  os.unlink(tmp_path)
139
  return rows
 
149
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
150
 
151
  contents = await file.read()
152
+ df = pd.read_csv(io.BytesIO(contents))
 
 
 
153
 
154
  session_id = os.urandom(8).hex()
155
  raw_name = os.path.splitext(file.filename)[0]
 
157
  if table_name[0].isdigit(): table_name = "t_" + table_name
158
 
159
  db_bytes = csv_to_sqlite(df, table_name)
 
 
160
  _db_store[session_id] = db_bytes
161
+ _schema_store[session_id] = get_schema(db_bytes)
162
 
163
  return JSONResponse({
164
  "session_id": session_id,
 
165
  "columns": list(df.columns),
166
  "row_count": len(df),
167
  "preview": df.head(5).to_dict(orient="records"),
168
+ "schema": _schema_store[session_id],
169
  })
170
 
171
  @app.post("/query")
172
  async def query(req: QueryRequest):
173
  if req.session_id not in _db_store:
174
+ raise HTTPException(status_code=404, detail="Session expired.")
175
 
176
  schema = _schema_store[req.session_id]
177
  sql = generate_sql(req.question, schema)
 
181
 
182
  @app.get("/health")
183
  def health():
184
+ return {"status": "ok", "model": MODEL_ID, "device": DEVICE}