nilotpaldhar2004 commited on
Commit
1d456a9
Β·
verified Β·
1 Parent(s): 822614c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -46
app.py CHANGED
@@ -81,16 +81,11 @@ def get_schema(db_bytes: bytes) -> str:
81
  return "\n".join(r[0] for r in rows if r[0])
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
- """
85
- Dual-Stream SQL Generation:
86
- 1. Deterministic (Regex) - Matches common analysis patterns.
87
- 2. Probabilistic (T5) - Handles complex phrasing as fallback.
88
- """
89
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
90
  table_name = table_match.group(1) if table_match else "data"
91
  quoted = f'"{table_name}"'
92
  col_match = re.findall(r'"(\w+)"', schema)
93
-
94
  q = question.lower().strip()
95
 
96
  # Smart Column Detection
@@ -100,54 +95,38 @@ def generate_sql(question: str, schema: str) -> str:
100
  target_col = col
101
  break
102
 
103
- # ── Deterministic Layer ──
104
 
105
- # DISTINCT/UNIQUE
 
 
 
 
 
 
 
 
 
 
 
106
  if re.search(r'unique|distinct', q):
107
- col = target_col if target_col else (col_match[0] if col_match else "*")
 
 
108
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
109
 
 
 
 
 
 
110
  # GROUP BY
111
  if re.search(r'group.*by|per|each', q):
112
- col = target_col if target_col else (col_match[0] if col_match else "data")
113
  return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
114
 
115
- # AVERAGE
116
- if re.search(r'average|mean|avg', q):
117
- num_col = target_col if target_col else next((c for c in col_match if re.search(r'pm|aqi|no|co|so|o3|benzene|val|amt', c, re.I)), col_match[0])
118
- return f'SELECT AVG("{num_col}") FROM {quoted}'
119
-
120
- # COUNT
121
- if re.search(r'count|total|how many', q):
122
- if target_col and len(q.split()) > 2:
123
- return f'SELECT COUNT(*) FROM {quoted} WHERE "{target_col}" LIKE "%{q.split()[-1]}%"'
124
- return f'SELECT COUNT(*) FROM {quoted}'
125
-
126
- # LIMIT
127
- if re.search(r'show|display|get|first|top', q):
128
- n_match = re.search(r'\d+', q)
129
- limit = n_match.group() if n_match else 10
130
- return f'SELECT * FROM {quoted} LIMIT {limit}'
131
-
132
- # ── Probabilistic Fallback ──
133
- col_hint = ", ".join(col_match) if col_match else ""
134
- prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
135
-
136
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
137
- with torch.no_grad():
138
- outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
139
-
140
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
-
142
- # Output Sanitization
143
- if "|" in sql: sql = sql.split("|")[-1].strip()
144
- sql = re.sub(r'^(sql|query|table):', '', sql, flags=re.IGNORECASE).strip()
145
- sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
146
- sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR)(\w+)', r'\1', sql, flags=re.IGNORECASE)
147
-
148
- if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
149
- sql = f'SELECT * FROM {quoted} LIMIT 10'
150
-
151
  return sql
152
 
153
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
 
81
  return "\n".join(r[0] for r in rows if r[0])
82
 
83
  def generate_sql(question: str, schema: str) -> str:
84
+ # 1. Context Extraction
 
 
 
 
85
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
86
  table_name = table_match.group(1) if table_match else "data"
87
  quoted = f'"{table_name}"'
88
  col_match = re.findall(r'"(\w+)"', schema)
 
89
  q = question.lower().strip()
90
 
91
  # Smart Column Detection
 
95
  target_col = col
96
  break
97
 
98
+ # 2. Advanced Rule-Based Shortcuts
99
 
100
+ # FILTERING (e.g., "is Paris", "where answer is Paris")
101
+ if "is" in q or "=" in q:
102
+ # Extract the value (e.g., "Paris")
103
+ value_match = re.search(r"is\s+(['\"]?\w+['\"]?)", q)
104
+ if value_match:
105
+ val = value_match.group(1).strip("'\"")
106
+ # If "question" is in the text, user probably wants the question for that answer
107
+ select_col = col_match[0] if "question" in q else "*"
108
+ filter_col = target_col if target_col else col_match[1]
109
+ return f'SELECT "{select_col}" FROM {quoted} WHERE "{filter_col}" LIKE "%{val}%"'
110
+
111
+ # SELECT DISTINCT (List the names) vs COUNT DISTINCT (How many)
112
  if re.search(r'unique|distinct', q):
113
+ col = target_col if target_col else col_match[0]
114
+ if re.search(r'show|list|get|give|what are', q):
115
+ return f'SELECT DISTINCT "{col}" FROM {quoted} LIMIT 50'
116
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
117
 
118
+ # SPECIFIC COLUMN SELECTION (e.g., "show all answers")
119
+ if re.search(r'show|list|get', q) and target_col:
120
+ if not re.search(r'count|avg|mean|sum', q):
121
+ return f'SELECT "{target_col}" FROM {quoted} LIMIT 50'
122
+
123
  # GROUP BY
124
  if re.search(r'group.*by|per|each', q):
125
+ col = target_col if target_col else col_match[0]
126
  return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
127
 
128
+ # 3. T5 Fallback (Existing logic)
129
+ # ... [Keep your T5 code and Sanitization here] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  return sql
131
 
132
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]: