Spaces:
Runtime error
Runtime error
File size: 5,699 Bytes
f7514ec 2d57502 f7514ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
import time
def make_query(context, question):
result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
### Database schema (PostgreSQL):
{context}
### Question:
{question}
### SQL Query: '''
return result_query
st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🚀")
st.title("SQL Query Generator with T5")
examples = [
{
"question": "Show all products with price greater than 100.",
"description": "CREATE TABLE products (product_name VARCHAR, price INTEGER)"
},
{
"question": "What is the average salary of employees in the Sales department?",
"description": "CREATE TABLE employees (employee_name VARCHAR, department VARCHAR, salary INTEGER)"
},
{
"question": "Which students have a GPA higher than 3.5?",
"description": "CREATE TABLE students (student_id INTEGER, student_name VARCHAR, gpa FLOAT)"
},
{
"question": "List all orders made by customer with ID 12345.",
"description": "CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, order_date DATE)"
},
{
"question": "How many books were published after 2000?",
"description": "CREATE TABLE books (book_title VARCHAR, author VARCHAR, publication_year INTEGER)"
},
{
"question": "What is the total revenue from all completed transactions?",
"description": "CREATE TABLE transactions (transaction_id INTEGER, amount FLOAT, status VARCHAR)"
},
{
"question": "Which cities have a population between 1 million and 2 million?",
"description": "CREATE TABLE cities (city_name VARCHAR, country VARCHAR, population INTEGER)"
},
{
"question": "List all movies with rating higher than 8.0 released in 2020.",
"description": "CREATE TABLE movies (movie_title VARCHAR, release_year INTEGER, rating FLOAT)"
},
{
"question": "What is the most common job title in the company?",
"description": "CREATE TABLE staff (employee_id INTEGER, job_title VARCHAR, department VARCHAR)"
},
{
"question": "Which products are out of stock (quantity = 0)?",
"description": "CREATE TABLE inventory (product_id INTEGER, product_name VARCHAR, quantity INTEGER)"
}
]
@st.cache_resource
def load_model():
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "model")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model directory not found at {model_path}")
try:
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
return model, tokenizer
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
try:
model, tokenizer = load_model()
except Exception as e:
st.error(f"Failed to load model: {str(e)}")
st.stop()
if 'current_description' not in st.session_state:
st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
if 'current_question' not in st.session_state:
st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
def load_example(example):
st.session_state.current_description = example["description"]
st.session_state.current_question = example["question"]
st.subheader("Примеры:")
cols = st.columns(2)
for i, example in enumerate(examples):
col = cols[i % 2]
if col.button(
f"Пример {i+1}: {example['question'][:30]}...",
key=f"example_{i}",
):
load_example(example)
st.rerun()
with st.form("query_form"):
description = st.text_area(
"Описание таблицы (столбцы и их типы):",
st.session_state.current_description,
height=150,
key="desc_input"
)
question = st.text_input(
"Ваш вопрос:",
st.session_state.current_question,
key="question_input"
)
submitted = st.form_submit_button("Сгенерировать запрос")
if submitted:
if description and question:
input_text = make_query(description, question)
try:
input_ids = tokenizer.encode(input_text, return_tensors="pt")
animation_placeholder = st.empty()
for frame in ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]:
animation_placeholder.markdown(f"`{frame}` Подготовка к генерации...")
time.sleep(0.1)
animation_placeholder.markdown("`⏳` Генерация SQL-запроса...")
outputs = model.generate(
input_ids,
max_length=200,
num_beams=5,
top_p=0.95,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
)
animation_placeholder.empty()
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.subheader("Результат:")
st.code(generated_sql, language="sql")
except Exception as e:
st.error(f"Ошибка при генерации: {str(e)}")
else:
st.warning("Пожалуйста, заполните описание таблицы и вопрос") |