nilotpaldhar2004 commited on
Commit
471250f
Β·
unverified Β·
1 Parent(s): 7073cc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -86,9 +86,34 @@ def generate_sql(question: str, schema: str) -> str:
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
  table_name = table_match.group(1) if table_match else "data"
88
  quoted = f'"{table_name}"'
89
-
90
- # Extract column names to inject into prompt β€” helps T5-small stay grounded
91
  col_match = re.findall(r'"(\w+)"', schema)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  col_hint = ", ".join(col_match) if col_match else ""
93
  prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
94
  inputs = tokenizer(
@@ -197,4 +222,3 @@ async def query(req: QueryRequest):
197
  @app.get("/health")
198
  def health():
199
  return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}
200
-
 
86
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
87
  table_name = table_match.group(1) if table_match else "data"
88
  quoted = f'"{table_name}"'
 
 
89
  col_match = re.findall(r'"(\w+)"', schema)
90
+
91
+ # ── Rule-based shortcuts (fast + accurate) ────────────────────────────────
92
+ q = question.lower().strip()
93
+ if re.search(r'show.*(first|top).*\d+|first.*\d+.*row|top.*\d+', q):
94
+ n = re.search(r'\d+', q)
95
+ return f'SELECT * FROM {quoted} LIMIT {n.group() if n else 10}'
96
+ if re.search(r'(show|display|get|give).*(first|all).*row|first.*row|show.*row', q):
97
+ return f'SELECT * FROM {quoted} LIMIT 10'
98
+ if re.search(r'count.*(total|all|record|row)|total.*(record|row|count)|how many', q):
99
+ return f'SELECT COUNT(*) FROM {quoted}'
100
+ if re.search(r'show.*(all|every).*row|all.*row|select all', q):
101
+ return f'SELECT * FROM {quoted} LIMIT 50'
102
+ if re.search(r'average|avg', q) and col_match:
103
+ num_col = next((c for c in col_match if re.search(r'num|price|val|amt|count|qty|sal|rev|cost|pm|aqi|no|co|so|o3', c, re.I)), col_match[1] if len(col_match) > 1 else col_match[0])
104
+ return f'SELECT AVG("{num_col}") FROM {quoted}'
105
+ if re.search(r'unique|distinct', q) and col_match:
106
+ return f'SELECT COUNT(DISTINCT "{col_match[0]}") FROM {quoted}'
107
+ if re.search(r'group by', q) and col_match:
108
+ return f'SELECT "{col_match[0]}", COUNT(*) FROM {quoted} GROUP BY "{col_match[0]}"'
109
+ if re.search(r'max|maximum|highest', q) and col_match:
110
+ num_col = col_match[1] if len(col_match) > 1 else col_match[0]
111
+ return f'SELECT MAX("{num_col}") FROM {quoted}'
112
+ if re.search(r'min|minimum|lowest', q) and col_match:
113
+ num_col = col_match[1] if len(col_match) > 1 else col_match[0]
114
+ return f'SELECT MIN("{num_col}") FROM {quoted}'
115
+
116
+ # ── T5 model fallback ─────────────────────────────────────────────────────
117
  col_hint = ", ".join(col_match) if col_match else ""
118
  prompt = f"tables:\n{schema}\ncolumns: {col_hint}\nquery for: {question}"
119
  inputs = tokenizer(
 
222
  @app.get("/health")
223
  def health():
224
  return {"status": "ok", "model": MODEL_NAME, "device": DEVICE}