hmyunis commited on
Commit
4b5aaf7
·
1 Parent(s): b485595

Refactor model loading and error handling in app.py for improved clarity and robustness

Browse files
Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -4,61 +4,54 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
  # --- CONFIGURATION ---
7
- # REPLACE THIS with your fine-tuned model ID from Colab/Phase 1
8
  FINE_TUNED_MODEL_ID = "hmyunis/t5-sql-finetuned"
9
 
10
- # Load the Semantic Embedding Model (for Schema Linking)
11
- # This creates vectors to find the right columns mathematically
12
- print("Loading Embedding Model...")
13
- embedder = SentenceTransformer('all-MiniLM-L6-v2')
14
-
15
- # Load your Fine-Tuned T5 Model
16
- print("Loading T5 Model...")
17
- tokenizer = T5Tokenizer.from_pretrained(FINE_TUNED_MODEL_ID)
18
- model = T5ForConditionalGeneration.from_pretrained(FINE_TUNED_MODEL_ID)
19
 
20
  def get_sql_pipeline(question, all_columns_str):
21
- """
22
- 1. Receives question + ALL columns in the database (as a string list).
23
- 2. Uses Vector Search to find the top 5 relevant columns.
24
- 3. Generates SQL using the Fine-Tuned T5 model.
25
- """
26
  try:
27
- # Convert string representation of list back to list
28
- # Expected input format: "['table.col1', 'table.col2', ...]"
29
  all_columns = eval(all_columns_str)
30
 
31
- # --- NLP LAYER 1: SEMANTIC SCHEMA LINKING ---
32
- # 1. Encode the user's question into a vector
33
  question_embedding = embedder.encode(question, convert_to_tensor=True)
34
-
35
- # 2. Encode all database columns into vectors
36
  column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
37
-
38
- # 3. Calculate Cosine Similarity to find relevant columns
39
  hits = util.semantic_search(question_embedding, column_embeddings, top_k=6)
40
-
41
- # 4. Extract the top matching columns
42
  relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
43
 
44
- # Format for T5: "table: ..., columns: ..."
45
- # We simplify here to just a comma-separated list for the model context
46
  schema_context = ", ".join(relevant_cols)
47
-
48
- # --- NLP LAYER 2: GENERATION ---
49
- # The prompt must match exactly how we trained it in Colab
50
  input_text = f"translate to SQL: {question} </s> {schema_context}"
 
51
 
 
52
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
53
-
54
- # Generate
55
  outputs = model.generate(input_ids, max_length=128)
56
  generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
 
 
 
 
 
 
 
58
  return generated_sql
59
 
60
  except Exception as e:
61
- return f"Error in HF Space: {str(e)}"
 
 
62
 
63
  # Launch Gradio
64
  iface = gr.Interface(
@@ -66,4 +59,4 @@ iface = gr.Interface(
66
  inputs=["text", "text"],
67
  outputs="text"
68
  )
69
- iface.launch()
 
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
  # --- CONFIGURATION ---
7
+ # Ensure this matches your ACTUAL model on Hugging Face
8
  FINE_TUNED_MODEL_ID = "hmyunis/t5-sql-finetuned"
9
 
10
+ print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
11
+ try:
12
+ tokenizer = T5Tokenizer.from_pretrained(FINE_TUNED_MODEL_ID)
13
+ model = T5ForConditionalGeneration.from_pretrained(FINE_TUNED_MODEL_ID)
14
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
15
+ print("Models loaded successfully.")
16
+ except Exception as e:
17
+ print(f"CRITICAL ERROR LOADING MODELS: {e}")
 
18
 
19
  def get_sql_pipeline(question, all_columns_str):
20
+ print(f"Received Question: {question}")
21
+ print(f"Received Columns Length: {len(all_columns_str)}")
22
+
 
 
23
  try:
24
+ # 1. Parse Columns
 
25
  all_columns = eval(all_columns_str)
26
 
27
+ # 2. Schema Linking (Embeddings)
 
28
  question_embedding = embedder.encode(question, convert_to_tensor=True)
 
 
29
  column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
 
 
30
  hits = util.semantic_search(question_embedding, column_embeddings, top_k=6)
 
 
31
  relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
32
 
33
+ # 3. Formulate Prompt
 
34
  schema_context = ", ".join(relevant_cols)
 
 
 
35
  input_text = f"translate to SQL: {question} </s> {schema_context}"
36
+ print(f"Prompt sent to T5: {input_text}")
37
 
38
+ # 4. Generate
39
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids
 
 
40
  outputs = model.generate(input_ids, max_length=128)
41
  generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
+ print(f"Generated Output: '{generated_sql}'")
44
+
45
+ # Fallback if empty (Model produced nothing)
46
+ if not generated_sql:
47
+ return "SELECT * FROM api_customer -- Model returned empty, defaulting."
48
+
49
  return generated_sql
50
 
51
  except Exception as e:
52
+ error_msg = f"Error in HF Space: {str(e)}"
53
+ print(error_msg)
54
+ return error_msg
55
 
56
  # Launch Gradio
57
  iface = gr.Interface(
 
59
  inputs=["text", "text"],
60
  outputs="text"
61
  )
62
+ iface.launch()