nilotpaldhar2004 commited on
Commit
d39d3d8
Β·
unverified Β·
1 Parent(s): f06394a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -79
app.py CHANGED
@@ -1,6 +1,12 @@
 
 
 
 
 
1
  import os
2
  import re
3
  import io
 
4
  import sqlite3
5
  import tempfile
6
  import pandas as pd
@@ -9,44 +15,26 @@ from fastapi.staticfiles import StaticFiles
9
  from fastapi.responses import FileResponse, JSONResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
  import torch
14
 
15
  # ── Config ────────────────────────────────────────────────────────────────────
16
- MODEL_NAME = "defog/sqlcoder-7b-2"
17
- MAX_NEW_TOKENS = 300
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- # ── Memory-Optimized Model Loading ───────────────────────────────────────────
21
- print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
22
- print("[INFO] Applying 4-bit quantization to fit within 16Gi RAM limit...")
23
-
24
- # Configure 4-bit quantization for memory efficiency
25
- quant_config = BitsAndBytesConfig(
26
- load_in_4bit=True,
27
- bnb_4bit_quant_type="nf4",
28
- bnb_4bit_use_double_quant=True,
29
- bnb_4bit_compute_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
30
- )
31
-
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
-
34
- # Load model with quantization and low memory usage settings
35
- model = AutoModelForCausalLM.from_pretrained(
36
- MODEL_NAME,
37
- quantization_config=quant_config,
38
- device_map="auto",
39
- low_cpu_mem_usage=True,
40
- trust_remote_code=True
41
- )
42
  model.eval()
43
- print("[INFO] Model loaded successfully.")
44
 
45
- # ── In-memory store ────────────────────────────────────────────────────────────
46
- _db_store: dict[str, bytes] = {}
47
- _schema_store: dict[str, str] = {}
48
 
49
- app = FastAPI(title="SQLCoder CSV Chat", version="1.1.0")
50
 
