nilotpaldhar2004 commited on
Commit
b56ee61
Β·
verified Β·
1 Parent(s): a02ba1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -55
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
- app.py β€” Model: google/flan-t5-large (Text-to-SQL)
3
- HuggingFace Space: Free Tier (CPU)
 
4
  """
5
 
6
  import os
@@ -18,23 +19,23 @@ from pydantic import BaseModel
18
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
19
  import torch
20
 
21
- # ── Config ────────────────────────────────────────────────────────────────────
22
- MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly
23
  MAX_NEW_TOKENS = 256
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
- # ── Load model once at startup ─────────────────────────────────────────────────
27
- print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
30
  model.eval()
31
  print("[INFO] Model ready.")
32
 
33
- # ── In-memory DB store ─────────────────────────────────────────────────────────
34
- _db_store: dict[str, bytes] = {} # session_id β†’ sqlite db bytes
35
- _schema_store: dict[str, str] = {} # session_id β†’ schema string
36
 
37
- app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
38
 
39
  app.add_middleware(
40
  CORSMiddleware,
@@ -43,31 +44,30 @@ app.add_middleware(
43
  allow_headers=["*"],
44
  )
45
 
46
- # ── Static frontend ────────────────────────────────────────────────────────────
47
  app.mount("/static", StaticFiles(directory="static"), name="static")
48
 
49
  @app.get("/")
50
  def root():
51
  return FileResponse("static/index.html")
52
 
53
-
54
- # ── Helpers ────────────────────────────────────────────────────────────────────
55
- def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
56
- """Convert DataFrame β†’ SQLite DB bytes."""
57
- buf = io.BytesIO()
58
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
59
  tmp_path = tmp.name
60
  conn = sqlite3.connect(tmp_path)
61
- df.to_sql(table_name, conn, if_exists="replace", index=False)
 
 
62
  conn.close()
63
  with open(tmp_path, "rb") as f:
64
  db_bytes = f.read()
65
  os.unlink(tmp_path)
66
  return db_bytes
67
 
68
-
69
  def get_schema(db_bytes: bytes) -> str:
70
- """Extract CREATE TABLE schema from DB bytes."""
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp.write(db_bytes)
73
  tmp_path = tmp.name
@@ -79,42 +79,54 @@ def get_schema(db_bytes: bytes) -> str:
79
  os.unlink(tmp_path)
80
  return "\n".join(r[0] for r in rows if r[0])
81
 
82
-
83
  def generate_sql(question: str, schema: str) -> str:
84
- # 1. Context Extraction (Same as before)
 
85
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
86
  table_name = table_match.group(1) if table_match else "data"
87
  quoted = f'"{table_name}"'
88
  col_match = re.findall(r'"(\w+)"', schema)
 
89
  q = question.lower().strip()
90
 
91
- # 2. Smart Column Detection
92
  target_col = None
93
  for col in col_match:
94
  if col.lower() in q:
95
  target_col = col
96
  break
97
 
98
- # 3. Enhanced Rule-Based Shortcuts (Deterministic)
 
 
99
  if re.search(r'unique|distinct', q):
100
  col = target_col if target_col else (col_match[0] if col_match else "*")
101
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
102
 
 
103
  if re.search(r'group.*by|per|each', q):
104
  col = target_col if target_col else (col_match[0] if col_match else "data")
105
  return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
106
 
107
- if re.search(r'count.*(total|all|record|row|paris)|how many', q):
108
- # Special case for "Count the Paris" -> We search for 'Paris' in all columns
109
- if "paris" in q:
110
- return f'SELECT COUNT(*) FROM {quoted} WHERE "answer" LIKE "%Paris%" OR "question" LIKE "%Paris%"'
 
 
 
 
 
 
111
  return f'SELECT COUNT(*) FROM {quoted}'
112
 
 
113
  if re.search(r'show|display|get|first|top', q):
114
  n_match = re.search(r'\d+', q)
115
- return f'SELECT * FROM {quoted} LIMIT {n_match.group() if n_match else 10}'
 
116
 
117
- # 4. T5 Model Fallback
118
  col_hint = ", ".join(col_match) if col_match else ""
119
  prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
120
 
@@ -124,23 +136,23 @@ def generate_sql(question: str, schema: str) -> str:
124
 
125
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
126
 
127
- # ── CRITICAL CLEANING GUARDRAIL ──
128
- # This removes hallucinations like "Table | SQL | Columns" from the output
129
- if "|" in sql:
130
- sql = sql.split("|")[-1].strip() # Take only the part after the last pipe
131
-
132
- # Remove common prefix hallucinations
133
- sql = re.sub(r'^(sql|query|result|table):', '', sql, flags=re.IGNORECASE).strip()
134
-
135
- # Force Table and SELECT
136
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
 
 
 
137
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
138
  sql = f'SELECT * FROM {quoted} LIMIT 10'
139
 
140
  return sql
141
 
142
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
143
- """Run SQL against the in-memory SQLite DB."""
144
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
145
  tmp.write(db_bytes)
146
  tmp_path = tmp.name
@@ -157,18 +169,16 @@ def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
157
  os.unlink(tmp_path)
158
  return rows
159
 
160
-
161
- # ── Routes ─────────────────────────────────────────────────────────────────────
162
  class QueryRequest(BaseModel):
163
  session_id: str
164
  question: str
165
 
166
-
167
  @app.post("/upload")
168
  async def upload_csv(file: UploadFile = File(...)):
169
- """Upload CSV β†’ parse β†’ store as SQLite β†’ return session_id & preview."""
170
  if not file.filename.endswith(".csv"):
171
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
 
172
  contents = await file.read()
173
  try:
174
  df = pd.read_csv(io.BytesIO(contents))
@@ -176,38 +186,37 @@ async def upload_csv(file: UploadFile = File(...)):
176
  raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
177
 
178
  session_id = os.urandom(8).hex()
179
- table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
180
- if table_name[0].isdigit():
181
- table_name = "t_" + table_name
 
 
182
  db_bytes = csv_to_sqlite(df, table_name)
183
  schema = get_schema(db_bytes)
184
 
185
  _db_store[session_id] = db_bytes
186
  _schema_store[session_id] = schema
187
 
188
- preview = df.head(5).to_dict(orient="records")
189
- columns = list(df.columns)
190
  return JSONResponse({
191
  "session_id": session_id,
192
  "table_name": table_name,
193
- "columns": columns,
194
  "row_count": len(df),
195
- "preview": preview,
196
  "schema": schema,
197
  })
198
 
199
-
200
  @app.post("/query")
201
  async def query(req: QueryRequest):
202
- """Natural language question β†’ SQL β†’ execute β†’ return results."""
203
  if req.session_id not in _db_store:
204
- raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
 
205
  schema = _schema_store[req.session_id]
206
  sql = generate_sql(req.question, schema)
207
  results = execute_sql(sql, _db_store[req.session_id])
 
208
  return JSONResponse({"sql": sql, "results": results})
209
 
210
-
211
  @app.get("/health")
212
  def health():
213
- return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
 
1
  """
