nilotpaldhar2004 commited on
Commit
c53c8b6
Β·
unverified Β·
1 Parent(s): af59526

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -93
app.py CHANGED
@@ -1,14 +1,6 @@
1
- """
2
- app.py β€” Model: defog/sqlcoder-7b-2 (Text-to-SQL)
3
- HuggingFace Space: Free Tier (needs GPU Space or patience on CPU)
4
- NOTE: 7B model β€” use HF Spaces with GPU (T4 small) if available.
5
- On CPU it will be slow (~60-120s per query) but will work.
6
- """
7
-
8
  import os
9
  import re
10
  import io
11
- import json
12
  import sqlite3
13
  import tempfile
14
  import pandas as pd
@@ -17,40 +9,44 @@ from fastapi.staticfiles import StaticFiles
17
  from fastapi.responses import FileResponse, JSONResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel
20
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
21
  import torch
22
 
23
  # ── Config ────────────────────────────────────────────────────────────────────
24
  MODEL_NAME = "defog/sqlcoder-7b-2"
25
  MAX_NEW_TOKENS = 300
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- LOAD_IN_8BIT = False # set True if bitsandbytes is available on GPU space
28
 
29
- # ── Load model once ────────────────────────────────────────────────────────────
30
- print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
31
- print("[INFO] This may take a few minutes on first load...")
 
 
 
 
 
 
 
 
32
 
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
34
 
35
- model_kwargs = {
36
- "torch_dtype": torch.float16 if DEVICE == "cuda" else torch.float32,
37
- "device_map": "auto" if DEVICE == "cuda" else None,
38
- "low_cpu_mem_usage": True,
39
- }
40
- if LOAD_IN_8BIT and DEVICE == "cuda":
41
- model_kwargs["load_in_8bit"] = True
42
-
43
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_kwargs)
44
- if DEVICE == "cpu":
45
- model = model.to(DEVICE)
46
  model.eval()
47
- print("[INFO] Model ready.")
48
 
49
  # ── In-memory store ────────────────────────────────────────────────────────────
50
  _db_store: dict[str, bytes] = {}
51
  _schema_store: dict[str, str] = {}
52
 
53
- app = FastAPI(title="CSV-to-SQL Chat (SQLCoder-7B)", version="1.0.0")
54
 
