text2sql / app.py
Sid26Roy's picture
Create app.py
495c53e verified
raw
history blame
3.25 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import gradio as gr
model_name = "defog/llama-3-sqlcoder-8b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Check GPU memory if available, otherwise default to 4-bit mode
def get_model():
try:
available_memory = torch.cuda.get_device_properties(0).total_memory
except:
available_memory = 0
if available_memory > 20e9:
return AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
use_cache=True,
)
else:
return AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
load_in_4bit=True,
device_map="auto",
use_cache=True,
)
model = get_model()
prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{question}`
DDL statements:
CREATE TABLE expenses (
id INTEGER PRIMARY KEY,
date DATE NOT NULL,
amount DECIMAL(10,2) NOT NULL,
category VARCHAR(50) NOT NULL,
description TEXT,
payment_method VARCHAR(20),
user_id INTEGER
);
CREATE TABLE categories (
id INTEGER PRIMARY KEY,
name VARCHAR(50) UNIQUE NOT NULL,
description TEXT
);
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE budgets (
id INTEGER PRIMARY KEY,
user_id INTEGER,
category VARCHAR(50),
amount DECIMAL(10,2) NOT NULL,
period VARCHAR(20) DEFAULT 'monthly',
start_date DATE,
end_date DATE
);
-- expenses.user_id can be joined with users.id
-- expenses.category can be joined with categories.name
-- budgets.user_id can be joined with users.id
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question `{question}`:
```sql
"""
def generate_query(question):
formatted_prompt = prompt.format(question=question)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
generated_ids = model.generate(
**inputs,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=400,
do_sample=False,
num_beams=1,
temperature=0.0,
top_p=1,
)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
try:
sql_code = output.split("```sql")[1].split("```")[0].strip()
return sqlparse.format(sql_code, reindent=True)
except:
return "SQL could not be parsed. Raw Output:\n\n" + output
# Gradio Interface
iface = gr.Interface(
fn=generate_query,
inputs=gr.Textbox(lines=3, placeholder="Enter your natural language question..."),
outputs="text",
title="LLaMA 3 SQLCoder 🦙",
description="Enter a natural language question and get a SQL query based on predefined tables.",
)
if __name__ == "__main__":
iface.launch()