""" app.py — Model: google/flan-t5-large (Text-to-SQL) HuggingFace Space: Free Tier (CPU) """ import os import re import io import json import sqlite3 import tempfile import pandas as pd from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch # ── Config ──────────────────────────────────────────────────────────────────── MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly MAX_NEW_TOKENS = 256 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ── Load model once at startup ───────────────────────────────────────────────── print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() print("[INFO] Model ready.") # ── In-memory DB store ───────────────────────────────────────────────────────── _db_store: dict[str, bytes] = {} # session_id → sqlite db bytes _schema_store: dict[str, str] = {} # session_id → schema string app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── Static frontend ──────────────────────────────────────────────────────────── app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def root(): return FileResponse("static/index.html") # ── Helpers ──────────────────────────────────────────────────────────────────── def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes: """Convert DataFrame → SQLite DB bytes.""" buf = io.BytesIO() with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: tmp_path = tmp.name conn = sqlite3.connect(tmp_path) df.to_sql(table_name, conn, if_exists="replace", index=False) conn.close() with open(tmp_path, "rb") as f: db_bytes = f.read() os.unlink(tmp_path) return db_bytes def get_schema(db_bytes: bytes) -> str: """Extract CREATE TABLE schema from DB bytes.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: tmp.write(db_bytes) tmp_path = tmp.name conn = sqlite3.connect(tmp_path) cur = conn.cursor() cur.execute("SELECT sql FROM sqlite_master WHERE type='table'") rows = cur.fetchall() conn.close() os.unlink(tmp_path) return "\n".join(r[0] for r in rows if r[0]) def generate_sql(question: str, schema: str) -> str: """ Enhanced Hybrid SQL Engine. Priority 1: Smart Regex (Deterministic & Instant) Priority 2: T5 Transformer (Probabilistic Fallback) """ # 1. Context Extraction table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE) table_name = table_match.group(1) if table_match else "data" quoted = f'"{table_name}"' col_match = re.findall(r'"(\w+)"', schema) q = question.lower().strip() # 2. Smart Column Detection # Searches for a column name from the schema within the user's question target_col = None for col in col_match: if col.lower() in q: target_col = col break # 3. Enhanced Rule-Based Shortcuts (Smart Logic) # DISTINCT/UNIQUE COUNT if re.search(r'unique|distinct', q): col = target_col if target_col else (col_match[0] if col_match else "*") return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}' # GROUP BY if re.search(r'group.*by|per|each', q): col = target_col if target_col else (col_match[0] if col_match else "data") return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"' # AVERAGE (With semantic fallback for your city_day dataset) if re.search(r'average|avg|mean', q): num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0]) return f'SELECT AVG("{num_col}") FROM {quoted}' # TOTAL RECORDS if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q): return f'SELECT COUNT(*) FROM {quoted}' # LIMIT/TOP ROWS if re.search(r'show|display|get|first|top', q): n_match = re.search(r'\d+', q) limit = n_match.group() if n_match else 10 return f'SELECT * FROM {quoted} LIMIT {limit}' # 4. T5 Model Fallback col_hint = ", ".join(col_match) if col_match else "" prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True) sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # Post-inference cleaning (Crucial for SQLite stability) sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE) sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE) if not re.search(r'\bSELECT\b', sql, re.IGNORECASE): sql = f'SELECT * FROM {quoted} LIMIT 10' return sql def execute_sql(sql: str, db_bytes: bytes) -> list[dict]: """Run SQL against the in-memory SQLite DB.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: tmp.write(db_bytes) tmp_path = tmp.name conn = sqlite3.connect(tmp_path) conn.row_factory = sqlite3.Row try: cur = conn.execute(sql) rows = [dict(r) for r in cur.fetchall()] except Exception as e: conn.close() os.unlink(tmp_path) raise HTTPException(status_code=400, detail=f"SQL error: {e}") conn.close() os.unlink(tmp_path) return rows # ── Routes ───────────────────────────────────────────────────────────────────── class QueryRequest(BaseModel): session_id: str question: str @app.post("/upload") async def upload_csv(file: UploadFile = File(...)): """Upload CSV → parse → store as SQLite → return session_id & preview.""" if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="Only CSV files accepted.") contents = await file.read() try: df = pd.read_csv(io.BytesIO(contents)) except Exception as e: raise HTTPException(status_code=400, detail=f"CSV parse error: {e}") session_id = os.urandom(8).hex() table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data" if table_name[0].isdigit(): table_name = "t_" + table_name db_bytes = csv_to_sqlite(df, table_name) schema = get_schema(db_bytes) _db_store[session_id] = db_bytes _schema_store[session_id] = schema preview = df.head(5).to_dict(orient="records") columns = list(df.columns) return JSONResponse({ "session_id": session_id, "table_name": table_name, "columns": columns, "row_count": len(df), "preview": preview, "schema": schema, }) @app.post("/query") async def query(req: QueryRequest): """Natural language question → SQL → execute → return results.""" if req.session_id not in _db_store: raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.") schema = _schema_store[req.session_id] sql = generate_sql(req.question, schema) results = execute_sql(sql, _db_store[req.session_id]) return JSONResponse({"sql": sql, "results": results}) @app.get("/health") def health(): return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}