55
  app.add_middleware(
56
  CORSMiddleware,
@@ -59,13 +55,17 @@ app.add_middleware(
59
  allow_headers=["*"],
60
  )
61
 
 
 
 
 
 
62
  app.mount("/static", StaticFiles(directory="static"), name="static")
63
 
64
  @app.get("/")
65
  def root():
66
  return FileResponse("static/index.html")
67
 
68
-
69
  # ── Helpers ────────────────────────────────────────────────────────────────────
70
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
@@ -78,7 +78,6 @@ def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
78
  os.unlink(tmp_path)
79
  return db_bytes
80
 
81
-
82
  def get_schema(db_bytes: bytes) -> str:
83
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
84
  tmp.write(db_bytes)
@@ -91,9 +90,8 @@ def get_schema(db_bytes: bytes) -> str:
91
  os.unlink(tmp_path)
92
  return "\n".join(r[0] for r in rows if r[0])
93
 
94
-
95
  def build_prompt(question: str, schema: str) -> str:
96
- """SQLCoder uses a specific prompt format."""
97
  return f"""### Task
98
  Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
99
 
@@ -106,57 +104,35 @@ Given the database schema, here is the SQL query that [QUESTION]{question}[/QUES
106
  [SQL]
107
  """
108
 
109
-
110
  def generate_sql(question: str, schema: str) -> str:
111
- # Extract table name from schema
112
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
113
- table_name = table_match.group(1) if table_match else "data"
114
- quoted = f'"{table_name}"'
115
-
116
  prompt = build_prompt(question, schema)
117
- inputs = tokenizer(
118
- prompt,
119
- return_tensors="pt",
120
- truncation=True,
121
- max_length=1024,
122
- ).to(DEVICE)
123
-
124
- eos_token_id = tokenizer.eos_token_id
125
  with torch.no_grad():
126
  outputs = model.generate(
127
  **inputs,
128
  max_new_tokens=MAX_NEW_TOKENS,
129
- num_beams=4,
130
- early_stopping=True,
131
- pad_token_id=eos_token_id,
 
132
  )
133
 
134
- # Decode only newly generated tokens
135
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
136
- sql = tokenizer.decode(generated_ids, skip_special_tokens=True)
137
 
138
- # Clean SQLCoder artifacts
139
  sql = sql.split("[/SQL]")[0].strip()
140
  sql = re.sub(r"```sql|```", "", sql).strip()
141
 
142
- # Fix 1: replace any FROM/JOIN table reference with correct table
143
- sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
144
- sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
145
-
146
- # Fix 2: strip junk tokens after table name
147
- sql = re.sub(
148
- r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
149
- r'\1',
150
- sql, flags=re.IGNORECASE
151
- )
152
-
153
- # Fix 3: fallback if no SELECT
154
- if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
155
- sql = f'SELECT * FROM {quoted} LIMIT 10'
156
-
157
  return sql
158
 
159
-
160
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
161
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
162
  tmp.write(db_bytes)
@@ -167,62 +143,50 @@ def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
167
  cur = conn.execute(sql)
168
  rows = [dict(r) for r in cur.fetchall()]
169
  except Exception as e:
 
 
170
  conn.close()
171
  os.unlink(tmp_path)
172
- raise HTTPException(status_code=400, detail=f"SQL error: {e}")
173
- conn.close()
174
- os.unlink(tmp_path)
175
  return rows
176
 
177
-
178
  # ── Routes ─────────────────────────────────────────────────────────────────────
179
  class QueryRequest(BaseModel):
180
  session_id: str
181
  question: str
182
 
183
-
184
  @app.post("/upload")
185
  async def upload_csv(file: UploadFile = File(...)):
186
  if not file.filename.endswith(".csv"):
187
- raise HTTPException(status_code=400, detail="Only CSV files accepted.")
 
188
  contents = await file.read()
189
- try:
190
- df = pd.read_csv(io.BytesIO(contents))
191
- except Exception as e:
192
- raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
193
-
194
  session_id = os.urandom(8).hex()
195
- table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
196
- if table_name[0].isdigit():
197
- table_name = "t_" + table_name
198
  db_bytes = csv_to_sqlite(df, table_name)
199
  schema = get_schema(db_bytes)
200
 
201
  _db_store[session_id] = db_bytes
202
  _schema_store[session_id] = schema
203
 
204
- preview = df.head(5).to_dict(orient="records")
205
- columns = list(df.columns)
206
- return JSONResponse({
207
  "session_id": session_id,
208
- "table_name": table_name,
209
- "columns": columns,
210
- "row_count": len(df),
211
- "preview": preview,
212
- "schema": schema,
213
- })
214
-
215
 
216
  @app.post("/query")
217
  async def query(req: QueryRequest):
218
  if req.session_id not in _db_store:
219
- raise HTTPException(status_code=404, detail="Session not found. Upload CSV first.")
 
220
  schema = _schema_store[req.session_id]
221
  sql = generate_sql(req.question, schema)
222
  results = execute_sql(sql, _db_store[req.session_id])
223
- return JSONResponse({"sql": sql, "results": results})
224
-
225
 
226
  @app.get("/health")
227
  def health():
228
- return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import io
 
4
  import sqlite3
5
  import tempfile
6
  import pandas as pd
 
9
  from fastapi.responses import FileResponse, JSONResponse
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
13
  import torch
14
 
15
  # ── Config ────────────────────────────────────────────────────────────────────
16
  MODEL_NAME = "defog/sqlcoder-7b-2"
17
  MAX_NEW_TOKENS = 300
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
19
 
20
+ # ── Memory-Optimized Model Loading ───────────────────────────────────────────
21
+ print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
22
+ print("[INFO] Applying 4-bit quantization to fit within 16Gi RAM limit...")
23
+
24
+ # Configure 4-bit quantization for memory efficiency
25
+ quant_config = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_quant_type="nf4",
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_compute_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
30
+ )
31
 
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
 
34
+ # Load model with quantization and low memory usage settings
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL_NAME,
37
+ quantization_config=quant_config,
38
+ device_map="auto",
39
+ low_cpu_mem_usage=True,
40
+ trust_remote_code=True
41
+ )
 
 
 
42
  model.eval()
43
+ print("[INFO] Model loaded successfully.")
44
 
45
  # ── In-memory store ────────────────────────────────────────────────────────────
46
  _db_store: dict[str, bytes] = {}
47
  _schema_store: dict[str, str] = {}
