oiisa commited on
Commit
f7514ec
·
verified ·
1 Parent(s): aa070f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
4
+ import os
5
+ import time
6
+
7
+ def make_query(context, question):
8
+ result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
9
+
10
+ ### Database schema (PostgreSQL):
11
+ {context}
12
+ ### Question:
13
+ {question}
14
+ ### SQL Query: '''
15
+ return result_query
16
+
17
+ st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🚀")
18
+ st.title("SQL Query Generator with T5")
19
+
20
+ examples = [
21
+ {
22
+ "question": "Show all products with price greater than 100.",
23
+ "description": "CREATE TABLE products (product_name VARCHAR, price INTEGER)"
24
+ },
25
+ {
26
+ "question": "What is the average salary of employees in the Sales department?",
27
+ "description": "CREATE TABLE employees (employee_name VARCHAR, department VARCHAR, salary INTEGER)"
28
+ },
29
+ {
30
+ "question": "Which students have a GPA higher than 3.5?",
31
+ "description": "CREATE TABLE students (student_id INTEGER, student_name VARCHAR, gpa FLOAT)"
32
+ },
33
+ {
34
+ "question": "List all orders made by customer with ID 12345.",
35
+ "description": "CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, order_date DATE)"
36
+ },
37
+ {
38
+ "question": "How many books were published after 2000?",
39
+ "description": "CREATE TABLE books (book_title VARCHAR, author VARCHAR, publication_year INTEGER)"
40
+ },
41
+ {
42
+ "question": "What is the total revenue from all completed transactions?",
43
+ "description": "CREATE TABLE transactions (transaction_id INTEGER, amount FLOAT, status VARCHAR)"
44
+ },
45
+ {
46
+ "question": "Which cities have a population between 1 million and 2 million?",
47
+ "description": "CREATE TABLE cities (city_name VARCHAR, country VARCHAR, population INTEGER)"
48
+ },
49
+ {
50
+ "question": "List all movies with rating higher than 8.0 released in 2020.",
51
+ "description": "CREATE TABLE movies (movie_title VARCHAR, release_year INTEGER, rating FLOAT)"
52
+ },
53
+ {
54
+ "question": "What is the most common job title in the company?",
55
+ "description": "CREATE TABLE staff (employee_id INTEGER, job_title VARCHAR, department VARCHAR)"
56
+ },
57
+ {
58
+ "question": "Which products are out of stock (quantity = 0)?",
59
+ "description": "CREATE TABLE inventory (product_id INTEGER, product_name VARCHAR, quantity INTEGER)"
60
+ }
61
+ ]
62
+
63
+ @st.cache_resource
64
+ def load_model():
65
+ script_dir = os.path.dirname(os.path.abspath(__file__))
66
+ model_path = os.path.join(script_dir, "model")
67
+
68
+ if not os.path.exists(model_path):
69
+ raise FileNotFoundError(f"Model directory not found at {model_path}")
70
+
71
+ try:
72
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
73
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
74
+ model.eval()
75
+ return model, tokenizer
76
+ except Exception as e:
77
+ raise RuntimeError(f"Error loading model: {str(e)}")
78
+
79
+ try:
80
+ model, tokenizer = load_model()
81
+ except Exception as e:
82
+ st.error(f"Failed to load model: {str(e)}")
83
+ st.stop()
84
+
85
+ if 'current_description' not in st.session_state:
86
+ st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
87
+ if 'current_question' not in st.session_state:
88
+ st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
89
+
90
+ def load_example(example):
91
+ st.session_state.current_description = example["description"]
92
+ st.session_state.current_question = example["question"]
93
+
94
+ st.subheader("Примеры:")
95
+ cols = st.columns(2)
96
+ for i, example in enumerate(examples):
97
+ col = cols[i % 2]
98
+ if col.button(
99
+ f"Пример {i+1}: {example['question'][:30]}...",
100
+ key=f"example_{i}",
101
+ ):
102
+ load_example(example)
103
+ st.rerun()
104
+
105
+ with st.form("query_form"):
106
+ description = st.text_area(
107
+ "Описание таблицы (столбцы и их типы):",
108
+ st.session_state.current_description,
109
+ height=150,
110
+ key="desc_input"
111
+ )
112
+
113
+ question = st.text_input(
114
+ "Ваш вопрос:",
115
+ st.session_state.current_question,
116
+ key="question_input"
117
+ )
118
+
119
+ submitted = st.form_submit_button("Сгенерировать запрос")
120
+
121
+ if submitted:
122
+ if description and question:
123
+ input_text = make_query(description, question)
124
+ try:
125
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
126
+
127
+ animation_placeholder = st.empty()
128
+
129
+ for frame in ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]:
130
+ animation_placeholder.markdown(f"`{frame}` Подготовка к генерации...")
131
+ time.sleep(0.1)
132
+
133
+ animation_placeholder.markdown("`⏳` Генерация SQL-запроса...")
134
+ outputs = model.generate(
135
+ input_ids,
136
+ max_length=200,
137
+ num_beams=5,
138
+ top_p=0.95,
139
+ early_stopping=True,
140
+ pad_token_id=tokenizer.eos_token_id,
141
+ )
142
+
143
+ animation_placeholder.empty()
144
+
145
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
146
+ st.subheader("Результат:")
147
+ st.code(generated_sql, language="sql")
148
+
149
+ except Exception as e:
150
+ st.error(f"Ошибка при генерации: {str(e)}")
151
+ else:
152
+ st.warning("Пожалуйста, заполните описание таблицы и вопрос")