nilotpaldhar2004 commited on
Commit
822614c
Β·
verified Β·
1 Parent(s): bf7ba46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -22
app.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
- QueryMind β€” CSV-to-SQL Engine
3
  Model: T5-Small Hybrid (Regex + Transformer)
4
- Target Hardware: HuggingFace Free Tier (CPU)
5
  """
6
 
7
  import os
@@ -26,13 +26,14 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # ── Model Initialization ──────────────────────────────────────────────────────
28
  print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
29
- # Force use_fast=False to avoid the sentencepiece backend error
 
30
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
31
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
32
  model.eval()
33
  print("[INFO] Model ready.")
34
 
35
- # ── State Management ──────────────────────────────────────────────────────────
36
  _db_store: dict[str, bytes] = {} # session_id -> sqlite db bytes
37
  _schema_store: dict[str, str] = {} # session_id -> create table schema
38
 
@@ -54,11 +55,10 @@ def root():
54
 
55
  # ── Logic Helpers ──────────────────────────────────────────────────────────────
56
  def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
57
- """Safely converts a Pandas DataFrame into a SQLite binary blob."""
58
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
59
  tmp_path = tmp.name
60
  conn = sqlite3.connect(tmp_path)
61
- # Ensure the table name is safe for SQL
62
  safe_table = re.sub(r"[^a-zA-Z0-9_]", "_", table_name)
63
  df.to_sql(safe_table, conn, if_exists="replace", index=False)
64
  conn.close()
@@ -68,7 +68,7 @@ def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
68
  return db_bytes
69
 
70
  def get_schema(db_bytes: bytes) -> str:
71
- """Extracts the exact SQL schema used to create the SQLite table."""
72
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
73
  tmp.write(db_bytes)
74
  tmp_path = tmp.name
@@ -81,8 +81,11 @@ def get_schema(db_bytes: bytes) -> str:
81
  return "\n".join(r[0] for r in rows if r[0])
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
- """Hybrid Engine: Uses smart regex first, falls back to T5 with sanitization."""
85
- # 1. Schema Context
 
 
 
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
  table_name = table_match.group(1) if table_match else "data"
88
  quoted = f'"{table_name}"'
@@ -90,14 +93,14 @@ def generate_sql(question: str, schema: str) -> str:
90
 
91
  q = question.lower().strip()
92
 
93
- # 2. Smart Column Detection (Matches user words to schema)
94
  target_col = None
95
  for col in col_match:
96
  if col.lower() in q:
97
  target_col = col
98
  break
99
 
100
- # 3. Deterministic Regex Layer (High Accuracy, Zero Latency)
101
 
102
  # DISTINCT/UNIQUE
103
  if re.search(r'unique|distinct', q):
@@ -114,20 +117,19 @@ def generate_sql(question: str, schema: str) -> str:
114
  num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[0])
115
  return f'SELECT AVG("{num_col}") FROM {quoted}'
116
 
117
- # COUNT/HOW MANY
118
  if re.search(r'count|total|how many', q):
119
- # Handle word searches (e.g. "count Paris")
120
  if target_col and len(q.split()) > 2:
121
  return f'SELECT COUNT(*) FROM {quoted} WHERE "{target_col}" LIKE "%{q.split()[-1]}%"'
122
  return f'SELECT COUNT(*) FROM {quoted}'
123
 
124
- # LIMIT/TOP
125
  if re.search(r'show|display|get|first|top', q):
126
  n_match = re.search(r'\d+', q)
127
  limit = n_match.group() if n_match else 10
128
  return f'SELECT * FROM {quoted} LIMIT {limit}'
129
 
130
- # 4. Transformer Fallback (Probabilistic Reasoning)
131
  col_hint = ", ".join(col_match) if col_match else ""
132
  prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
133
 
@@ -137,23 +139,19 @@ def generate_sql(question: str, schema: str) -> str:
137
 
138
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
139
 
140
- # ── Output Sanitization Guardrails ──
141
- # Remove T5 artifacts (pipes, prompt echoes)
142
  if "|" in sql: sql = sql.split("|")[-1].strip()
143
  sql = re.sub(r'^(sql|query|table):', '', sql, flags=re.IGNORECASE).strip()
144
-
145
- # Force correct table references
146
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
147
  sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
148
 
149
- # Final check for valid SELECT
150
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
151
  sql = f'SELECT * FROM {quoted} LIMIT 10'
152
 
153
  return sql
154
 
155
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
156
- """Runs SQL against the binary blob by creating a temporary local SQLite DB."""
157
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
158
  tmp.write(db_bytes)
159
  tmp_path = tmp.name
@@ -187,7 +185,6 @@ async def upload_csv(file: UploadFile = File(...)):
187
  raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
188
 
189
  session_id = os.urandom(8).hex()
190
- # Clean the filename to create a valid SQLite table name
191
  raw_name = os.path.splitext(file.filename)[0]
192
  table_name = re.sub(r"[^a-zA-Z0-9_]", "_", raw_name)[:32] or "data"
193
  if table_name[0].isdigit(): table_name = "t_" + table_name
 
1
  """
