| import streamlit as st |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| import torch |
| import os |
| import time |
|
|
| def make_query(context, question): |
| result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question. |
| |
| ### Database schema (PostgreSQL): |
| {context} |
| |
| ### Question: |
| {question} |
| |
| ### SQL Query: ''' |
| return result_query |
|
|
| st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🚀") |
| st.title("SQL Query Generator with T5") |
|
|
| examples = [ |
| { |
| "question": "Show all products with price greater than 100.", |
| "description": "CREATE TABLE products (product_name VARCHAR, price INTEGER)" |
| }, |
| { |
| "question": "What is the average salary of employees in the Sales department?", |
| "description": "CREATE TABLE employees (employee_name VARCHAR, department VARCHAR, salary INTEGER)" |
| }, |
| { |
| "question": "Which students have a GPA higher than 3.5?", |
| "description": "CREATE TABLE students (student_id INTEGER, student_name VARCHAR, gpa FLOAT)" |
| }, |
| { |
| "question": "List all orders made by customer with ID 12345.", |
| "description": "CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, order_date DATE)" |
| }, |
| { |
| "question": "How many books were published after 2000?", |
| "description": "CREATE TABLE books (book_title VARCHAR, author VARCHAR, publication_year INTEGER)" |
| }, |
| { |
| "question": "What is the total revenue from all completed transactions?", |
| "description": "CREATE TABLE transactions (transaction_id INTEGER, amount FLOAT, status VARCHAR)" |
| }, |
| { |
| "question": "Which cities have a population between 1 million and 2 million?", |
| "description": "CREATE TABLE cities (city_name VARCHAR, country VARCHAR, population INTEGER)" |
| }, |
| { |
| "question": "List all movies with rating higher than 8.0 released in 2020.", |
| "description": "CREATE TABLE movies (movie_title VARCHAR, release_year INTEGER, rating FLOAT)" |
| }, |
| { |
| "question": "What is the most common job title in the company?", |
| "description": "CREATE TABLE staff (employee_id INTEGER, job_title VARCHAR, department VARCHAR)" |
| }, |
| { |
| "question": "Which products are out of stock (quantity = 0)?", |
| "description": "CREATE TABLE inventory (product_id INTEGER, product_name VARCHAR, quantity INTEGER)" |
| } |
| ] |
|
|
| @st.cache_resource |
| def load_model(): |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| model_path = os.path.join(script_dir, "model") |
| |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model directory not found at {model_path}") |
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
| model.eval() |
| return model, tokenizer |
| except Exception as e: |
| raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
| try: |
| model, tokenizer = load_model() |
| except Exception as e: |
| st.error(f"Failed to load model: {str(e)}") |
| st.stop() |
|
|
| if 'current_description' not in st.session_state: |
| st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)""" |
| if 'current_question' not in st.session_state: |
| st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?" |
|
|
| def load_example(example): |
| st.session_state.current_description = example["description"] |
| st.session_state.current_question = example["question"] |
|
|
| st.subheader("Примеры:") |
| cols = st.columns(2) |
| for i, example in enumerate(examples): |
| col = cols[i % 2] |
| if col.button( |
| f"Пример {i+1}: {example['question'][:30]}...", |
| key=f"example_{i}", |
| ): |
| load_example(example) |
| st.rerun() |
|
|
| with st.form("query_form"): |
| description = st.text_area( |
| "Описание таблицы (столбцы и их типы):", |
| st.session_state.current_description, |
| height=150, |
| key="desc_input" |
| ) |
| |
| question = st.text_input( |
| "Ваш вопрос:", |
| st.session_state.current_question, |
| key="question_input" |
| ) |
| |
| submitted = st.form_submit_button("Сгенерировать запрос") |
|
|
| if submitted: |
| if description and question: |
| input_text = make_query(description, question) |
| try: |
| input_ids = tokenizer.encode(input_text, return_tensors="pt") |
| |
| animation_placeholder = st.empty() |
| |
| for frame in ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]: |
| animation_placeholder.markdown(f"`{frame}` Подготовка к генерации...") |
| time.sleep(0.1) |
| |
| animation_placeholder.markdown("`⏳` Генерация SQL-запроса...") |
| outputs = model.generate( |
| input_ids, |
| max_length=200, |
| num_beams=5, |
| top_p=0.95, |
| early_stopping=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
| |
| animation_placeholder.empty() |
|
|
| generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| st.subheader("Результат:") |
| st.code(generated_sql, language="sql") |
| |
| except Exception as e: |
| st.error(f"Ошибка при генерации: {str(e)}") |
| else: |
| st.warning("Пожалуйста, заполните описание таблицы и вопрос") |