import gradio as gr import torch from transformers import T5Tokenizer, T5ForConditionalGeneration from sentence_transformers import SentenceTransformer, util # --- CONFIGURATION --- FINE_TUNED_MODEL_ID = "hmyunis/t5-base-sql-custom" print(f"Loading Model: {FINE_TUNED_MODEL_ID}...") try: tokenizer = T5Tokenizer.from_pretrained(FINE_TUNED_MODEL_ID) model = T5ForConditionalGeneration.from_pretrained(FINE_TUNED_MODEL_ID) embedder = SentenceTransformer('all-MiniLM-L6-v2') print("Models loaded successfully.") except Exception as e: print(f"CRITICAL ERROR LOADING MODELS: {e}") def format_schema_like_training(raw_column_list): """ Transforms ['api_customer.name', 'api_customer.city', 'api_order.id'] Into: "api_customer: name, city | api_order: id" This matches the pattern the model saw during training. """ schema_map = {} for item in raw_column_list: if "." in item: table, col = item.split('.', 1) if table not in schema_map: schema_map[table] = [] schema_map[table].append(col) # Join nicely parts = [f"{table}: {', '.join(cols)}" for table, cols in schema_map.items()] return " | ".join(parts) def get_sql_pipeline(question, all_columns_str): print(f"Input Q: {question}") try: # 1. Parse Columns all_columns = eval(all_columns_str) # 2. Schema Linking (Embeddings) question_embedding = embedder.encode(question, convert_to_tensor=True) column_embeddings = embedder.encode(all_columns, convert_to_tensor=True) # Increase Top-K to 10 to ensure we get enough context from the right table hits = util.semantic_search(question_embedding, column_embeddings, top_k=10) relevant_cols = [all_columns[hit['corpus_id']] for hit in hits[0]] # 3. Formulate Prompt (CRITICAL FIX HERE) # We re-format the list to look like "table: col1, col2" schema_context = format_schema_like_training(relevant_cols) input_text = f"translate English to SQL: {question} {schema_context}" print(f"Prompt: {input_text}") # 4. Generate input_ids = tokenizer(input_text, return_tensors="pt").input_ids outputs = model.generate( input_ids, max_length=128, num_beams=4, early_stopping=True ) generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Output: '{generated_sql}'") return generated_sql except Exception as e: return f"Error: {str(e)}" iface = gr.Interface(fn=get_sql_pipeline, inputs=["text", "text"], outputs="text") iface.launch()