48
 
49
+ app = FastAPI(title="SQLCoder CSV Chat", version="1.1.0")
50
 
51
  app.add_middleware(
52
  CORSMiddleware,
 
55
  allow_headers=["*"],
56
  )
57
 
58
+ # ── Static frontend ────────────────────────────────────────────────────────────
59
+ # Ensure your index.html is in a folder named 'static'
60
+ if not os.path.exists("static"):
61
+ os.makedirs("static")
62
+
63
  app.mount("/static", StaticFiles(directory="static"), name="static")
64
 
65
  @app.get("/")
66
  def root():
67
  return FileResponse("static/index.html")
68
 
 
69
  # ── Helpers ────────────────────────────────────────────────────────────────────
70
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
 
78
  os.unlink(tmp_path)
79
  return db_bytes
80
 
 
81
  def get_schema(db_bytes: bytes) -> str:
82
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
83
  tmp.write(db_bytes)
 
90
  os.unlink(tmp_path)
91
  return "\n".join(r[0] for r in rows if r[0])
92
 
 
93
  def build_prompt(question: str, schema: str) -> str:
94
+ """SQLCoder specific prompt format for better accuracy."""
95
  return f"""### Task
96
  Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
97
 
 
104
  [SQL]
105
  """
106
 
 
107
  def generate_sql(question: str, schema: str) -> str:
 
108
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
109
+ table_name = table_match.group(1) if table_match else "user_data"
110
+
 
111
  prompt = build_prompt(question, schema)
112
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(model.device)
113
+
 
 
 
 
 
 
114
  with torch.no_grad():
115
  outputs = model.generate(
116
  **inputs,
117
  max_new_tokens=MAX_NEW_TOKENS,
118
+ do_sample=False,
119
+ num_beams=1,
120
+ eos_token_id=tokenizer.eos_token_id,
121
+ pad_token_id=tokenizer.eos_token_id
122
  )
123
 
124
+ # Decode newly generated tokens
125
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
126
+ sql = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
127
 
128
+ # Post-processing and cleaning
129
  sql = sql.split("[/SQL]")[0].strip()
130
  sql = re.sub(r"```sql|```", "", sql).strip()
131
 
132
+
133
+ sql = re.sub(r'\bFROM\s+(\w+)', f'FROM "{table_name}"', sql, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  return sql
135
 
 
136
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
137
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
138
  tmp.write(db_bytes)
 
143
  cur = conn.execute(sql)
144
  rows = [dict(r) for r in cur.fetchall()]
145
  except Exception as e:
146
+ raise HTTPException(status_code=400, detail=f"Execution error: {e}")
147
+ finally:
148
  conn.close()
149
  os.unlink(tmp_path)
 
 
 
150
  return rows
151
 
 
152
  # ── Routes ─────────────────────────────────────────────────────────────────────
153
  class QueryRequest(BaseModel):
154
  session_id: str
155
  question: str
156
 
 
157
  @app.post("/upload")
158
  async def upload_csv(file: UploadFile = File(...)):
159
  if not file.filename.endswith(".csv"):
160
+ raise HTTPException(status_code=400, detail="Invalid file type. Upload a CSV.")
161
+
162
  contents = await file.read()
163
+ df = pd.read_csv(io.BytesIO(contents))
164
+
 
 
 
165
  session_id = os.urandom(8).hex()
166
+ table_name = "user_data" # Standardized for internal SQL logic
 
 
167
  db_bytes = csv_to_sqlite(df, table_name)
168
  schema = get_schema(db_bytes)
169
 
170
  _db_store[session_id] = db_bytes
171
  _schema_store[session_id] = schema
172
 
173
+ return {
 
 
174
  "session_id": session_id,
175
+ "columns": list(df.columns),
176
+ "preview": df.head(3).to_dict(orient="records")
177
+ }
 
 
 
 
178
 
179
  @app.post("/query")
180
  async def query(req: QueryRequest):
181
  if req.session_id not in _db_store:
182
+ raise HTTPException(status_code=404, detail="Session expired.")
183
+
184
  schema = _schema_store[req.session_id]
185
  sql = generate_sql(req.question, schema)
186
  results = execute_sql(sql, _db_store[req.session_id])
187
+
188
+ return {"sql": sql, "results": results}
189
 
190
  @app.get("/health")
191
  def health():
192
+ return {"status": "running", "quantization": "4-bit"}