nilotpaldhar2004 commited on
Commit
af59526
Β·
unverified Β·
1 Parent(s): 7c20e07

Update model to defog/sqlcoder-7b-2 and adjust settings

Browse files
Files changed (1) hide show
  1. app.py +59 -27
app.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
- app.py β€” Model: google/flan-t5-large (Text-to-SQL)
3
- HuggingFace Space: Free Tier (CPU)
 
 
4
  """
5
 
6
  import os
@@ -15,26 +17,40 @@ from fastapi.staticfiles import StaticFiles
15
  from fastapi.responses import FileResponse, JSONResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel
18
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
19
  import torch
20
 
21
  # ── Config ────────────────────────────────────────────────────────────────────
22
- MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly
23
- MAX_NEW_TOKENS = 256
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
25
 
26
- # ── Load model once at startup ─────────────────────────────────────────────────
27
  print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
 
 
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
30
  model.eval()
31
  print("[INFO] Model ready.")
32
 
33
- # ── In-memory DB store ─────────────────────────────────────────────────────────
34
- _db_store: dict[str, bytes] = {} # session_id β†’ sqlite db bytes
35
- _schema_store: dict[str, str] = {} # session_id β†’ schema string
36
 
37
- app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
38
 
39
  app.add_middleware(
40
  CORSMiddleware,
@@ -43,7 +59,6 @@ app.add_middleware(
43
  allow_headers=["*"],
44
  )
45
 
46
- # ── Static frontend ────────────────────────────────────────────────────────────
47
  app.mount("/static", StaticFiles(directory="static"), name="static")
48
 
49
  @app.get("/")
@@ -53,8 +68,6 @@ def root():
53
 
54
  # ── Helpers ────────────────────────────────────────────────────────────────────
55
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
56
- """Convert DataFrame β†’ SQLite DB bytes."""
57
- buf = io.BytesIO()
58
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
59
  tmp_path = tmp.name
60
  conn = sqlite3.connect(tmp_path)
@@ -67,7 +80,6 @@ def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
67
 
68
 
69
  def get_schema(db_bytes: bytes) -> str:
70
- """Extract CREATE TABLE schema from DB bytes."""
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp.write(db_bytes)
73
  tmp_path = tmp.name
@@ -80,42 +92,65 @@ def get_schema(db_bytes: bytes) -> str:
80
  return "\n".join(r[0] for r in rows if r[0])
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def generate_sql(question: str, schema: str) -> str:
84
- """Run T5 inference to produce SQL."""
85
  # Extract table name from schema
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
  table_name = table_match.group(1) if table_match else "data"
88
  quoted = f'"{table_name}"'
89
 
90
- prompt = f"tables:\n{schema}\nquery for: {question}"
91
  inputs = tokenizer(
92
  prompt,
93
  return_tensors="pt",
94
  truncation=True,
95
- max_length=512,
96
  ).to(DEVICE)
 
 
97
  with torch.no_grad():
98
  outputs = model.generate(
99
  **inputs,
100
  max_new_tokens=MAX_NEW_TOKENS,
101
  num_beams=4,
102
  early_stopping=True,
 
103
  )
104
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
105
 
106
- # Fix 1: replace any FROM/JOIN table reference (quoted or unquoted) with correct table
 
 
 
 
 
 
 
 
107
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
108
  sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
109
 
110
- # Fix 2: strip junk tokens after table name before LIMIT/WHERE/ORDER etc.
111
- # e.g. FROM "city_day" Datetime LIMIT 10 β†’ FROM "city_day" LIMIT 10
112
  sql = re.sub(
113
  r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
114
  r'\1',
115
  sql, flags=re.IGNORECASE
116
  )
117
 
118
- # Fix 3: fallback if no SELECT at all
119
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
120
  sql = f'SELECT * FROM {quoted} LIMIT 10'
121
 
@@ -123,7 +158,6 @@ def generate_sql(question: str, schema: str) -> str:
123
 
124
 
125
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
126
- """Run SQL against the in-memory SQLite DB."""
127
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
128
  tmp.write(db_bytes)
129
  tmp_path = tmp.name
@@ -149,7 +183,6 @@ class QueryRequest(BaseModel):
149
 
150
  @app.post("/upload")
151
  async def upload_csv(file: UploadFile = File(...)):
152
- """Upload CSV β†’ parse β†’ store as SQLite β†’ return session_id & preview."""
153
  if not file.filename.endswith(".csv"):
154
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
155
  contents = await file.read()
@@ -182,9 +215,8 @@ async def upload_csv(file: UploadFile = File(...)):
182
 
183
  @app.post("/query")
184
  async def query(req: QueryRequest):
185
- """Natural language question β†’ SQL β†’ execute β†’ return results."""
186
  if req.session_id not in _db_store:
187
- raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
188
  schema = _schema_store[req.session_id]
189
  sql = generate_sql(req.question, schema)
190
  results = execute_sql(sql, _db_store[req.session_id])
 
1
  """
