nilotpaldhar2004 commited on
Commit
7073cc4
Β·
unverified Β·
1 Parent(s): 3d13366

Refactor app.py for model update and code clarity

Browse files

Updated model name and improved comments for clarity. Adjusted table name handling and SQL generation logic.

Files changed (1) hide show
  1. app.py +55 -48
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- app.py β€” Model: T5-Small (Text-to-SQL)
3
  HuggingFace Space: Free Tier (CPU)
4
  """
5
 
@@ -19,20 +19,20 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
19
  import torch
20
 
21
  # ── Config ────────────────────────────────────────────────────────────────────
22
- MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql"
23
  MAX_NEW_TOKENS = 256
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  # ── Load model once at startup ─────────────────────────────────────────────────
27
- print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
30
  model.eval()
31
  print("[INFO] Model ready.")
32
 
33
  # ── In-memory DB store ─────────────────────────────────────────────────────────
34
- _db_store: dict[str, bytes] = {}
35
- _schema_store: dict[str, str] = {}
36
 
37
  app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
38
 
@@ -50,21 +50,22 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
50
  def root():
51
  return FileResponse("static/index.html")
52
 
 
53
  # ── Helpers ────────────────────────────────────────────────────────────────────
54
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
55
  """Convert DataFrame β†’ SQLite DB bytes."""
 
56
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
57
  tmp_path = tmp.name
58
  conn = sqlite3.connect(tmp_path)
59
- # Ensure table name is clean
60
- clean_table = re.sub(r"[^a-zA-Z0-9_]", "_", table_name)
61
- df.to_sql(clean_table, conn, if_exists="replace", index=False)
62
  conn.close()
63
  with open(tmp_path, "rb") as f:
64
  db_bytes = f.read()
65
  os.unlink(tmp_path)
66
  return db_bytes
67
 
 
68
  def get_schema(db_bytes: bytes) -> str:
69
  """Extract CREATE TABLE schema from DB bytes."""
70
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
@@ -78,52 +79,57 @@ def get_schema(db_bytes: bytes) -> str:
78
  os.unlink(tmp_path)
79
  return "\n".join(r[0] for r in rows if r[0])
80
 
 
81
  def generate_sql(question: str, schema: str) -> str:
82
- """Run T5 inference with strict case-sensitivity fixes."""
83
- # 1. Force lowercase table name detection from schema
84
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
85
- # We explicitly lowercase this to match the SQLite storage
86
- table_name = table_match.group(1).lower() if table_match else "city_day"
87
  quoted = f'"{table_name}"'
88
 
89
- # 2. Build the prompt with explicit lowercase hints
90
  col_match = re.findall(r'"(\w+)"', schema)
91
  col_hint = ", ".join(col_match) if col_match else ""
92
- prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
93
-
94
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
95
-
 
 
 
96
  with torch.no_grad():
97
- outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
98
-
 
 
 
 
99
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
100
 
101
- # --- THE CRITICAL FIXES ---
 
 
102
 
103
- # Fix 1: Force the table name to be the lowercase version we found in Step 1
104
- # This stops the "City_day" vs "city_day" conflict.
105
- sql = re.sub(r'\bFROM\s+["\w]+', f'FROM {quoted}', sql, flags=re.IGNORECASE)
106
- sql = re.sub(r'\bJOIN\s+["\w]+', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
107
-
108
- # Fix 2: Remove junk tokens that T5 inserts after the table name
109
- sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR|UNION)(\w+)', r'\1', sql, flags=re.IGNORECASE)
110
-
111
- # Fix 3: Standardize common column case issues
112
- # If the model writes "City", we make sure it matches the schema's "City"
113
- for col in col_match:
114
- sql = re.sub(rf'\b{col}\b', f'"{col}"', sql, flags=re.IGNORECASE)
115
 
 
116
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
117
  sql = f'SELECT * FROM {quoted} LIMIT 10'
118
 
119
  return sql
120
 
 
121
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
122
  """Run SQL against the in-memory SQLite DB."""
123
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
124
  tmp.write(db_bytes)
125
  tmp_path = tmp.name
126
-
127
  conn = sqlite3.connect(tmp_path)
128
  conn.row_factory = sqlite3.Row
129
  try:
@@ -132,23 +138,23 @@ def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
132
  except Exception as e:
133
  conn.close()
134
  os.unlink(tmp_path)
135
- # Return error as a list to be caught by JSONResponse
136
- raise HTTPException(status_code=400, detail=f"SQL error: {str(e)}")
137
-
138
  conn.close()
139
  os.unlink(tmp_path)
140
  return rows
141
 
 
142
  # ── Routes ─────────────────────────────────────────────────────────────────────
143
  class QueryRequest(BaseModel):
144
  session_id: str
145
  question: str
146
 
 
147
  @app.post("/upload")
148
  async def upload_csv(file: UploadFile = File(...)):
 
149
  if not file.filename.endswith(".csv"):
150
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
151
-
152
  contents = await file.read()
153
  try:
154
  df = pd.read_csv(io.BytesIO(contents))
@@ -156,38 +162,39 @@ async def upload_csv(file: UploadFile = File(...)):
156
  raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
157
 
158
  session_id = os.urandom(8).hex()
159
- # Clean table name from filename
160
- raw_name = os.path.splitext(file.filename)[0]
161
- table_name = re.sub(r"[^a-zA-Z0-9_]", "_", raw_name)[:32] or "data"
162
-
163
  db_bytes = csv_to_sqlite(df, table_name)
164
  schema = get_schema(db_bytes)
165
 
166
  _db_store[session_id] = db_bytes
167
  _schema_store[session_id] = schema
168
 
 
 
169
  return JSONResponse({
170
  "session_id": session_id,
171
  "table_name": table_name,
172
- "columns": list(df.columns),
173
  "row_count": len(df),
174
- "preview": df.head(5).to_dict(orient="records"),
175
  "schema": schema,
176
  })
177
 
 
178
  @app.post("/query")
179
  async def query(req: QueryRequest):
 
180
  if req.session_id not in _db_store:
181
- raise HTTPException(status_code=404, detail="Session not found. Upload CSV first.")
182
-
183
  schema = _schema_store[req.session_id]
184
  sql = generate_sql(req.question, schema)
185
-
186
- # This is where your previous code was likely failing
187
  results = execute_sql(sql, _db_store[req.session_id])
188
-
189
  return JSONResponse({"sql": sql, "results": results})
190
 
 
191
  @app.get("/health")
192
  def health():
193
  return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
 
 
1
  """
