hmyunis commited on
Commit
8dfd7bb
·
1 Parent(s): b7d2fa2

Enhance SQL generation by adding schema formatting function and increasing semantic search context

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
  # --- CONFIGURATION ---
7
- # UPDATE THIS to the new model you just trained
8
  FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom"
9
 
10
  print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
@@ -16,6 +15,25 @@ try:
16
  except Exception as e:
17
  print(f"CRITICAL ERROR LOADING MODELS: {e}")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_sql_pipeline(question, all_columns_str):
20
  print(f"Input Q: {question}")
21
 
@@ -26,23 +44,25 @@ def get_sql_pipeline(question, all_columns_str):
26
  # 2. Schema Linking (Embeddings)
27
  question_embedding = embedder.encode(question, convert_to_tensor=True)
28
  column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
29
- hits = util.semantic_search(question_embedding, column_embeddings, top_k=6)
 
 
30
  relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
31
 
32
- # 3. Formulate Prompt (MATCHES TRAINING EXACTLY)
33
- schema_context = ", ".join(relevant_cols)
34
- # Note the prefix change: "translate English to SQL"
 
35
  input_text = f"translate English to SQL: {question} </s> {schema_context}"
36
  print(f"Prompt: {input_text}")
37
 
38
  # 4. Generate
39
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
40
 
41
- # Use beam search for better accuracy (slower but worth it)
42
  outputs = model.generate(
43
  input_ids,
44
  max_length=128,
45
- num_beams=4, # Inspects 4 possible paths
46
  early_stopping=True
47
  )
48
 
@@ -54,6 +74,5 @@ def get_sql_pipeline(question, all_columns_str):
54
  except Exception as e:
55
  return f"Error: {str(e)}"
56
 
57
- # Simpler Interface
58
  iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text")
59
  iface.launch()
 
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
  # --- CONFIGURATION ---
 
7
  FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom"
8
 
9
  print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
 
15
  except Exception as e:
16
  print(f"CRITICAL ERROR LOADING MODELS: {e}")
17
 
18
+ def format_schema_like_training(raw_column_list):
19
+ """
20
+ Transforms ['api_customer.name', 'api_customer.city', 'api_order.id']
21
+ Into: "api_customer: name, city | api_order: id"
22
+
23
+ This matches the pattern the model saw during training.
24
+ """
25
+ schema_map = {}
26
+ for item in raw_column_list:
27
+ if "." in item:
28
+ table, col = item.split('.', 1)
29
+ if table not in schema_map:
30
+ schema_map[table] = []
31
+ schema_map[table].append(col)
32
+
33
+ # Join nicely
34
+ parts = [f"{table}: {', '.join(cols)}" for table, cols in schema_map.items()]
35
+ return " | ".join(parts)
36
+
37
  def get_sql_pipeline(question, all_columns_str):
38
  print(f"Input Q: {question}")
39
 
 
44
  # 2. Schema Linking (Embeddings)
45
  question_embedding = embedder.encode(question, convert_to_tensor=True)
46
  column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
47
+
48
+ # Increase Top-K to 10 to ensure we get enough context from the right table
49
+ hits = util.semantic_search(question_embedding, column_embeddings, top_k=10)
50
  relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
51
 
52
+ # 3. Formulate Prompt (CRITICAL FIX HERE)
53
+ # We re-format the list to look like "table: col1, col2"
54
+ schema_context = format_schema_like_training(relevant_cols)
55
+
56
  input_text = f"translate English to SQL: {question} </s> {schema_context}"
57
  print(f"Prompt: {input_text}")
58
 
59
  # 4. Generate
60
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
61
 
 
62
  outputs = model.generate(
63
  input_ids,
64
  max_length=128,
65
+ num_beams=4,
66
  early_stopping=True
67
  )
68
 
 
74
  except Exception as e:
75
  return f"Error: {str(e)}"
76
 
 
77
  iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text")
78
  iface.launch()