File size: 5,699 Bytes
f7514ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d57502
f7514ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
import time

def make_query(context, question):
    result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
    
### Database schema (PostgreSQL):
{context}
### Question:
{question}
### SQL Query: '''
    return result_query

st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🚀")
st.title("SQL Query Generator with T5")

examples = [
    {
        "question": "Show all products with price greater than 100.",
        "description": "CREATE TABLE products (product_name VARCHAR, price INTEGER)"
    },
    {
        "question": "What is the average salary of employees in the Sales department?",
        "description": "CREATE TABLE employees (employee_name VARCHAR, department VARCHAR, salary INTEGER)"
    },
    {
        "question": "Which students have a GPA higher than 3.5?",
        "description": "CREATE TABLE students (student_id INTEGER, student_name VARCHAR, gpa FLOAT)"
    },
    {
        "question": "List all orders made by customer with ID 12345.",
        "description": "CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, order_date DATE)"
    },
    {
        "question": "How many books were published after 2000?",
        "description": "CREATE TABLE books (book_title VARCHAR, author VARCHAR, publication_year INTEGER)"
    },
    {
        "question": "What is the total revenue from all completed transactions?",
        "description": "CREATE TABLE transactions (transaction_id INTEGER, amount FLOAT, status VARCHAR)"
    },
    {
        "question": "Which cities have a population between 1 million and 2 million?",
        "description": "CREATE TABLE cities (city_name VARCHAR, country VARCHAR, population INTEGER)"
    },
    {
        "question": "List all movies with rating higher than 8.0 released in 2020.",
        "description": "CREATE TABLE movies (movie_title VARCHAR, release_year INTEGER, rating FLOAT)"
    },
    {
        "question": "What is the most common job title in the company?",
        "description": "CREATE TABLE staff (employee_id INTEGER, job_title VARCHAR, department VARCHAR)"
    },
    {
        "question": "Which products are out of stock (quantity = 0)?",
        "description": "CREATE TABLE inventory (product_id INTEGER, product_name VARCHAR, quantity INTEGER)"
    }
]

@st.cache_resource
def load_model():
    script_dir = os.path.dirname(os.path.abspath(__file__))
    model_path = os.path.join(script_dir, "model")
    
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model directory not found at {model_path}")
    
    try:
        tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        model.eval()
        return model, tokenizer
    except Exception as e:
        raise RuntimeError(f"Error loading model: {str(e)}")

try:
    model, tokenizer = load_model()
except Exception as e:
    st.error(f"Failed to load model: {str(e)}")
    st.stop()

if 'current_description' not in st.session_state:
    st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
if 'current_question' not in st.session_state:
    st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"

def load_example(example):
    st.session_state.current_description = example["description"]
    st.session_state.current_question = example["question"]

st.subheader("Примеры:")
cols = st.columns(2)
for i, example in enumerate(examples):
    col = cols[i % 2]
    if col.button(
        f"Пример {i+1}: {example['question'][:30]}...",
        key=f"example_{i}",
    ):
        load_example(example)
        st.rerun()

with st.form("query_form"):
    description = st.text_area(
        "Описание таблицы (столбцы и их типы):",
        st.session_state.current_description,
        height=150,
        key="desc_input"
    )
    
    question = st.text_input(
        "Ваш вопрос:", 
        st.session_state.current_question,
        key="question_input"
    )
    
    submitted = st.form_submit_button("Сгенерировать запрос")

if submitted:
    if description and question:
        input_text = make_query(description, question)
        try:
            input_ids = tokenizer.encode(input_text, return_tensors="pt")
            
            animation_placeholder = st.empty()
            
            for frame in ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]:
                animation_placeholder.markdown(f"`{frame}` Подготовка к генерации...")
                time.sleep(0.1)
            
            animation_placeholder.markdown("`⏳` Генерация SQL-запроса...")
            outputs = model.generate(
                input_ids,
                max_length=200,
                num_beams=5,
                top_p=0.95,
                early_stopping=True,
                pad_token_id=tokenizer.eos_token_id,
            )
            
            animation_placeholder.empty()

            generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
            st.subheader("Результат:")
            st.code(generated_sql, language="sql")
            
        except Exception as e:
            st.error(f"Ошибка при генерации: {str(e)}")
    else:
        st.warning("Пожалуйста, заполните описание таблицы и вопрос")