nilotpaldhar2004 commited on
Commit
1e5473b
·
verified ·
1 Parent(s): 1d456a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -21
app.py CHANGED
@@ -97,36 +97,55 @@ def generate_sql(question: str, schema: str) -> str:
97
 
98
  # 2. Advanced Rule-Based Shortcuts
99
 
100
- # FILTERING (e.g., "is Paris", "where answer is Paris")
101
- if "is" in q or "=" in q:
102
- # Extract the value (e.g., "Paris")
103
- value_match = re.search(r"is\s+(['\"]?\w+['\"]?)", q)
104
- if value_match:
105
- val = value_match.group(1).strip("'\"")
106
- # If "question" is in the text, user probably wants the question for that answer
107
- select_col = col_match[0] if "question" in q else "*"
108
- filter_col = target_col if target_col else col_match[1]
 
 
 
109
  return f'SELECT "{select_col}" FROM {quoted} WHERE "{filter_col}" LIKE "%{val}%"'
110
 
111
- # SELECT DISTINCT (List the names) vs COUNT DISTINCT (How many)
112
  if re.search(r'unique|distinct', q):
113
  col = target_col if target_col else col_match[0]
114
- if re.search(r'show|list|get|give|what are', q):
115
  return f'SELECT DISTINCT "{col}" FROM {quoted} LIMIT 50'
116
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
117
 
118
- # SPECIFIC COLUMN SELECTION (e.g., "show all answers")
119
- if re.search(r'show|list|get', q) and target_col:
120
- if not re.search(r'count|avg|mean|sum', q):
121
- return f'SELECT "{target_col}" FROM {quoted} LIMIT 50'
122
 
123
- # GROUP BY
124
- if re.search(r'group.*by|per|each', q):
125
- col = target_col if target_col else col_match[0]
126
- return f'SELECT "{col}", COUNT(*) FROM {quoted} GROUP BY "{col}"'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- # 3. T5 Fallback (Existing logic)
129
- # ... [Keep your T5 code and Sanitization here] ...
130
  return sql
131
 
132
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]:
 
97
 
98
  # 2. Advanced Rule-Based Shortcuts
99
 
100
+ # FILTERING (e.g., "ans is Asia")
101
+ if "is" in q or "where" in q:
102
+ # Improved value extraction: look for the last word in the sentence
103
+ words = q.split()
104
+ val = words[-1].strip("?.!")
105
+
106
+ # Determine columns
107
+ select_col = col_match[0] if "question" in q else "*"
108
+ filter_col = target_col if target_col else (col_match[1] if len(col_match)>1 else col_match[0])
109
+
110
+ # Don't trigger if the 'value' is just a common instruction word
111
+ if val not in ["null", "not", "the", "average", "rows"]:
112
  return f'SELECT "{select_col}" FROM {quoted} WHERE "{filter_col}" LIKE "%{val}%"'
113
 
114
+ # SELECT DISTINCT (List) vs COUNT DISTINCT (Number)
115
  if re.search(r'unique|distinct', q):
116
  col = target_col if target_col else col_match[0]
117
+ if re.search(r'show|list|get|give', q):
118
  return f'SELECT DISTINCT "{col}" FROM {quoted} LIMIT 50'
119
  return f'SELECT COUNT(DISTINCT "{col}") FROM {quoted}'
120
 
121
+ # AGGREGATIONS
122
+ if re.search(r'average|mean|avg', q):
123
+ num_col = target_col if target_col else (col_match[1] if len(col_match)>1 else col_match[0])
124
+ return f'SELECT AVG("{num_col}") FROM {quoted}'
125
 
126
+ # LIMIT/SHOW
127
+ if re.search(r'show|display|get|first|top', q) and not target_col:
128
+ n_match = re.search(r'\d+', q)
129
+ return f'SELECT * FROM {quoted} LIMIT {n_match.group() if n_match else 10}'
130
+
131
+ # 3. Transformer Fallback (MANDATORY FIX)
132
+ # Ensure this part is NOT skipped
133
+ col_hint = ", ".join(col_match) if col_match else ""
134
+ prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
135
+
136
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
137
+ with torch.no_grad():
138
+ outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, num_beams=4, early_stopping=True)
139
+
140
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
+
142
+ # Sanitization
143
+ if "|" in sql: sql = sql.split("|")[-1].strip()
144
+ sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
145
+
146
+ if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
147
+ sql = f'SELECT * FROM {quoted} LIMIT 10'
148
 
 
 
149
  return sql
150
 
151
  def execute_sql(sql: str, db_bytes: bytes) -> list[dict]: