nilotpaldhar2004's picture
Enhance SQL generation in generate_sql function
4f43f55 unverified
"""
app.py — Model: google/flan-t5-large (Text-to-SQL)
HuggingFace Space: Free Tier (CPU)
"""
import os
import re
import io
import json
import sqlite3
import tempfile
import pandas as pd
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly
MAX_NEW_TOKENS = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── Load model once at startup ─────────────────────────────────────────────────
print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
print("[INFO] Model ready.")
# ── In-memory DB store ─────────────────────────────────────────────────────────
_db_store: dict[str, bytes] = {} # session_id → sqlite db bytes
_schema_store: dict[str, str] = {} # session_id → schema string
app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Static frontend ────────────────────────────────────────────────────────────
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def root():
return FileResponse("static/index.html")
# ── Helpers ────────────────────────────────────────────────────────────────────
def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
"""Convert DataFrame → SQLite DB bytes."""
buf = io.BytesIO()
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp_path = tmp.name
conn = sqlite3.connect(tmp_path)
df.to_sql(table_name, conn, if_exists="replace", index=False)
conn.close()
with open(tmp_path, "rb") as f:
db_bytes = f.read()
os.unlink(tmp_path)
return db_bytes
def get_schema(db_bytes: bytes) -> str:
"""Extract CREATE TABLE schema from DB bytes."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp.write(db_bytes)
tmp_path = tmp.name
conn = sqlite3.connect(tmp_path)
cur = conn.cursor()
cur.execute("SELECT sql FROM sqlite_master WHERE type='table'")
rows = cur.fetchall()
conn.close()
os.unlink(tmp_path)
return "\n".join(r[0] for r in rows if r[0])
def generate_sql(question: str, schema: str) -> str:
"""
Enhanced Hybrid SQL Engine.
Priority 1: Smart Regex (Deterministic & Instant)
Priority 2: T5 Transformer (Probabilistic Fallback)
"""
# 1. Context Extraction
table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
table_name = table_match.group(1) if table_match else "data"
quoted = f'"{table_name}"'
col_match = re.findall(r'"(\w+)"', schema)
q = question.lower().strip()
# 2. Smart Column Detection
# Searches for a column name from the schema within the user's question
target_col = None
for col in col_match:
if col.lower() in q:
target_col = col
break
# 3. Enhanced Rule-Based Shortcuts (Smart Logic)
# DISTINCT/UNIQUE COUNT
if re.search(r'unique|distinct', q):
col = target_col if target_col else (col_match[0] if col_match else "*")
return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
# GROUP BY
if re.search(r'group.*by|per|each', q):
col = target_col if target_col else (col_match[0] if col_match else "data")
return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
# AVERAGE (With semantic fallback for your city_day dataset)
if re.search(r'average|avg|mean', q):
num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0])
return f'SELECT AVG("{num_col}") FROM {quoted}'
# TOTAL RECORDS
if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
return f'SELECT COUNT(*) FROM {quoted}'
# LIMIT/TOP ROWS
if re.search(r'show|display|get|first|top', q):
n_match = re.search(r'\d+', q)
limit = n_match.group() if n_match else 10
return f'SELECT * FROM {quoted} LIMIT {limit}'
# 4. T5 Model Fallback
col_hint = ", ".join(col_match) if col_match else ""
prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Post-inference cleaning (Crucial for SQLite stability)
sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
sql = f'SELECT * FROM {quoted} LIMIT 10'
return sql
def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
"""Run SQL against the in-memory SQLite DB."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp.write(db_bytes)
tmp_path = tmp.name
conn = sqlite3.connect(tmp_path)
conn.row_factory = sqlite3.Row
try:
cur = conn.execute(sql)
rows = [dict(r) for r in cur.fetchall()]
except Exception as e:
conn.close()
os.unlink(tmp_path)
raise HTTPException(status_code=400, detail=f"SQL error: {e}")
conn.close()
os.unlink(tmp_path)
return rows
# ── Routes ─────────────────────────────────────────────────────────────────────
class QueryRequest(BaseModel):
session_id: str
question: str
@app.post("/upload")
async def upload_csv(file: UploadFile = File(...)):
"""Upload CSV → parse → store as SQLite → return session_id & preview."""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files accepted.")
contents = await file.read()
try:
df = pd.read_csv(io.BytesIO(contents))
except Exception as e:
raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
session_id = os.urandom(8).hex()
table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
if table_name[0].isdigit():
table_name = "t_" + table_name
db_bytes = csv_to_sqlite(df, table_name)
schema = get_schema(db_bytes)
_db_store[session_id] = db_bytes
_schema_store[session_id] = schema
preview = df.head(5).to_dict(orient="records")
columns = list(df.columns)
return JSONResponse({
"session_id": session_id,
"table_name": table_name,
"columns": columns,
"row_count": len(df),
"preview": preview,
"schema": schema,
})
@app.post("/query")
async def query(req: QueryRequest):
"""Natural language question → SQL → execute → return results."""
if req.session_id not in _db_store:
raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
schema = _schema_store[req.session_id]
sql = generate_sql(req.question, schema)
results = execute_sql(sql, _db_store[req.session_id])
return JSONResponse({"sql": sql, "results": results})
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}