text2sql / app.py
adamboom111's picture
Update app.py
8f665dd verified
raw
history blame
1.58 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_path = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
def generate_sql(payload):
question = payload.get("question", "")
schema = payload.get("schema", "")
sample_rows = payload.get("sample_rows", [])
sample_str = "\n".join([str(row) for row in sample_rows]) if sample_rows else ""
prompt = f"""
### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
### Database Schema
The query will run on a database with the following schema:
{schema}
### Sample Rows
{sample_str}
### Answer
Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
[SQL]
""".strip()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_length=512,
do_sample=False,
num_beams=4,
early_stopping=True
)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql.split("[SQL]")[-1].strip()
demo = gr.Interface(
fn=generate_sql,
inputs=gr.JSON(label="Input JSON (question, schema, sample_rows)"),
outputs="text",
title="SQLCoder - Text to SQL",
description="Enter a JSON object with 'question', 'schema', and optional 'sample_rows'. The model will generate SQL using Defog's sqlcoder-7b-2."
)
demo.launch()