2
+ app.py β€” Model: defog/sqlcoder-7b-2 (Text-to-SQL)
3
+ HuggingFace Space: Free Tier (needs GPU Space or patience on CPU)
4
+ NOTE: 7B model β€” use HF Spaces with GPU (T4 small) if available.
5
+ On CPU it will be slow (~60-120s per query) but will work.
6
  """
7
 
8
  import os
 
17
  from fastapi.responses import FileResponse, JSONResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
  from pydantic import BaseModel
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
21
  import torch
22
 
23
  # ── Config ────────────────────────────────────────────────────────────────────
24
+ MODEL_NAME = "defog/sqlcoder-7b-2"
25
+ MAX_NEW_TOKENS = 300
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ LOAD_IN_8BIT = False # set True if bitsandbytes is available on GPU space
28
 
29
+ # ── Load model once ────────────────────────────────────────────────────────────
30
  print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
31
+ print("[INFO] This may take a few minutes on first load...")
32
+
33
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
34
+
35
+ model_kwargs = {
36
+ "torch_dtype": torch.float16 if DEVICE == "cuda" else torch.float32,
37
+ "device_map": "auto" if DEVICE == "cuda" else None,
38
+ "low_cpu_mem_usage": True,
39
+ }
40
+ if LOAD_IN_8BIT and DEVICE == "cuda":
41
+ model_kwargs["load_in_8bit"] = True
42
+
43
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_kwargs)
44
+ if DEVICE == "cpu":
45
+ model = model.to(DEVICE)
46
  model.eval()
47
  print("[INFO] Model ready.")
48
 
49
+ # ── In-memory store ────────────────────────────────────────────────────────────
50
+ _db_store: dict[str, bytes] = {}
51
+ _schema_store: dict[str, str] = {}
52
 
53
+ app = FastAPI(title="CSV-to-SQL Chat (SQLCoder-7B)", version="1.0.0")
54
 
55
  app.add_middleware(
56
  CORSMiddleware,
 
59
  allow_headers=["*"],
60
  )
61
 
 
62
  app.mount("/static", StaticFiles(directory="static"), name="static")
63
 
64
  @app.get("/")
 
68
 
69
  # ── Helpers ────────────────────────────────────────────────────────────────────
70
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
 
 
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
72
  tmp_path = tmp.name
73
  conn = sqlite3.connect(tmp_path)
 
80
 
81
 
82
  def get_schema(db_bytes: bytes) -> str:
 
83
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
84
  tmp.write(db_bytes)
85
  tmp_path = tmp.name
 
92
  return "\n".join(r[0] for r in rows if r[0])
93
 
94
 
95
+ def build_prompt(question: str, schema: str) -> str:
96
+ """SQLCoder uses a specific prompt format."""
97
+ return f"""### Task
98
+ Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
99
+
100
+ ### Database Schema
101
+ The query will run on a database with the following schema:
102
+ {schema}
103
+
104
+ ### Answer
105
+ Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
106
+ [SQL]
107
+ """
108
+
109
+
110
  def generate_sql(question: str, schema: str) -> str:
 
111
  # Extract table name from schema
112
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
113
  table_name = table_match.group(1) if table_match else "data"
114
  quoted = f'"{table_name}"'
115
 
116
+ prompt = build_prompt(question, schema)
117
  inputs = tokenizer(
118
  prompt,
119
  return_tensors="pt",
120
  truncation=True,
121
+ max_length=1024,
122
  ).to(DEVICE)
123
+
124
+ eos_token_id = tokenizer.eos_token_id
125
  with torch.no_grad():
126
  outputs = model.generate(
127
  **inputs,
128
  max_new_tokens=MAX_NEW_TOKENS,
129
  num_beams=4,
130
  early_stopping=True,
131
+ pad_token_id=eos_token_id,
132
  )
 
133
 
134
+ # Decode only newly generated tokens
135
+ generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
136
+ sql = tokenizer.decode(generated_ids, skip_special_tokens=True)
137
+
138
+ # Clean SQLCoder artifacts
139
+ sql = sql.split("[/SQL]")[0].strip()
140
+ sql = re.sub(r"```sql|```", "", sql).strip()
141
+
142
+ # Fix 1: replace any FROM/JOIN table reference with correct table
143
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
144
  sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
145
 
146
+ # Fix 2: strip junk tokens after table name
 
147
  sql = re.sub(
148
  r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
149
  r'\1',
150
  sql, flags=re.IGNORECASE
151
  )
152
 
153
+ # Fix 3: fallback if no SELECT
154
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
155
  sql = f'SELECT * FROM {quoted} LIMIT 10'
156
 
 
158
 
159
 
160
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
 
161
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
162
  tmp.write(db_bytes)
163
  tmp_path = tmp.name
 
183
 
184
  @app.post("/upload")
185
  async def upload_csv(file: UploadFile = File(...)):
 
186
  if not file.filename.endswith(".csv"):
187
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
188
  contents = await file.read()
 
215
 
216
  @app.post("/query")
217
  async def query(req: QueryRequest):
 
218
  if req.session_id not in _db_store:
219
+ raise HTTPException(status_code=404, detail="Session not found. Upload CSV first.")
220
  schema = _schema_store[req.session_id]
221
  sql = generate_sql(req.question, schema)
222
  results = execute_sql(sql, _db_store[req.session_id])