text2sql / app.py
adamboom111's picture
Update app.py
cc6747b verified
raw
history blame
1.07 kB
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the GaussAlgo model
model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
def generate_sql(payload):
# Extract components from payload
question = payload.get("question", "")
schema = payload.get("schema", "")
# Build model input
full_prompt = f"Question: {question} Schema: {schema}"
inputs = tokenizer(full_prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_sql
# Define expected input as a JSON object (dict)
demo = gr.Interface(
fn=generate_sql,
inputs=gr.JSON(label="Input JSON (with 'question' and 'schema')"),
outputs="text",
title="Text-to-SQL Generator",
description="Input a JSON with your natural language question and database schema. Output is SQL."
)
demo.launch()