File size: 8,807 Bytes
d39d3d8 7073cc4 d39d3d8 4040150 d39d3d8 4040150 d39d3d8 4040150 7073cc4 d39d3d8 4040150 d39d3d8 7073cc4 4040150 d39d3d8 4040150 d39d3d8 4040150 d39d3d8 7073cc4 4040150 d39d3d8 4040150 c53c8b6 4040150 7073cc4 4040150 d39d3d8 7073cc4 4040150 7073cc4 4040150 7073cc4 4040150 d39d3d8 4040150 7073cc4 4040150 4f43f55 a175f6f 7073cc4 d39d3d8 4f43f55 471250f 4f43f55 471250f 4f43f55 d39d3d8 4f43f55 4040150 4f43f55 d39d3d8 a175f6f 4f43f55 7073cc4 4f43f55 d39d3d8 af59526 87cc995 d39d3d8 e870039 7073cc4 e870039 7073cc4 4040150 7073cc4 4040150 7073cc4 4040150 d39d3d8 4040150 d39d3d8 4040150 7073cc4 4040150 7073cc4 d39d3d8 4040150 d39d3d8 7073cc4 d39d3d8 7073cc4 d39d3d8 7073cc4 4040150 7073cc4 4040150 7073cc4 4040150 d39d3d8 7073cc4 4040150 d39d3d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """
app.py — Model: google/flan-t5-large (Text-to-SQL)
HuggingFace Space: Free Tier (CPU)
"""
import os
import re
import io
import json
import sqlite3
import tempfile
import pandas as pd
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# ── Config ────────────────────────────────────────────────────────────────────
MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly
MAX_NEW_TOKENS = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── Load model once at startup ─────────────────────────────────────────────────
print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
print("[INFO] Model ready.")
# ── In-memory DB store ─────────────────────────────────────────────────────────
_db_store: dict[str, bytes] = {} # session_id → sqlite db bytes
_schema_store: dict[str, str] = {} # session_id → schema string
app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── Static frontend ────────────────────────────────────────────────────────────
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def root():
return FileResponse("static/index.html")
# ── Helpers ────────────────────────────────────────────────────────────────────
def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
"""Convert DataFrame → SQLite DB bytes."""
buf = io.BytesIO()
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp_path = tmp.name
conn = sqlite3.connect(tmp_path)
df.to_sql(table_name, conn, if_exists="replace", index=False)
conn.close()
with open(tmp_path, "rb") as f:
db_bytes = f.read()
os.unlink(tmp_path)
return db_bytes
def get_schema(db_bytes: bytes) -> str:
"""Extract CREATE TABLE schema from DB bytes."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp.write(db_bytes)
tmp_path = tmp.name
conn = sqlite3.connect(tmp_path)
cur = conn.cursor()
cur.execute("SELECT sql FROM sqlite_master WHERE type='table'")
rows = cur.fetchall()
conn.close()
os.unlink(tmp_path)
return "\n".join(r[0] for r in rows if r[0])
def generate_sql(question: str, schema: str) -> str:
"""
Enhanced Hybrid SQL Engine.
Priority 1: Smart Regex (Deterministic & Instant)
Priority 2: T5 Transformer (Probabilistic Fallback)
"""
# 1. Context Extraction
table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
table_name = table_match.group(1) if table_match else "data"
quoted = f'"{table_name}"'
col_match = re.findall(r'"(\w+)"', schema)
q = question.lower().strip()
# 2. Smart Column Detection
# Searches for a column name from the schema within the user's question
target_col = None
for col in col_match:
if col.lower() in q:
target_col = col
break
# 3. Enhanced Rule-Based Shortcuts (Smart Logic)
# DISTINCT/UNIQUE COUNT
if re.search(r'unique|distinct', q):
col = target_col if target_col else (col_match[0] if col_match else "*")
return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
# GROUP BY
if re.search(r'group.*by|per|each', q):
col = target_col if target_col else (col_match[0] if col_match else "data")
return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
# AVERAGE (With semantic fallback for your city_day dataset)
if re.search(r'average|avg|mean', q):
num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0])
return f'SELECT AVG("{num_col}") FROM {quoted}'
# TOTAL RECORDS
if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
return f'SELECT COUNT(*) FROM {quoted}'
# LIMIT/TOP ROWS
if re.search(r'show|display|get|first|top', q):
n_match = re.search(r'\d+', q)
limit = n_match.group() if n_match else 10
return f'SELECT * FROM {quoted} LIMIT {limit}'
# 4. T5 Model Fallback
col_hint = ", ".join(col_match) if col_match else ""
prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Post-inference cleaning (Crucial for SQLite stability)
sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
sql = f'SELECT * FROM {quoted} LIMIT 10'
return sql
def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
"""Run SQL against the in-memory SQLite DB."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp.write(db_bytes)
tmp_path = tmp.name
conn = sqlite3.connect(tmp_path)
conn.row_factory = sqlite3.Row
try:
cur = conn.execute(sql)
rows = [dict(r) for r in cur.fetchall()]
except Exception as e:
conn.close()
os.unlink(tmp_path)
raise HTTPException(status_code=400, detail=f"SQL error: {e}")
conn.close()
os.unlink(tmp_path)
return rows
# ── Routes ─────────────────────────────────────────────────────────────────────
class QueryRequest(BaseModel):
session_id: str
question: str
@app.post("/upload")
async def upload_csv(file: UploadFile = File(...)):
"""Upload CSV → parse → store as SQLite → return session_id & preview."""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files accepted.")
contents = await file.read()
try:
df = pd.read_csv(io.BytesIO(contents))
except Exception as e:
raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
session_id = os.urandom(8).hex()
table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
if table_name[0].isdigit():
table_name = "t_" + table_name
db_bytes = csv_to_sqlite(df, table_name)
schema = get_schema(db_bytes)
_db_store[session_id] = db_bytes
_schema_store[session_id] = schema
preview = df.head(5).to_dict(orient="records")
columns = list(df.columns)
return JSONResponse({
"session_id": session_id,
"table_name": table_name,
"columns": columns,
"row_count": len(df),
"preview": preview,
"schema": schema,
})
@app.post("/query")
async def query(req: QueryRequest):
"""Natural language question → SQL → execute → return results."""
if req.session_id not in _db_store:
raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
schema = _schema_store[req.session_id]
sql = generate_sql(req.question, schema)
results = execute_sql(sql, _db_store[req.session_id])
return JSONResponse({"sql": sql, "results": results})
@app.get("/health")
def health():
return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
|