text2sql-base / app.py
oiisa's picture
Update app.py
2d57502 verified
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
import time
def make_query(context, question):
result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
### Database schema (PostgreSQL):
{context}
### Question:
{question}
### SQL Query: '''
return result_query
st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🚀")
st.title("SQL Query Generator with T5")
examples = [
{
"question": "Show all products with price greater than 100.",
"description": "CREATE TABLE products (product_name VARCHAR, price INTEGER)"
},
{
"question": "What is the average salary of employees in the Sales department?",
"description": "CREATE TABLE employees (employee_name VARCHAR, department VARCHAR, salary INTEGER)"
},
{
"question": "Which students have a GPA higher than 3.5?",
"description": "CREATE TABLE students (student_id INTEGER, student_name VARCHAR, gpa FLOAT)"
},
{
"question": "List all orders made by customer with ID 12345.",
"description": "CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, order_date DATE)"
},
{
"question": "How many books were published after 2000?",
"description": "CREATE TABLE books (book_title VARCHAR, author VARCHAR, publication_year INTEGER)"
},
{
"question": "What is the total revenue from all completed transactions?",
"description": "CREATE TABLE transactions (transaction_id INTEGER, amount FLOAT, status VARCHAR)"
},
{
"question": "Which cities have a population between 1 million and 2 million?",
"description": "CREATE TABLE cities (city_name VARCHAR, country VARCHAR, population INTEGER)"
},
{
"question": "List all movies with rating higher than 8.0 released in 2020.",
"description": "CREATE TABLE movies (movie_title VARCHAR, release_year INTEGER, rating FLOAT)"
},
{
"question": "What is the most common job title in the company?",
"description": "CREATE TABLE staff (employee_id INTEGER, job_title VARCHAR, department VARCHAR)"
},
{
"question": "Which products are out of stock (quantity = 0)?",
"description": "CREATE TABLE inventory (product_id INTEGER, product_name VARCHAR, quantity INTEGER)"
}
]
@st.cache_resource
def load_model():
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "model")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model directory not found at {model_path}")
try:
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
return model, tokenizer
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
try:
model, tokenizer = load_model()
except Exception as e:
st.error(f"Failed to load model: {str(e)}")
st.stop()
if 'current_description' not in st.session_state:
st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
if 'current_question' not in st.session_state:
st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
def load_example(example):
st.session_state.current_description = example["description"]
st.session_state.current_question = example["question"]
st.subheader("Примеры:")
cols = st.columns(2)
for i, example in enumerate(examples):
col = cols[i % 2]
if col.button(
f"Пример {i+1}: {example['question'][:30]}...",
key=f"example_{i}",
):
load_example(example)
st.rerun()
with st.form("query_form"):
description = st.text_area(
"Описание таблицы (столбцы и их типы):",
st.session_state.current_description,
height=150,
key="desc_input"
)
question = st.text_input(
"Ваш вопрос:",
st.session_state.current_question,
key="question_input"
)
submitted = st.form_submit_button("Сгенерировать запрос")
if submitted:
if description and question:
input_text = make_query(description, question)
try:
input_ids = tokenizer.encode(input_text, return_tensors="pt")
animation_placeholder = st.empty()
for frame in ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]:
animation_placeholder.markdown(f"`{frame}` Подготовка к генерации...")
time.sleep(0.1)
animation_placeholder.markdown("`⏳` Генерация SQL-запроса...")
outputs = model.generate(
input_ids,
max_length=200,
num_beams=5,
top_p=0.95,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
)
animation_placeholder.empty()
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.subheader("Результат:")
st.code(generated_sql, language="sql")
except Exception as e:
st.error(f"Ошибка при генерации: {str(e)}")
else:
st.warning("Пожалуйста, заполните описание таблицы и вопрос")