Spaces:
Sleeping
Sleeping
File size: 2,797 Bytes
b105a98 6366587 b105a98 6366587 b7d2fa2 b105a98 4b5aaf7 6366587 8dfd7bb 6366587 b7d2fa2 6366587 4b5aaf7 b7d2fa2 4b5aaf7 6366587 8dfd7bb 6366587 b7d2fa2 8dfd7bb b7d2fa2 4b5aaf7 6366587 b7d2fa2 8dfd7bb b7d2fa2 6366587 b7d2fa2 6366587 b7d2fa2 6366587 b7d2fa2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import gradio as gr
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer, util
# --- CONFIGURATION ---
FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom"
print(f"Loading Model: {FINE_TUNED_MODEL_ID}...")
try:
tokenizer = T5Tokenizer.from_pretrained(FINE_TUNED_MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(FINE_TUNED_MODEL_ID)
embedder = SentenceTransformer('all-MiniLM-L6-v2')
print("Models loaded successfully.")
except Exception as e:
print(f"CRITICAL ERROR LOADING MODELS: {e}")
def format_schema_like_training(raw_column_list):
"""
Transforms ['api_customer.name', 'api_customer.city', 'api_order.id']
Into: "api_customer: name, city | api_order: id"
This matches the pattern the model saw during training.
"""
schema_map = {}
for item in raw_column_list:
if "." in item:
table, col = item.split('.', 1)
if table not in schema_map:
schema_map[table] = []
schema_map[table].append(col)
# Join nicely
parts = [f"{table}: {', '.join(cols)}" for table, cols in schema_map.items()]
return " | ".join(parts)
def get_sql_pipeline(question, all_columns_str):
print(f"Input Q: {question}")
try:
# 1. Parse Columns
all_columns = eval(all_columns_str)
# 2. Schema Linking (Embeddings)
question_embedding = embedder.encode(question, convert_to_tensor=True)
column_embeddings = embedder.encode(all_columns, convert_to_tensor=True)
# Increase Top-K to 10 to ensure we get enough context from the right table
hits = util.semantic_search(question_embedding, column_embeddings, top_k=10)
relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]]
# 3. Formulate Prompt (CRITICAL FIX HERE)
# We re-format the list to look like "table: col1, col2"
schema_context = format_schema_like_training(relevant_cols)
input_text = f"translate English to SQL: {question} </s> {schema_context}"
print(f"Prompt: {input_text}")
# 4. Generate
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
max_length=128,
num_beams=4,
early_stopping=True
)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: '{generated_sql}'")
return generated_sql
except Exception as e:
return f"Error: {str(e)}"
iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text")
iface.launch() |