Spaces:
Sleeping
Sleeping
File size: 7,574 Bytes
d6bfc8b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | """
SQLForge — demo server.
A FastAPI app that loads the fine-tuned text-to-SQL model and turns a natural
-language question + database schema into a SQL query. When a real SQLite
database is available it also *runs* the query, shows the results, and uses
self-correction (feed the DB error back to the model and retry) so you can watch
the agent fix its own mistakes.
Run locally:
uvicorn app.server:app --reload --port 8000
Then open http://localhost:8000
"""
import os
import sqlite3
import threading
import time
from pathlib import Path
# the model + tokenizer are local; don't reach for the (flaky) HF CDN at serve time.
os.environ.setdefault("HF_HUB_OFFLINE", "1")
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from sqlforge.exec_eval import run_sql, schema_from_sqlite
from sqlforge.inference import generate_sql, generate_sql_with_retry, load_model
ROOT = Path(__file__).resolve().parent.parent
STATIC = Path(__file__).resolve().parent / "static"
# --- config (all overridable by env, so HF Spaces / CI can swap paths) ----------
BASE_MODEL = os.environ.get("SQLFORGE_BASE", "models/qwen2.5-coder-1.5b")
if (ROOT / BASE_MODEL).is_dir():
BASE_MODEL = str(ROOT / BASE_MODEL)
ADAPTER = os.environ.get("SQLFORGE_ADAPTER", str(ROOT / "outputs" / "qwen2.5-coder-1.5b-sql"))
FOUR_BIT = os.environ.get("SQLFORGE_4BIT", "1") == "1"
DB_DIR = Path(os.environ.get("SQLFORGE_DB_DIR",
str(ROOT / "data" / "spider_raw" / "spider_data" / "database")))
MAX_ROWS = 100 # cap result rows sent to the browser
GEN_TOKENS = 192 # SQL is short; cap new tokens for a snappier demo
MAX_RETRIES = 1 # one self-correction retry (worst case ~2 generations, not 3)
# curated example databases — every question below is verified to run cleanly
# against the real DB. The last car question deliberately triggers self-correction
# (a JOIN it fixes itself) to showcase the agentic recovery succeeding.
EXAMPLES = [
{"db_id": "concert_singer", "label": "Concerts & Singers",
"questions": ["How many singers do we have?",
"What is the average, minimum, and maximum age of all singers?",
"Show the name and country of all singers ordered by age from oldest to youngest."]},
{"db_id": "pets_1", "label": "Students & Pets",
"questions": ["How many pets are there?",
"What is the average weight of all pets?",
"How many students are there?"]},
{"db_id": "world_1", "label": "World (countries)",
"questions": ["What are the names of all countries that became independent after 1950?",
"How many countries have a republic as their form of government?",
"What is the average life expectancy of countries in Africa?"]},
{"db_id": "car_1", "label": "Cars",
"questions": ["How many continents are there?",
"What is the maximum horsepower of any car?",
"How many countries does each continent have? List continent id, name and count."]},
{"db_id": "student_transcripts_tracking", "label": "Student Transcripts",
"questions": ["How many courses in total are listed?",
"How many students are there?",
"List the first and last name of every student."]},
]
# --- model state (loaded in the background so the server starts instantly) ------
STATE = {"status": "loading", "model": ADAPTER, "device": None, "error": None}
_MODEL = {"model": None, "tok": None}
_LOCK = threading.Lock() # one generation at a time (single GPU)
def _load():
try:
model, tok = load_model(BASE_MODEL, adapter_path=ADAPTER, four_bit=FOUR_BIT)
_MODEL["model"], _MODEL["tok"] = model, tok
STATE["device"] = str(getattr(model, "device", "cuda"))
STATE["status"] = "online"
print(f"[sqlforge] model online ({STATE['device']}, 4bit={FOUR_BIT})")
except Exception as exc: # noqa: BLE001
STATE["status"] = "error"
STATE["error"] = str(exc)
print(f"[sqlforge] model failed to load: {exc}")
app = FastAPI(title="SQLForge", description="Fine-tuned text-to-SQL demo")
@app.on_event("startup")
def _startup():
threading.Thread(target=_load, daemon=True).start()
def _db_path(db_id: str) -> Path | None:
"""Resolve a known example db_id to its .sqlite file (no path traversal)."""
if not db_id or "/" in db_id or "\\" in db_id or ".." in db_id:
return None
p = DB_DIR / db_id / f"{db_id}.sqlite"
return p if p.exists() else None
# --- API ------------------------------------------------------------------------
class GenerateRequest(BaseModel):
question: str
schema_text: str | None = None # raw CREATE TABLE text (custom mode)
db_id: str | None = None # example DB id (executes + self-corrects)
self_correct: bool = True
@app.get("/api/health")
def health():
return STATE
@app.get("/api/examples")
def examples():
return [e for e in EXAMPLES if _db_path(e["db_id"])]
@app.get("/api/schema")
def schema(db_id: str):
p = _db_path(db_id)
if not p:
raise HTTPException(404, f"unknown database '{db_id}'")
return {"db_id": db_id, "schema": schema_from_sqlite(p)}
@app.post("/api/generate")
def generate(req: GenerateRequest):
if STATE["status"] != "online":
raise HTTPException(503, f"model not ready ({STATE['status']})")
if not req.question.strip():
raise HTTPException(400, "question is required")
db_path = _db_path(req.db_id) if req.db_id else None
schema_text = req.schema_text
if db_path and not schema_text:
schema_text = schema_from_sqlite(db_path)
if not schema_text:
raise HTTPException(400, "provide a schema or pick an example database")
model, tok = _MODEL["model"], _MODEL["tok"]
trace: list = []
t0 = time.time()
with _LOCK:
if db_path and req.self_correct:
sql, attempts = generate_sql_with_retry(
model, tok, schema_text, req.question,
validate=lambda s: run_sql(db_path, s)[1],
max_retries=MAX_RETRIES, max_new_tokens=GEN_TOKENS, trace=trace)
else:
sql = generate_sql(model, tok, schema_text, req.question,
max_new_tokens=GEN_TOKENS)
attempts = 1
elapsed = round(time.time() - t0, 2)
resp = {"sql": sql, "attempts": attempts, "trace": trace,
"self_corrected": attempts > 1, "elapsed_s": elapsed,
"executed": False, "columns": None, "rows": None,
"row_count": None, "error": None}
# if we have the real DB, run the final query and return the result preview
if db_path:
conn = sqlite3.connect(str(db_path))
try:
cur = conn.execute(sql)
cols = [c[0] for c in cur.description] if cur.description else []
rows = cur.fetchmany(MAX_ROWS)
resp.update(executed=True, columns=cols,
rows=[list(r) for r in rows], row_count=len(rows))
except Exception as exc: # noqa: BLE001
resp.update(executed=True, error=str(exc))
finally:
conn.close()
return resp
# --- static frontend (mounted last so /api/* wins) ------------------------------
@app.get("/")
def index():
return FileResponse(STATIC / "index.html")
app.mount("/", StaticFiles(directory=str(STATIC)), name="static")
|