nl2sql-api / app.py
Chaitanya182004's picture
Update app.py
9a6bae9 verified
Raw
History Blame Contribute Delete
1.24 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
MODEL_NAME = "gaussalgo/T5-LM-Large-text2sql-spider"
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
print(f"Model ready on {device}")
def generate_sql(question: str, context: str) -> str:
input_text = f"{question} | {context}"
inputs = tokenizer(
input_text,
return_tensors='pt',
max_length=512,
truncation=True
).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=128,
num_beams=4,
early_stopping=True
)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql
with gr.Blocks() as demo:
gr.Markdown("# NL2SQL API")
with gr.Row():
question = gr.Textbox(label="Question")
context = gr.Textbox(label="Context")
output = gr.Textbox(label="SQL")
btn = gr.Button("Submit")
btn.click(fn=generate_sql, inputs=[question, context], outputs=output)
demo.launch(server_name="0.0.0.0", server_port=7860)