2
+ QueryMind β€” CSV-to-SQL Engine
3
+ Model: T5-Small Hybrid (Regex + Transformer)
4
+ Target Hardware: HuggingFace Free Tier (CPU)
5
  """
6
 
7
  import os
 
19
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
20
  import torch
21
 
22
+ # ── Configuration ─────────────────────────────────────────────────────────────
23
+ MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"
24
  MAX_NEW_TOKENS = 256
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
+ # ── Model Initialization ──────────────────────────────────────────────────────
28
+ print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
30
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
31
  model.eval()
32
  print("[INFO] Model ready.")
33
 
34
+ # ── State Management ──────────────────────────────────────────────────────────
35
+ _db_store: dict[str, bytes] = {} # session_id -> sqlite db bytes
36
+ _schema_store: dict[str, str] = {} # session_id -> create table schema
37
 
38
+ app = FastAPI(title="QueryMind Engine", version="1.1.0")
39
 
40
  app.add_middleware(
41
  CORSMiddleware,
 
44
  allow_headers=["*"],
45
  )
46
 
47
+ # ── Static Frontend ───────────────────────────────────────────────────────────
48
  app.mount("/static", StaticFiles(directory="static"), name="static")
49
 
50
  @app.get("/")
51
  def root():
52
  return FileResponse("static/index.html")
53
 
54
+ # ── Logic Helpers ──────────────────────────────────────────────────────────────
55
+ def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
56
+ """Safely converts a Pandas DataFrame into a SQLite binary blob."""
 
 
57
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
58
  tmp_path = tmp.name
59
  conn = sqlite3.connect(tmp_path)
60
+ # Ensure the table name is safe for SQL
61
+ safe_table = re.sub(r"[^a-zA-Z0-9_]", "_", table_name)
62
+ df.to_sql(safe_table, conn, if_exists="replace", index=False)
63
  conn.close()
64
  with open(tmp_path, "rb") as f:
65
  db_bytes = f.read()
66
  os.unlink(tmp_path)
67
  return db_bytes
68
 
 
69
  def get_schema(db_bytes: bytes) -> str:
70
+ """Extracts the exact SQL schema used to create the SQLite table."""
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp.write(db_bytes)
73
  tmp_path = tmp.name
 
79
  os.unlink(tmp_path)
80
  return "\n".join(r[0] for r in rows if r[0])
81
 
 
82
  def generate_sql(question: str, schema: str) -> str:
83
+ """Hybrid Engine: Uses smart regex first, falls back to T5 with sanitization."""
84
+ # 1. Schema Context
85
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
86
  table_name = table_match.group(1) if table_match else "data"
87
  quoted = f'"{table_name}"'
88
  col_match = re.findall(r'"(\w+)"', schema)
89
+
90
  q = question.lower().strip()
91
 
92
+ # 2. Smart Column Detection (Matches user words to schema)
93
  target_col = None
94
  for col in col_match:
95
  if col.lower() in q:
96
  target_col = col
97
  break
98
 
99
+ # 3. Deterministic Regex Layer (High Accuracy, Zero Latency)
100
+
101
+ # DISTINCT/UNIQUE
102
  if re.search(r'unique|distinct', q):
103
  col = target_col if target_col else (col_match[0] if col_match else "*")
104
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
105
 
106
+ # GROUP BY
107
  if re.search(r'group.*by|per|each', q):
108
  col = target_col if target_col else (col_match[0] if col_match else "data")
109
  return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
110
 
111
+ # AVERAGE
112
+ if re.search(r'average|mean|avg', q):
113
+ num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[0])
114
+ return f'SELECT AVG("{num_col}") FROM {quoted}'
115
+
116
+ # COUNT/HOW MANY
117
+ if re.search(r'count|total|how many', q):
118
+ # Handle word searches (e.g. "count Paris")
119
+ if target_col and len(q.split()) > 2:
120
+ return f'SELECT COUNT(*) FROM {quoted} WHERE "{target_col}" LIKE "%{q.split()[-1]}%"'
121
  return f'SELECT COUNT(*) FROM {quoted}'
122
 
123
+ # LIMIT/TOP
124
  if re.search(r'show|display|get|first|top', q):
125
  n_match = re.search(r'\d+', q)
126
+ limit = n_match.group() if n_match else 10
127
+ return f'SELECT * FROM {quoted} LIMIT {limit}'
128
 
129
+ # 4. Transformer Fallback (Probabilistic Reasoning)
130
  col_hint = ", ".join(col_match) if col_match else ""
131
  prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
132
 
 
136
 
137
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
138
 
139
+ # ── Output Sanitization Guardrails ──
140
+ # Remove T5 artifacts (pipes, prompt echoes)
141
+ if "|" in sql: sql = sql.split("|")[-1].strip()
142
+ sql = re.sub(r'^(sql|query|table):', '', sql, flags=re.IGNORECASE).strip()
143
+
144
+ # Force correct table references
 
 
 
145
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
146
+ sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
147
+
148
+ # Final check for valid SELECT
149
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
150
  sql = f'SELECT * FROM {quoted} LIMIT 10'
151
 
152
  return sql
153
 
154
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
155
+ """Runs SQL against the binary blob by creating a temporary local SQLite DB."""
156
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
157
  tmp.write(db_bytes)
158
  tmp_path = tmp.name
 
169
  os.unlink(tmp_path)
170
  return rows
171
 
172
+ # ── API Endpoints ─────────────────────────────────────────────────────────────
 
173
  class QueryRequest(BaseModel):
174
  session_id: str
175
  question: str
176
 
 
177
  @app.post("/upload")
178
  async def upload_csv(file: UploadFile = File(...)):
 
179
  if not file.filename.endswith(".csv"):
180
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
181
+
182
  contents = await file.read()
183
  try:
184
  df = pd.read_csv(io.BytesIO(contents))
 
186
  raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
187
 
188
  session_id = os.urandom(8).hex()
189
+ # Clean the filename to create a valid SQLite table name
190
+ raw_name = os.path.splitext(file.filename)[0]
191
+ table_name = re.sub(r"[^a-zA-Z0-9_]", "_", raw_name)[:32] or "data"
192
+ if table_name[0].isdigit(): table_name = "t_" + table_name
193
+
194
  db_bytes = csv_to_sqlite(df, table_name)
195
  schema = get_schema(db_bytes)
196
 
197
  _db_store[session_id] = db_bytes
198
  _schema_store[session_id] = schema
199
 
 
 
200
  return JSONResponse({
201
  "session_id": session_id,
202
  "table_name": table_name,
203
+ "columns": list(df.columns),
204
  "row_count": len(df),
205
+ "preview": df.head(5).to_dict(orient="records"),
206
  "schema": schema,
207
  })
208
 
 
209
  @app.post("/query")
210
  async def query(req: QueryRequest):
 
211
  if req.session_id not in _db_store:
212
+ raise HTTPException(status_code=404, detail="Session expired or not found.")
213
+
214
  schema = _schema_store[req.session_id]
215
  sql = generate_sql(req.question, schema)
216
  results = execute_sql(sql, _db_store[req.session_id])
217
+
218
  return JSONResponse({"sql": sql, "results": results})
219
 
 
220
  @app.get("/health")
221
  def health():
222
+ return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}