2
+ QueryMind β€” CSV-to-SQL Engine (Final Production Version)
3
  Model: T5-Small Hybrid (Regex + Transformer)
4
+ Hardware: HuggingFace Free Tier (CPU)
5
  """
6
 
7
  import os
 
26
 
27
  # ── Model Initialization ──────────────────────────────────────────────────────
28
  print(f"[INFO] Loading model: {MODEL_NAME} | device: {DEVICE}")
29
+
30
+ # CRITICAL: use_fast=False fixes the sentencepiece/backend tokenizer error on CPU
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
32
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(DEVICE)
33
  model.eval()
34
  print("[INFO] Model ready.")
35
 
36
+ # ── State Management (In-Memory) ──────────────────────────────────────────────
37
  _db_store: dict[str, bytes] = {} # session_id -> sqlite db bytes
38
  _schema_store: dict[str, str] = {} # session_id -> create table schema
39
 
 
55
 
56
  # ── Logic Helpers ──────────────────────────────────────────────────────────────
57
  def csv_to_sqlite(df: pd.DataFrame, table_name: str) -> bytes:
58
+ """Converts Pandas DataFrame into a portable SQLite binary blob."""
59
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
60
  tmp_path = tmp.name
61
  conn = sqlite3.connect(tmp_path)
 
62
  safe_table = re.sub(r"[^a-zA-Z0-9_]", "_", table_name)
63
  df.to_sql(safe_table, conn, if_exists="replace", index=False)
64
  conn.close()
 
68
  return db_bytes
69
 
70
  def get_schema(db_bytes: bytes) -> str:
71
+ """Extracts the SQL schema used to create the table."""
72
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
73
  tmp.write(db_bytes)
74
  tmp_path = tmp.name
 
81
  return "\n".join(r[0] for r in rows if r[0])
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
+ """
85
+ Dual-Stream SQL Generation:
86
+ 1. Deterministic (Regex) - Matches common analysis patterns.
87
+ 2. Probabilistic (T5) - Handles complex phrasing as fallback.
88
+ """
89
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
90
  table_name = table_match.group(1) if table_match else "data"
91
  quoted = f'"{table_name}"'
 
93
 
94
  q = question.lower().strip()
95
 
96
+ # Smart Column Detection
97
  target_col = None
98
  for col in col_match:
99
  if col.lower() in q:
100
  target_col = col
101
  break
102
 
103
+ # ── Deterministic Layer ──
104
 
105
  # DISTINCT/UNIQUE
106
  if re.search(r'unique|distinct', q):
 
117
  num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[0])
118
  return f'SELECT AVG("{num_col}") FROM {quoted}'
119
 
120
+ # COUNT
121
  if re.search(r'count|total|how many', q):
 
122
  if target_col and len(q.split()) > 2:
123
  return f'SELECT COUNT(*) FROM {quoted} WHERE "{target_col}" LIKE "%{q.split()[-1]}%"'
124
  return f'SELECT COUNT(*) FROM {quoted}'
125
 
126
+ # LIMIT
127
  if re.search(r'show|display|get|first|top', q):
128
  n_match = re.search(r'\d+', q)
129
  limit = n_match.group() if n_match else 10
130
  return f'SELECT * FROM {quoted} LIMIT {limit}'
131
 
132
+ # ── Probabilistic Fallback ──
133
  col_hint = ", ".join(col_match) if col_match else ""
134
  prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
135
 
 
139
 
140
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
 
142
+ # Output Sanitization
 
143
  if "|" in sql: sql = sql.split("|")[-1].strip()
144
  sql = re.sub(r'^(sql|query|table):', '', sql, flags=re.IGNORECASE).strip()
 
 
145
  sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
146
  sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
147
 
 
148
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
149
  sql = f'SELECT * FROM {quoted} LIMIT 10'
150
 
151
  return sql
152
 
153
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
154
+ """Runs SQL against the binary blob via a temporary SQLite instance."""
155
  with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
156
  tmp.write(db_bytes)
157
  tmp_path = tmp.name
 
185
  raise HTTPException(status_code=400, detail=f"CSV parse error: {e}")
186
 
187
  session_id = os.urandom(8).hex()
 
188
  raw_name = os.path.splitext(file.filename)[0]
189
  table_name = re.sub(r"[^a-zA-Z0-9_]", "_", raw_name)[:32] or "data"
190
  if table_name[0].isdigit(): table_name = "t_" + table_name