2
+ app.py β€” Model: google/flan-t5-large (Text-to-SQL)
3
  HuggingFace Space: Free Tier (CPU)
4
  """
5
 
 
19
  import torch
20
 
21
  # ── Config ────────────────────────────────────────────────────────────────────
22
+ MODEL_NAME = "cssupport/t5-small-awesome-text-to-sql" # T5-based text→SQL, CPU-friendly
23
  MAX_NEW_TOKENS = 256
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  # ── Load model once at startup ─────────────────────────────────────────────────
27
+ print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
30
  model.eval()
31
  print("[INFO] Model ready.")
32
 
33
  # ── In-memory DB store ─────────────────────────────────────────────────────────
34
+ _db_store: dict[str, bytes] = {} # session_id β†’ sqlite db bytes
35
+ _schema_store: dict[str, str] = {} # session_id β†’ schema string
36
 
37
  app = FastAPI(title="CSV-to-SQL Chat", version="1.0.0")
38
 
 
50
  def root():
51
  return FileResponse("static/index.html")
52
 
53
+
54
  # ── Helpers ────────────────────────────────────────────────────────────────────
55
  def csv_to_sqlite(df: pd.DataFrame, table_name: str = "data") -> bytes:
56
  """Convert DataFrame β†’ SQLite DB bytes."""
57
+ buf = io.BytesIO()
58
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
59
  tmp_path = tmp.name
60
  conn = sqlite3.connect(tmp_path)
61
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
 
 
62
  conn.close()
63
  with open(tmp_path, "rb") as f:
64
  db_bytes = f.read()
65
  os.unlink(tmp_path)
66
  return db_bytes
67
 
68
+
69
  def get_schema(db_bytes: bytes) -> str:
70
  """Extract CREATE TABLE schema from DB bytes."""
71
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
 
79
  os.unlink(tmp_path)
80
  return "\n".join(r[0] for r in rows if r[0])
81
 
82
+
83
  def generate_sql(question: str, schema: str) -> str:
84
+ """Run T5 inference to produce SQL."""
85
+ # Extract table name from schema
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
+ table_name = table_match.group(1) if table_match else "data"
 
88
  quoted = f'"{table_name}"'
89
 
90
+ # Extract column names to inject into prompt β€” helps T5-small stay grounded
91
  col_match = re.findall(r'"(\w+)"', schema)
92
  col_hint = ", ".join(col_match) if col_match else ""
93
+ prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
94
+ inputs = tokenizer(
95
+ prompt,
96
+ return_tensors="pt",
97
+ truncation=True,
98
+ max_length=512,
99
+ ).to(DEVICE)
100
  with torch.no_grad():
101
+ outputs = model.generate(
102
+ **inputs,
103
+ max_new_tokens=MAX_NEW_TOKENS,
104
+ num_beams=4,
105
+ early_stopping=True,
106
+ )
107
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
108
 
109
+ # Fix 1: replace any FROM/JOIN table reference (quoted or unquoted) with correct table
110
+ sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
111
+ sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
112
 
113
+ # Fix 2: strip junk tokens after table name before LIMIT/WHERE/ORDER etc.
114
+ # e.g. FROM "city_day" Datetime LIMIT 10 β†’ FROM "city_day" LIMIT 10
115
+ sql = re.sub(
116
+ r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)',
117
+ r'\1',
118
+ sql, flags=re.IGNORECASE
119
+ )
 
 
 
 
 
120
 
121
+ # Fix 3: fallback if no SELECT at all
122
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
123
  sql = f'SELECT * FROM {quoted} LIMIT 10'
124
 
125
  return sql
126
 
127
+
128
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
129
  """Run SQL against the in-memory SQLite DB."""
130
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
131
  tmp.write(db_bytes)
132
  tmp_path = tmp.name
 
133
  conn = sqlite3.connect(tmp_path)
134
  conn.row_factory = sqlite3.Row
135
  try:
 
138
  except Exception as e:
139
  conn.close()
140
  os.unlink(tmp_path)
141
+ raise HTTPException(status_code=400, detail=f"SQL error: {e}")
 
 
142
  conn.close()
143
  os.unlink(tmp_path)
144
  return rows
145
 
146
+
147
  # ── Routes ─────────────────────────────────────────────────────────────────────
148
  class QueryRequest(BaseModel):
149
  session_id: str
150
  question: str
151
 
152
+
153
  @app.post("/upload")
154
  async def upload_csv(file: UploadFile = File(...)):
155
+ """Upload CSV β†’ parse β†’ store as SQLite β†’ return session_id & preview."""
156
  if not file.filename.endswith(".csv"):
157
  raise HTTPException(status_code=400, detail="Only CSV files accepted.")
 
158
  contents = await file.read()
159
  try:
160
  df = pd.read_csv(io.BytesIO(contents))
 
162
  raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
163
 
164
  session_id = os.urandom(8).hex()
165
+ table_name = re.sub(r"[^a-zA-Z0-9_]", "_", os.path.splitext(file.filename)[0])[:32] or "data"
166
+ if table_name[0].isdigit():
167
+ table_name = "t_" + table_name
 
168
  db_bytes = csv_to_sqlite(df, table_name)
169
  schema = get_schema(db_bytes)
170
 
171
  _db_store[session_id] = db_bytes
172
  _schema_store[session_id] = schema
173
 
174
+ preview = df.head(5).to_dict(orient="records")
175
+ columns = list(df.columns)
176
  return JSONResponse({
177
  "session_id": session_id,
178
  "table_name": table_name,
179
+ "columns": columns,
180
  "row_count": len(df),
181
+ "preview": preview,
182
  "schema": schema,
183
  })
184
 
185
+
186
  @app.post("/query")
187
  async def query(req: QueryRequest):
188
+ """Natural language question β†’ SQL β†’ execute β†’ return results."""
189
  if req.session_id not in _db_store:
190
+ raise HTTPException(status_code=404, detail="Session not found. Please upload CSV first.")
 
191
  schema = _schema_store[req.session_id]
192
  sql = generate_sql(req.question, schema)
 
 
193
  results = execute_sql(sql, _db_store[req.session_id])
 
194
  return JSONResponse({"sql": sql, "results": results})
195
 
196
+
197
  @app.get("/health")
198
  def health():
199
  return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
200
+