nilotpaldhar2004 commited on
Commit
2ea4813
·
verified ·
1 Parent(s): 4ebb141

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -25
app.py CHANGED
@@ -81,53 +81,38 @@ def get_schema(db_bytes: bytes) -> str:
81
 
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
- """
85
- Enhanced Hybrid SQL Engine.
86
- Priority 1: Smart Regex (Deterministic & Instant)
87
- Priority 2: T5 Transformer (Probabilistic Fallback)
88
- """
89
- # 1. Context Extraction
90
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
91
  table_name = table_match.group(1) if table_match else "data"
92
  quoted = f'"{table_name}"'
93
  col_match = re.findall(r'"(\w+)"', schema)
94
-
95
  q = question.lower().strip()
96
 
97
  # 2. Smart Column Detection
98
- # Searches for a column name from the schema within the user's question
99
  target_col = None
100
  for col in col_match:
101
  if col.lower() in q:
102
  target_col = col
103
  break
104
 
105
- # 3. Enhanced Rule-Based Shortcuts (Smart Logic)
106
-
107
- # DISTINCT/UNIQUE COUNT
108
  if re.search(r'unique|distinct', q):
109
  col = target_col if target_col else (col_match[0] if col_match else "*")
110
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
111
 
112
- # GROUP BY
113
  if re.search(r'group.*by|per|each', q):
114
  col = target_col if target_col else (col_match[0] if col_match else "data")
115
  return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
116
 
117
- # AVERAGE (With semantic fallback for your city_day dataset)
118
- if re.search(r'average|avg|mean', q):
119
- num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[2] if len(col_match)>2 else col_match[0])
120
- return f'SELECT AVG("{num_col}") FROM {quoted}'
121
-
122
- # TOTAL RECORDS
123
- if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
124
  return f'SELECT COUNT(*) FROM {quoted}'
125
 
126
- # LIMIT/TOP ROWS
127
  if re.search(r'show|display|get|first|top', q):
128
  n_match = re.search(r'\d+', q)
129
- limit = n_match.group() if n_match else 10
130
- return f'SELECT * FROM {quoted} LIMIT {limit}'
131
 
132
  # 4. T5 Model Fallback
133
  col_hint = ", ".join(col_match) if col_match else ""
@@ -139,10 +124,16 @@ def generate_sql(question: str, schema: str) -> str:
139
 
140
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
 
142
- # Post-inference cleaning (Crucial for SQLite stability)
143
- sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
144
- sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
 
 
 
 
145
 
 
 
146
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
147
  sql = f'SELECT * FROM {quoted} LIMIT 10'
148
 
 
81
 
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
+ # 1. Context Extraction (Same as before)
 
 
 
 
 
85
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
86
  table_name = table_match.group(1) if table_match else "data"
87
  quoted = f'"{table_name}"'
88
  col_match = re.findall(r'"(\w+)"', schema)
 
89
  q = question.lower().strip()
90
 
91
  # 2. Smart Column Detection
 
92
  target_col = None
93
  for col in col_match:
94
  if col.lower() in q:
95
  target_col = col
96
  break
97
 
98
+ # 3. Enhanced Rule-Based Shortcuts (Deterministic)
 
 
99
  if re.search(r'unique|distinct', q):
100
  col = target_col if target_col else (col_match[0] if col_match else "*")
101
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
102
 
 
103
  if re.search(r'group.*by|per|each', q):
104
  col = target_col if target_col else (col_match[0] if col_match else "data")
105
  return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
106
 
107
+ if re.search(r'count.*(total|all|record|row|paris)|how many', q):
108
+ # Special case for "Count the Paris" -> We search for 'Paris' in all columns
109
+ if "paris" in q:
110
+ return f'SELECT COUNT(*) FROM {quoted} WHERE "answer" LIKE "%Paris%" OR "question" LIKE "%Paris%"'
 
 
 
111
  return f'SELECT COUNT(*) FROM {quoted}'
112
 
 
113
  if re.search(r'show|display|get|first|top', q):
114
  n_match = re.search(r'\d+', q)
115
+ return f'SELECT * FROM {quoted} LIMIT {n_match.group() if n_match else 10}'
 
116
 
117
  # 4. T5 Model Fallback
118
  col_hint = ", ".join(col_match) if col_match else ""
 
124
 
125
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
126
 
127
+ # ── CRITICAL CLEANING GUARDRAIL ──
128
+ # This removes hallucinations like "Table | SQL | Columns" from the output
129
+ if "|" in sql:
130
+ sql = sql.split("|")[-1].strip() # Take only the part after the last pipe
131
+
132
+ # Remove common prefix hallucinations
133
+ sql = re.sub(r'^(sql|query|result|table):', '', sql, flags=re.IGNORECASE).strip()
134
 
135
+ # Force Table and SELECT
136
+ sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
137
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
138
  sql = f'SELECT * FROM {quoted} LIMIT 10'
139