hmyunis commited on
Commit
b7d2fa2
·
1 Parent(s): 4b5aaf7

Update model ID and refine SQL generation process in app.py for improved accuracy and clarity

Browse files
Files changed (1) hide show
  1. app.py +27 -30
app.py CHANGED
@@ -4,8 +4,8 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
  # --- CONFIGURATION ---
7
- # Ensure this matches your ACTUAL model on Hugging Face
8
- FINE_TUNED_MODEL_ID = "hmyunis/t5-sql-finetuned"
9
 
10
  print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
11
  try:
@@ -17,46 +17,43 @@ except Exception as e:
17
  print(f"CRITICAL ERROR LOADING MODELS: {e}")
18
 
19
  def get_sql_pipeline(question, all_columns_str):
20
- print(f"Received Question: {question}")
21
- print(f"Received Columns Length: {len(all_columns_str)}")
22
-
23
  try:
24
  # 1. Parse Columns
25
- all_columns = eval(all_columns_str)
26
-
27
  # 2. Schema Linking (Embeddings)
28
  question_embedding = embedder.encode(question, convert_to_tensor=True)
29
  column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
30
  hits = util.semantic_search(question_embedding, column_embeddings, top_k=6)
31
  relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
32
-
33
- # 3. Formulate Prompt
34
  schema_context = ", ".join(relevant_cols)
35
- input_text = f"translate to SQL: {question} </s> {schema_context}"
36
- print(f"Prompt sent to T5: {input_text}")
37
-
 
38
  # 4. Generate
39
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
40
- outputs = model.generate(input_ids, max_length=128)
 
 
 
 
 
 
 
 
41
  generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
-
43
- print(f"Generated Output: '{generated_sql}'")
44
-
45
- # Fallback if empty (Model produced nothing)
46
- if not generated_sql:
47
- return "SELECT * FROM api_customer -- Model returned empty, defaulting."
48
-
49
  return generated_sql
50
 
51
  except Exception as e:
52
- error_msg = f"Error in HF Space: {str(e)}"
53
- print(error_msg)
54
- return error_msg
55
 
56
- # Launch Gradio
57
- iface = gr.Interface(
58
- fn=get_sql_pipeline,
59
- inputs=["text", "text"],
60
- outputs="text"
61
- )
62
- iface.launch()
 
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
  # --- CONFIGURATION ---
7
+ # UPDATE THIS to the new model you just trained
8
+ FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom"
9
 
10
  print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
11
  try:
 
17
  print(f"CRITICAL ERROR LOADING MODELS: {e}")
18
 
19
  def get_sql_pipeline(question, all_columns_str):
20
+ print(f"Input Q: {question}")
21
+
 
22
  try:
23
  # 1. Parse Columns
24
+ all_columns = eval(all_columns_str)
25
+
26
  # 2. Schema Linking (Embeddings)
27
  question_embedding = embedder.encode(question, convert_to_tensor=True)
28
  column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
29
  hits = util.semantic_search(question_embedding, column_embeddings, top_k=6)
30
  relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
31
+
32
+ # 3. Formulate Prompt (MATCHES TRAINING EXACTLY)
33
  schema_context = ", ".join(relevant_cols)
34
+ # Note the prefix change: "translate English to SQL"
35
+ input_text = f"translate English to SQL: {question} </s> {schema_context}"
36
+ print(f"Prompt: {input_text}")
37
+
38
  # 4. Generate
39
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
40
+
41
+ # Use beam search for better accuracy (slower but worth it)
42
+ outputs = model.generate(
43
+ input_ids,
44
+ max_length=128,
45
+ num_beams=4, # Inspects 4 possible paths
46
+ early_stopping=True
47
+ )
48
+
49
  generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ print(f"Output: '{generated_sql}'")
51
+
 
 
 
 
 
52
  return generated_sql
53
 
54
  except Exception as e:
55
+ return f"Error: {str(e)}"
 
 
56
 
57
+ # Simpler Interface
58
+ iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text")
59
+ iface.launch()