File size: 8,807 Bytes
d39d3d8
7073cc4
d39d3d8
 
 
4040150
 
 
d39d3d8
4040150
 
 
 
 
 
 
 
d39d3d8
4040150
 
 
7073cc4
d39d3d8
4040150
 
d39d3d8
7073cc4
4040150
d39d3d8
4040150
d39d3d8
4040150
d39d3d8
7073cc4
 
4040150
d39d3d8
4040150
 
 
 
 
 
 
 
c53c8b6
4040150
 
 
 
 
 
7073cc4
4040150
 
d39d3d8
7073cc4
4040150
 
 
7073cc4
4040150
 
 
 
 
 
7073cc4
4040150
d39d3d8
4040150
 
 
 
 
 
 
 
 
 
 
7073cc4
4040150
4f43f55
 
 
 
 
 
a175f6f
7073cc4
d39d3d8
 
4f43f55
471250f
4f43f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471250f
 
4f43f55
 
 
 
 
 
 
 
d39d3d8
4f43f55
 
 
4040150
4f43f55
 
d39d3d8
a175f6f
4f43f55
7073cc4
4f43f55
 
d39d3d8
 
af59526
87cc995
d39d3d8
e870039
 
 
 
 
 
 
 
 
 
 
 
 
7073cc4
e870039
 
 
 
7073cc4
4040150
 
 
 
 
7073cc4
4040150
 
7073cc4
4040150
d39d3d8
4040150
d39d3d8
 
 
 
 
4040150
7073cc4
 
 
4040150
 
 
 
 
 
7073cc4
 
d39d3d8
4040150
d39d3d8
7073cc4
d39d3d8
7073cc4
d39d3d8
 
 
7073cc4
4040150
 
7073cc4
4040150
7073cc4
4040150
 
 
d39d3d8
 
7073cc4
4040150
 
d39d3d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
"""
app.py — Model: google/flan-t5-large (Text-to-SQL)
HuggingFace Space: Free Tier (CPU)
"""

import os
import re
import io
import json
import sqlite3
import tempfile
import pandas as pd
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# ── Config ────────────────────────────────────────────────────────────────────
MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"   # T5-based text→SQL, CPU-friendly
MAX_NEW_TOKENS = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ── Load model once at startup ─────────────────────────────────────────────────
print(f"[INFO] Loading model: {MODEL_NAME}  |  device: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
print("[INFO] Model ready.")

# ── In-memory DB store ─────────────────────────────────────────────────────────
_db_store: dict[str, bytes] = {}   # session_id → sqlite db bytes
_schema_store: dict[str, str] = {} # session_id → schema string

app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Static frontend ────────────────────────────────────────────────────────────
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/")
def root():
    return FileResponse("static/index.html")


# ── Helpers ────────────────────────────────────────────────────────────────────
def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
    """Convert DataFrame → SQLite DB bytes."""
    buf = io.BytesIO()
    with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
        tmp_path = tmp.name
    conn = sqlite3.connect(tmp_path)
    df.to_sql(table_name, conn, if_exists="replace", index=False)
    conn.close()
    with open(tmp_path, "rb") as f:
        db_bytes = f.read()
    os.unlink(tmp_path)
    return db_bytes


def get_schema(db_bytes: bytes) -> str:
    """Extract CREATE TABLE schema from DB bytes."""
    with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
        tmp.write(db_bytes)
        tmp_path = tmp.name
    conn = sqlite3.connect(tmp_path)
    cur = conn.cursor()
    cur.execute("SELECT sql FROM sqlite_master WHERE type='table'")
    rows = cur.fetchall()
    conn.close()
    os.unlink(tmp_path)
    return "\n".join(r[0] for r in rows if r[0])


def generate_sql(question: str, schema: str) -> str:
    """
    Enhanced Hybrid SQL Engine.
    Priority 1: Smart Regex (Deterministic & Instant)
    Priority 2: T5 Transformer (Probabilistic Fallback)
    """
    # 1. Context Extraction
    table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
    table_name = table_match.group(1) if table_match else "data"
    quoted = f'"{table_name}"'
    col_match = re.findall(r'"(\w+)"', schema)
    
    q = question.lower().strip()

    # 2. Smart Column Detection
    # Searches for a column name from the schema within the user's question
    target_col = None
    for col in col_match:
        if col.lower() in q:
            target_col = col
            break

    # 3. Enhanced Rule-Based Shortcuts (Smart Logic)
    
    # DISTINCT/UNIQUE COUNT
    if re.search(r'unique|distinct', q):
        col = target_col if target_col else (col_match[0] if col_match else "*")
        return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'

    # GROUP BY
    if re.search(r'group.*by|per|each', q):
        col = target_col if target_col else (col_match[0] if col_match else "data")
        return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'

    # AVERAGE (With semantic fallback for your city_day dataset)
    if re.search(r'average|avg|mean', q):
        num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0])
        return f'SELECT AVG("{num_col}") FROM {quoted}'

    # TOTAL RECORDS
    if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
        return f'SELECT COUNT(*) FROM {quoted}'

    # LIMIT/TOP ROWS
    if re.search(r'show|display|get|first|top', q):
        n_match = re.search(r'\d+', q)
        limit = n_match.group() if n_match else 10
        return f'SELECT * FROM {quoted} LIMIT {limit}'

    # 4. T5 Model Fallback
    col_hint = ", ".join(col_match) if col_match else ""
    prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
    
    sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    # Post-inference cleaning (Crucial for SQLite stability)
    sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
    sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
    
    if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
        sql = f'SELECT * FROM {quoted} LIMIT 10'

    return sql

def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
    """Run SQL against the in-memory SQLite DB."""
    with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
        tmp.write(db_bytes)
        tmp_path = tmp.name
    conn = sqlite3.connect(tmp_path)
    conn.row_factory = sqlite3.Row
    try:
        cur = conn.execute(sql)
        rows = [dict(r) for r in cur.fetchall()]
    except Exception as e:
        conn.close()
        os.unlink(tmp_path)
        raise HTTPException(status_code=400, detail=f"SQL error: {e}")
    conn.close()
    os.unlink(tmp_path)
    return rows


# ── Routes ─────────────────────────────────────────────────────────────────────
class QueryRequest(BaseModel):
    session_id: str
    question: str


@app.post("/upload")
async def upload_csv(file: UploadFile = File(...)):
    """Upload CSV → parse → store as SQLite → return session_id & preview."""
    if not file.filename.endswith(".csv"):
        raise HTTPException(status_code=400, detail="Only CSV files accepted.")
    contents = await file.read()
    try:
        df = pd.read_csv(io.BytesIO(contents))
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")

    session_id = os.urandom(8).hex()
    table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
    if table_name[0].isdigit():
        table_name = "t_" + table_name
    db_bytes = csv_to_sqlite(df, table_name)
    schema = get_schema(db_bytes)

    _db_store[session_id] = db_bytes
    _schema_store[session_id] = schema

    preview = df.head(5).to_dict(orient="records")
    columns = list(df.columns)
    return JSONResponse({
        "session_id": session_id,
        "table_name": table_name,
        "columns": columns,
        "row_count": len(df),
        "preview": preview,
        "schema": schema,
    })


@app.post("/query")
async def query(req: QueryRequest):
    """Natural language question → SQL → execute → return results."""
    if req.session_id not in _db_store:
        raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
    schema = _schema_store[req.session_id]
    sql = generate_sql(req.question, schema)
    results = execute_sql(sql, _db_store[req.session_id])
    return JSONResponse({"sql": sql, "results": results})


@app.get("/health")
def health():
    return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}