| """ |
| app.py — Model: google/flan-t5-large (Text-to-SQL) |
| HuggingFace Space: Free Tier (CPU) |
| """ |
|
|
| import os |
| import re |
| import io |
| import json |
| import sqlite3 |
| import tempfile |
| import pandas as pd |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
|
|
| |
| MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" |
| MAX_NEW_TOKENS = 256 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE) |
| model.eval() |
| print("[INFO] Model ready.") |
|
|
| |
| _db_store: dict[str, bytes] = {} |
| _schema_store: dict[str, str] = {} |
|
|
| app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
| @app.get("/") |
| def root(): |
| return FileResponse("static/index.html") |
|
|
|
|
| |
| def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes: |
| """Convert DataFrame → SQLite DB bytes.""" |
| buf = io.BytesIO() |
| with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: |
| tmp_path = tmp.name |
| conn = sqlite3.connect(tmp_path) |
| df.to_sql(table_name, conn, if_exists="replace", index=False) |
| conn.close() |
| with open(tmp_path, "rb") as f: |
| db_bytes = f.read() |
| os.unlink(tmp_path) |
| return db_bytes |
|
|
|
|
| def get_schema(db_bytes: bytes) -> str: |
| """Extract CREATE TABLE schema from DB bytes.""" |
| with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: |
| tmp.write(db_bytes) |
| tmp_path = tmp.name |
| conn = sqlite3.connect(tmp_path) |
| cur = conn.cursor() |
| cur.execute("SELECT sql FROM sqlite_master WHERE type='table'") |
| rows = cur.fetchall() |
| conn.close() |
| os.unlink(tmp_path) |
| return "\n".join(r[0] for r in rows if r[0]) |
|
|
|
|
| def generate_sql(question: str, schema: str) -> str: |
| """ |
| Enhanced Hybrid SQL Engine. |
| Priority 1: Smart Regex (Deterministic & Instant) |
| Priority 2: T5 Transformer (Probabilistic Fallback) |
| """ |
| |
| table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE) |
| table_name = table_match.group(1) if table_match else "data" |
| quoted = f'"{table_name}"' |
| col_match = re.findall(r'"(\w+)"', schema) |
| |
| q = question.lower().strip() |
|
|
| |
| |
| target_col = None |
| for col in col_match: |
| if col.lower() in q: |
| target_col = col |
| break |
|
|
| |
| |
| |
| if re.search(r'unique|distinct', q): |
| col = target_col if target_col else (col_match[0] if col_match else "*") |
| return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}' |
|
|
| |
| if re.search(r'group.*by|per|each', q): |
| col = target_col if target_col else (col_match[0] if col_match else "data") |
| return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"' |
|
|
| |
| if re.search(r'average|avg|mean', q): |
| num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0]) |
| return f'SELECT AVG("{num_col}") FROM {quoted}' |
|
|
| |
| if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q): |
| return f'SELECT COUNT(*) FROM {quoted}' |
|
|
| |
| if re.search(r'show|display|get|first|top', q): |
| n_match = re.search(r'\d+', q) |
| limit = n_match.group() if n_match else 10 |
| return f'SELECT * FROM {quoted} LIMIT {limit}' |
|
|
| |
| col_hint = ", ".join(col_match) if col_match else "" |
| prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}" |
| |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) |
| with torch.no_grad(): |
| outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True) |
| |
| sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() |
|
|
| |
| sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE) |
| sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE) |
| |
| if not re.search(r'\bSELECT\b', sql, re.IGNORECASE): |
| sql = f'SELECT * FROM {quoted} LIMIT 10' |
|
|
| return sql |
|
|
| def execute_sql(sql: str, db_bytes: bytes) -> list[dict]: |
| """Run SQL against the in-memory SQLite DB.""" |
| with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: |
| tmp.write(db_bytes) |
| tmp_path = tmp.name |
| conn = sqlite3.connect(tmp_path) |
| conn.row_factory = sqlite3.Row |
| try: |
| cur = conn.execute(sql) |
| rows = [dict(r) for r in cur.fetchall()] |
| except Exception as e: |
| conn.close() |
| os.unlink(tmp_path) |
| raise HTTPException(status_code=400, detail=f"SQL error: {e}") |
| conn.close() |
| os.unlink(tmp_path) |
| return rows |
|
|
|
|
| |
| class QueryRequest(BaseModel): |
| session_id: str |
| question: str |
|
|
|
|
| @app.post("/upload") |
| async def upload_csv(file: UploadFile = File(...)): |
| """Upload CSV → parse → store as SQLite → return session_id & preview.""" |
| if not file.filename.endswith(".csv"): |
| raise HTTPException(status_code=400, detail="Only CSV files accepted.") |
| contents = await file.read() |
| try: |
| df = pd.read_csv(io.BytesIO(contents)) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"CSV parse error: {e}") |
|
|
| session_id = os.urandom(8).hex() |
| table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data" |
| if table_name[0].isdigit(): |
| table_name = "t_" + table_name |
| db_bytes = csv_to_sqlite(df, table_name) |
| schema = get_schema(db_bytes) |
|
|
| _db_store[session_id] = db_bytes |
| _schema_store[session_id] = schema |
|
|
| preview = df.head(5).to_dict(orient="records") |
| columns = list(df.columns) |
| return JSONResponse({ |
| "session_id": session_id, |
| "table_name": table_name, |
| "columns": columns, |
| "row_count": len(df), |
| "preview": preview, |
| "schema": schema, |
| }) |
|
|
|
|
| @app.post("/query") |
| async def query(req: QueryRequest): |
| """Natural language question → SQL → execute → return results.""" |
| if req.session_id not in _db_store: |
| raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.") |
| schema = _schema_store[req.session_id] |
| sql = generate_sql(req.question, schema) |
| results = execute_sql(sql, _db_store[req.session_id]) |
| return JSONResponse({"sql": sql, "results": results}) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model": MODEL_NAME, "device": DEVICE} |
|
|