51
  app.add_middleware(
52
  CORSMiddleware,
@@ -56,18 +44,17 @@ app.add_middleware(
56
  )
57
 
58
  # ── Static frontend ────────────────────────────────────────────────────────────
59
- # Ensure your index.html is in a folder named 'static'
60
- if not os.path.exists("static"):
61
- os.makedirs("static")
62
-
63
  app.mount("/static", StaticFiles(directory="static"), name="static")
64
 
65
  @app.get("/")
66
  def root():
67
  return FileResponse("static/index.html")
68
 
 
69
  # ── Helpers ────────────────────────────────────────────────────────────────────
70
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
 
 
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp_path = tmp.name
73
  conn = sqlite3.connect(tmp_path)
@@ -78,7 +65,9 @@ def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
78
  os.unlink(tmp_path)
79
  return db_bytes
80
 
 
81
  def get_schema(db_bytes: bytes) -> str:
 
82
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
83
  tmp.write(db_bytes)
84
  tmp_path = tmp.name
@@ -90,50 +79,54 @@ def get_schema(db_bytes: bytes) -> str:
90
  os.unlink(tmp_path)
91
  return "\n".join(r[0] for r in rows if r[0])
92
 
93
- def build_prompt(question: str, schema: str) -> str:
94
- """SQLCoder specific prompt format for better accuracy."""
95
- return f"""### Task
96
- Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
97
-
98
- ### Database Schema
99
- The query will run on a database with the following schema:
100
- {schema}
101
-
102
- ### Answer
103
- Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
104
- [SQL]
105
- """
106
 
107
  def generate_sql(question: str, schema: str) -> str:
 
 
108
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
109
- table_name = table_match.group(1) if table_match else "user_data"
110
-
111
- prompt = build_prompt(question, schema)
112
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
113
-
 
 
 
 
 
 
 
 
114
  with torch.no_grad():
115
  outputs = model.generate(
116
  **inputs,
117
  max_new_tokens=MAX_NEW_TOKENS,
118
- do_sample=False,
119
- num_beams=1,
120
- eos_token_id=tokenizer.eos_token_id,
121
- pad_token_id=tokenizer.eos_token_id
122
  )
 
123
 
124
- # Decode newly generated tokens
125
- generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
126
- sql = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
127
 
128
- # Post-processing and cleaning
129
- sql = sql.split("[/SQL]")[0].strip()
130
- sql = re.sub(r"```sql|```", "", sql).strip()
 
 
 
 
 
 
 
 
131
 
132
-
133
- sql = re.sub(r'\bFROM\s+(\w+)', f'FROM "{table_name}"', sql, flags=re.IGNORECASE)
134
  return sql
135
 
 
136
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
 
137
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
138
  tmp.write(db_bytes)
139
  tmp_path = tmp.name
@@ -143,50 +136,64 @@ def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
143
  cur = conn.execute(sql)
144
  rows = [dict(r) for r in cur.fetchall()]
145
  except Exception as e:
146
- raise HTTPException(status_code=400, detail=f"Execution error: {e}")
147
- finally:
148
  conn.close()
149
  os.unlink(tmp_path)
 
 
 
150
  return rows
151
 
 
152
  # ── Routes ─────────────────────────────────────────────────────────────────────
153
  class QueryRequest(BaseModel):
154
  session_id: str
155
  question: str
156
 
 
157
  @app.post("/upload")
158
  async def upload_csv(file: UploadFile = File(...)):
 
159
  if not file.filename.endswith(".csv"):
160
- raise HTTPException(status_code=400, detail="Invalid file type. Upload a CSV.")
161
-
162
  contents = await file.read()
163
- df = pd.read_csv(io.BytesIO(contents))
164
-
 
 
 
165
  session_id = os.urandom(8).hex()
166
- table_name = "user_data" # Standardized for internal SQL logic
 
 
167
  db_bytes = csv_to_sqlite(df, table_name)
168
  schema = get_schema(db_bytes)
169
 
170
  _db_store[session_id] = db_bytes
171
  _schema_store[session_id] = schema
172
 
173
- return {
 
 
174
  "session_id": session_id,
175
- "columns": list(df.columns),
176
- "preview": df.head(3).to_dict(orient="records")
177
- }
 
 
 
 
178
 
179
  @app.post("/query")
180
  async def query(req: QueryRequest):
 
181
  if req.session_id not in _db_store:
182
- raise HTTPException(status_code=404, detail="Session expired.")
183
-
184
  schema = _schema_store[req.session_id]
185
  sql = generate_sql(req.question, schema)
186
  results = execute_sql(sql, _db_store[req.session_id])
187
-
188
- return {"sql": sql, "results": results}
189
 
190
  @app.get("/health")
191
  def health():
192
- return {"status": "running", "quantization": "4-bit"}
 
1
+ """
2
+ app.py β€” Model: google/flan-t5-large (Text-to-SQL)
3
+ HuggingFace Space: Free Tier (CPU)
4
+ """
5
+
6
  import os
7
  import re
8
  import io
9
+ import json
10
  import sqlite3
11
  import tempfile
12
  import pandas as pd
 
15
  from fastapi.responses import FileResponse, JSONResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel
18
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
19
  import torch
20
 
21
  # ── Config ────────────────────────────────────────────────────────────────────
22
+ MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly
23
+ MAX_NEW_TOKENS = 256
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
+ # ── Load model once at startup ─────────────────────────────────────────────────
27
+ print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
 
 
 
 
 
 
 
 
 
 
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
 
 
 
 
 
 
 
 
30
  model.eval()
31
+ print("[INFO] Model ready.")
32
 
33
+ # ── In-memory DB store ─────────────────────────────────────────────────────────
34
+ _db_store: dict[str, bytes] = {} # session_id β†’ sqlite db bytes
35
+ _schema_store: dict[str, str] = {} # session_id β†’ schema string
36
 
37
+ app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
38
 
39
  app.add_middleware(
40
  CORSMiddleware,
 
44
  )
45
 
46
  # ── Static frontend ────────────────────────────────────────────────────────────
 
 
 
 
47
  app.mount("/static", StaticFiles(directory="static"), name="static")
48
 
49
  @app.get("/")
50
  def root():
51
  return FileResponse("static/index.html")
52
 
53
+
54
  # ── Helpers ────────────────────────────────────────────────────────────────────
55
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
56
+ """Convert DataFrame β†’ SQLite DB bytes."""
57
+ buf = io.BytesIO()
58
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
59
  tmp_path = tmp.name
60
  conn = sqlite3.connect(tmp_path)
 
65
  os.unlink(tmp_path)
66
  return db_bytes
67
 
68
+
69
  def get_schema(db_bytes: bytes) -> str:
70
+ """Extract CREATE TABLE schema from DB bytes."""
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp.write(db_bytes)
73
  tmp_path = tmp.name
 
79
  os.unlink(tmp_path)
80
  return "\n".join(r[0] for r in rows if r[0])
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
+ """Run T5 inference to produce SQL."""
85
+ # Extract table name from schema
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
+ table_name = table_match.group(1) if table_match else "data"
88
+ quoted = f'"{table_name}"'
89
+
90
+ # Extract column names to inject into prompt β€” helps T5-small stay grounded
91
+ col_match = re.findall(r'"(\w+)"', schema)
92
+ col_hint = ", ".join(col_match) if col_match else ""
93
+ prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
94
+ inputs = tokenizer(
95
+ prompt,
96
+ return_tensors="pt",
97
+ truncation=True,
98
+ max_length=512,
99
+ ).to(DEVICE)
100
  with torch.no_grad():
101
  outputs = model.generate(
102
  **inputs,
103
  max_new_tokens=MAX_NEW_TOKENS,
104
+ num_beams=4,
105
+ early_stopping=True,
 
 
106
  )
107
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
108
 
109
+ # Fix 1: replace any FROM/JOIN table reference (quoted or unquoted) with correct table
110
+ sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
111
+ sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
112
 
113
+ # Fix 2: strip junk tokens after table name before LIMIT/WHERE/ORDER etc.
114
+ # e.g. FROM "city_day" Datetime LIMIT 10 β†’ FROM "city_day" LIMIT 10
115
+ sql = re.sub(
116
+ r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
117
+ r'\1',
118
+ sql, flags=re.IGNORECASE
119
+ )
120
+
121
+ # Fix 3: fallback if no SELECT at all
122
+ if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
123
+ sql = f'SELECT * FROM {quoted} LIMIT 10'
124
 
 
 
125
  return sql
126
 
127
+
128
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
129
+ """Run SQL against the in-memory SQLite DB."""
130
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
131
  tmp.write(db_bytes)
132
  tmp_path = tmp.name
 
136
  cur = conn.execute(sql)
137
  rows = [dict(r) for r in cur.fetchall()]
138
  except Exception as e:
 
 
139
  conn.close()
140
  os.unlink(tmp_path)
141
+ raise HTTPException(status_code=400, detail=f"SQL error: {e}")
142
+ conn.close()
143
+ os.unlink(tmp_path)
144
  return rows
145
 
146
+
147
  # ── Routes ─────────────────────────────────────────────────────────────────────
148
  class QueryRequest(BaseModel):
149
  session_id: str
150
  question: str
151
 
152
+
153
  @app.post("/upload")
154
  async def upload_csv(file: UploadFile = File(...)):
155
+ """Upload CSV β†’ parse β†’ store as SQLite β†’ return session_id & preview."""
156
  if not file.filename.endswith(".csv"):
157
+ raise HTTPException(status_code=400, detail="Only CSV files accepted.")
 
158
  contents = await file.read()
159
+ try:
160
+ df = pd.read_csv(io.BytesIO(contents))
161
+ except Exception as e:
162
+ raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
163
+
164
  session_id = os.urandom(8).hex()
165
+ table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
166
+ if table_name[0].isdigit():
167
+ table_name = "t_" + table_name
168
  db_bytes = csv_to_sqlite(df, table_name)
169
  schema = get_schema(db_bytes)
170
 
171
  _db_store[session_id] = db_bytes
172
  _schema_store[session_id] = schema
173
 
174
+ preview = df.head(5).to_dict(orient="records")
175
+ columns = list(df.columns)
176
+ return JSONResponse({
177
  "session_id": session_id,
178
+ "table_name": table_name,
179
+ "columns": columns,
180
+ "row_count": len(df),
181
+ "preview": preview,
182
+ "schema": schema,
183
+ })
184
+
185
 
186
  @app.post("/query")
187
  async def query(req: QueryRequest):
188
+ """Natural language question β†’ SQL β†’ execute β†’ return results."""
189
  if req.session_id not in _db_store:
190
+ raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
 
191
  schema = _schema_store[req.session_id]
192
  sql = generate_sql(req.question, schema)
193
  results = execute_sql(sql, _db_store[req.session_id])
194
+ return JSONResponse({"sql": sql, "results": results})
195
+
196
 
197
  @app.get("/health")
198
  def health():
199
+ return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}