Spaces:
Sleeping
Sleeping
Enhance SQL generation by adding schema formatting function and increasing semantic search context
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
| 5 |
|
| 6 |
# --- CONFIGURATION ---
|
| 7 |
-
# UPDATE THIS to the new model you just trained
|
| 8 |
FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom"
|
| 9 |
|
| 10 |
print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
|
|
@@ -16,6 +15,25 @@ try:
|
|
| 16 |
except Exception as e:
|
| 17 |
print(f"CRITICAL ERROR LOADING MODELS: {e}")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def get_sql_pipeline(question, all_columns_str):
|
| 20 |
print(f"Input Q: {question}")
|
| 21 |
|
|
@@ -26,23 +44,25 @@ def get_sql_pipeline(question, all_columns_str):
|
|
| 26 |
# 2. Schema Linking (Embeddings)
|
| 27 |
question_embedding = embedder.encode(question, convert_to_tensor=True)
|
| 28 |
column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
|
| 31 |
|
| 32 |
-
# 3. Formulate Prompt (
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
input_text = f"translate English to SQL: {question} </s> {schema_context}"
|
| 36 |
print(f"Prompt: {input_text}")
|
| 37 |
|
| 38 |
# 4. Generate
|
| 39 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
| 40 |
|
| 41 |
-
# Use beam search for better accuracy (slower but worth it)
|
| 42 |
outputs = model.generate(
|
| 43 |
input_ids,
|
| 44 |
max_length=128,
|
| 45 |
-
num_beams=4,
|
| 46 |
early_stopping=True
|
| 47 |
)
|
| 48 |
|
|
@@ -54,6 +74,5 @@ def get_sql_pipeline(question, all_columns_str):
|
|
| 54 |
except Exception as e:
|
| 55 |
return f"Error: {str(e)}"
|
| 56 |
|
| 57 |
-
# Simpler Interface
|
| 58 |
iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text")
|
| 59 |
iface.launch()
|
|
|
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
| 5 |
|
| 6 |
# --- CONFIGURATION ---
|
|
|
|
| 7 |
FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom"
|
| 8 |
|
| 9 |
print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
|
|
|
|
| 15 |
except Exception as e:
|
| 16 |
print(f"CRITICAL ERROR LOADING MODELS: {e}")
|
| 17 |
|
| 18 |
+
def format_schema_like_training(raw_column_list):
|
| 19 |
+
"""
|
| 20 |
+
Transforms ['api_customer.name', 'api_customer.city', 'api_order.id']
|
| 21 |
+
Into: "api_customer: name, city | api_order: id"
|
| 22 |
+
|
| 23 |
+
This matches the pattern the model saw during training.
|
| 24 |
+
"""
|
| 25 |
+
schema_map = {}
|
| 26 |
+
for item in raw_column_list:
|
| 27 |
+
if "." in item:
|
| 28 |
+
table, col = item.split('.', 1)
|
| 29 |
+
if table not in schema_map:
|
| 30 |
+
schema_map[table] = []
|
| 31 |
+
schema_map[table].append(col)
|
| 32 |
+
|
| 33 |
+
# Join nicely
|
| 34 |
+
parts = [f"{table}: {', '.join(cols)}" for table, cols in schema_map.items()]
|
| 35 |
+
return " | ".join(parts)
|
| 36 |
+
|
| 37 |
def get_sql_pipeline(question, all_columns_str):
|
| 38 |
print(f"Input Q: {question}")
|
| 39 |
|
|
|
|
| 44 |
# 2. Schema Linking (Embeddings)
|
| 45 |
question_embedding = embedder.encode(question, convert_to_tensor=True)
|
| 46 |
column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
|
| 47 |
+
|
| 48 |
+
# Increase Top-K to 10 to ensure we get enough context from the right table
|
| 49 |
+
hits = util.semantic_search(question_embedding, column_embeddings, top_k=10)
|
| 50 |
relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
|
| 51 |
|
| 52 |
+
# 3. Formulate Prompt (CRITICAL FIX HERE)
|
| 53 |
+
# We re-format the list to look like "table: col1, col2"
|
| 54 |
+
schema_context = format_schema_like_training(relevant_cols)
|
| 55 |
+
|
| 56 |
input_text = f"translate English to SQL: {question} </s> {schema_context}"
|
| 57 |
print(f"Prompt: {input_text}")
|
| 58 |
|
| 59 |
# 4. Generate
|
| 60 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
|
| 61 |
|
|
|
|
| 62 |
outputs = model.generate(
|
| 63 |
input_ids,
|
| 64 |
max_length=128,
|
| 65 |
+
num_beams=4,
|
| 66 |
early_stopping=True
|
| 67 |
)
|
| 68 |
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
return f"Error: {str(e)}"
|
| 76 |
|
|
|
|
| 77 |
iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text")
|
| 78 |
iface.launch()
|