nilotpaldhar2004 commited on
Commit
3d13366
·
unverified ·
1 Parent(s): e870039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -79,14 +79,16 @@ def get_schema(db_bytes: bytes) -> str:
79
  return "\n".join(r[0] for r in rows if r[0])
80
 
81
  def generate_sql(question: str, schema: str) -> str:
82
- """Run T5 inference with enhanced regex fixes."""
 
83
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
84
- table_name = table_match.group(1) if table_match else "data"
 
85
  quoted = f'"{table_name}"'
86
 
 
87
  col_match = re.findall(r'"(\w+)"', schema)
88
  col_hint = ", ".join(col_match) if col_match else ""
89
-
90
  prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
91
 
92
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
@@ -96,12 +98,20 @@ def generate_sql(question: str, schema: str) -> str:
96
 
97
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
98
 
99
- # FIXES
100
- sql = re.sub(r'\bFROM\s+("?\w+"?)', f'FROM {quoted}', sql, flags=re.IGNORECASE)
101
- sql = re.sub(r'\bJOIN\s+("?\w+"?)', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
 
 
 
102
 
103
- # Strip trailing junk
104
- sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|LEFT|RIGHT|INNER|ON|AND|OR|\d)(\w+)', r'\1', sql, flags=re.IGNORECASE)
 
 
 
 
 
105
 
106
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
107
  sql = f'SELECT * FROM {quoted} LIMIT 10'
 
79
  return "\n".join(r[0] for r in rows if r[0])
80
 
81
  def generate_sql(question: str, schema: str) -> str:
82
+ """Run T5 inference with strict case-sensitivity fixes."""
83
+ # 1. Force lowercase table name detection from schema
84
  table_match = re.search(r'CREATE TABLE\s+"?(\w+)"?', schema, re.IGNORECASE)
85
+ # We explicitly lowercase this to match the SQLite storage
86
+ table_name = table_match.group(1).lower() if table_match else "city_day"
87
  quoted = f'"{table_name}"'
88
 
89
+ # 2. Build the prompt with explicit lowercase hints
90
  col_match = re.findall(r'"(\w+)"', schema)
91
  col_hint = ", ".join(col_match) if col_match else ""
 
92
  prompt = f"Translate English to SQL: {question} | Table: {table_name} | Columns: {col_hint}"
93
 
94
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
 
98
 
99
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
100
 
101
+ # --- THE CRITICAL FIXES ---
102
+
103
+ # Fix 1: Force the table name to be the lowercase version we found in Step 1
104
+ # This stops the "City_day" vs "city_day" conflict.
105
+ sql = re.sub(r'\bFROM\s+["\w]+', f'FROM {quoted}', sql, flags=re.IGNORECASE)
106
+ sql = re.sub(r'\bJOIN\s+["\w]+', f'JOIN {quoted}', sql, flags=re.IGNORECASE)
107
 
108
+ # Fix 2: Remove junk tokens that T5 inserts after the table name
109
+ sql = re.sub(r'(FROM\s+"?\w+"?)\s+(?!WHERE|LIMIT|ORDER|GROUP|HAVING|JOIN|ON|AND|OR|UNION)(\w+)', r'\1', sql, flags=re.IGNORECASE)
110
+
111
+ # Fix 3: Standardize common column case issues
112
+ # If the model writes "City", we make sure it matches the schema's "City"
113
+ for col in col_match:
114
+ sql = re.sub(rf'\b{col}\b', f'"{col}"', sql, flags=re.IGNORECASE)
115
 
116
  if not re.search(r'\bSELECT\b', sql, re.IGNORECASE):
117
  sql = f'SELECT * FROM {quoted} LIMIT 10'