oiisa commited on
Commit
fbd0369
·
verified ·
1 Parent(s): c58803f

Add new examples

Browse files
Files changed (1) hide show
  1. app.py +140 -140
app.py CHANGED
@@ -1,141 +1,141 @@
1
- import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- import torch
4
- import os
5
-
6
- def make_query(context, question):
7
- result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
8
-
9
- ### Database schema (PostgreSQL):
10
- {context}
11
-
12
- ### Question:
13
- {question}
14
-
15
- ### SQL Query: '''
16
- return result_query
17
-
18
- st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🤖")
19
- st.title("SQL Query Generator with T5")
20
-
21
- examples = [
22
- {
23
- "question": "Who was the music director in 1971 for the movie Kalyani?",
24
- "description": "CREATE TABLE table_name_7 (Music VARCHAR, year VARCHAR, movie__in_kannada__ VARCHAR)"
25
- },
26
- {
27
- "question": "What's the highest with a capacity of greater than 4,000 and an average of 615?",
28
- "description": "CREATE TABLE table_name_20 (highest INTEGER, average VARCHAR, capacity VARCHAR)"
29
- },
30
- {
31
- "question": "If the letters is φαν, what is the founding date?",
32
- "description": "CREATE TABLE table_2538117_7 (founding_date VARCHAR, letters VARCHAR)"
33
- },
34
- {
35
- "question": "How many weeks had a game on November 26, 1978, and an attendance higher than 26,248?",
36
- "description": "CREATE TABLE table_name_88 (week VARCHAR, date VARCHAR, attendance VARCHAR)"
37
- },
38
- {
39
- "question": "How many television service are in italian and n°is greater than 856.0?",
40
- "description": "CREATE TABLE table_15887683_15 (television_service VARCHAR, language VARCHAR, VARCHAR)"
41
- },
42
- {
43
- "question": "What date was Bury the home team?",
44
- "description": "CREATE TABLE table_name_67 (date VARCHAR, away_team VARCHAR)"
45
- },
46
- {
47
- "question": "What regular season result had an average attendance less than 942?",
48
- "description": "CREATE TABLE table_name_16 (reg_season VARCHAR, avg_attendance INTEGER)"
49
- },
50
- {
51
- "question": "What is the value for 2011 when `a` is the value for 2009, and 4r is the value for 2013?",
52
- "description": "CREATE TABLE table_name_89 (Id VARCHAR)"
53
- },
54
- {
55
- "question": "Who wrote episode with production code 1.01?",
56
- "description": "CREATE TABLE table_28089666_1 (written_by VARCHAR, production_code VARCHAR)"
57
- },
58
- {
59
- "question": "find the names of museums which have more staff than the minimum staff number of all museums opened after 2010.",
60
- "description": "CREATE TABLE museum (name VARCHAR, num_of_staff INTEGER, open_year INTEGER)"
61
- }
62
- ]
63
-
64
- @st.cache_resource
65
- def load_model():
66
- script_dir = os.path.dirname(os.path.abspath(__file__))
67
- model_path = os.path.join(script_dir, "model")
68
-
69
- if not os.path.exists(model_path):
70
- raise FileNotFoundError(f"Model directory not found at {model_path}")
71
-
72
- try:
73
- tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
74
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
75
- model.eval()
76
- return model, tokenizer
77
- except Exception as e:
78
- raise RuntimeError(f"Error loading model: {str(e)}")
79
-
80
- try:
81
- model, tokenizer = load_model()
82
- except Exception as e:
83
- st.error(f"Failed to load model: {str(e)}")
84
- st.stop()
85
-
86
- if 'current_description' not in st.session_state:
87
- st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
88
- if 'current_question' not in st.session_state:
89
- st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
90
-
91
- def load_example(example):
92
- st.session_state.current_description = example["description"]
93
- st.session_state.current_question = example["question"]
94
-
95
- st.subheader("Примеры:")
96
- cols = st.columns(2)
97
- for i, example in enumerate(examples):
98
- col = cols[i % 2]
99
- if col.button(
100
- f"Пример {i+1}: {example['question'][:30]}...",
101
- key=f"example_{i}",
102
- ):
103
- load_example(example)
104
- st.rerun()
105
-
106
- with st.form("query_form"):
107
- description = st.text_area(
108
- "Описание таблицы (столбцы и их типы):",
109
- st.session_state.current_description,
110
- height=150,
111
- key="desc_input"
112
- )
113
-
114
- question = st.text_input(
115
- "Ваш вопрос:",
116
- st.session_state.current_question,
117
- key="question_input"
118
- )
119
-
120
- submitted = st.form_submit_button("Сгенерировать запрос")
121
-
122
- if submitted:
123
- if description and question:
124
- input_text = make_query(description, question)
125
- try:
126
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
127
- with st.spinner("Генерация запроса..."):
128
- outputs = model.generate(
129
- input_ids,
130
- max_length=200,
131
- num_beams=5,
132
- early_stopping=True,
133
- pad_token_id=tokenizer.eos_token_id,
134
- )
135
- generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
- st.subheader("Результат:")
137
- st.code(generated_sql, language="sql")
138
- except Exception as e:
139
- st.error(f"Ошибка при генерации: {str(e)}")
140
- else:
141
  st.warning("Пожалуйста, заполните описание таблицы и вопрос")
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
4
+ import os
5
+
6
+ def make_query(context, question):
7
+ result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
8
+
9
+ ### Database schema (PostgreSQL):
10
+ {context}
11
+
12
+ ### Question:
13
+ {question}
14
+
15
+ ### SQL Query: '''
16
+ return result_query
17
+
18
+ st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🚀")
19
+ st.title("SQL Query Generator with T5")
20
+
21
+ examples = [
22
+ {
23
+ "question": "Show all products with price greater than 100.",
24
+ "description": "CREATE TABLE products (product_name VARCHAR, price INTEGER)"
25
+ },
26
+ {
27
+ "question": "What is the average salary of employees in the Sales department?",
28
+ "description": "CREATE TABLE employees (employee_name VARCHAR, department VARCHAR, salary INTEGER)"
29
+ },
30
+ {
31
+ "question": "Which students have a GPA higher than 3.5?",
32
+ "description": "CREATE TABLE students (student_id INTEGER, student_name VARCHAR, gpa FLOAT)"
33
+ },
34
+ {
35
+ "question": "List all orders made by customer with ID 12345.",
36
+ "description": "CREATE TABLE orders (order_id INTEGER, customer_id INTEGER, order_date DATE)"
37
+ },
38
+ {
39
+ "question": "How many books were published after 2000?",
40
+ "description": "CREATE TABLE books (book_title VARCHAR, author VARCHAR, publication_year INTEGER)"
41
+ },
42
+ {
43
+ "question": "What is the total revenue from all completed transactions?",
44
+ "description": "CREATE TABLE transactions (transaction_id INTEGER, amount FLOAT, status VARCHAR)"
45
+ },
46
+ {
47
+ "question": "Which cities have a population between 1 million and 2 million?",
48
+ "description": "CREATE TABLE cities (city_name VARCHAR, country VARCHAR, population INTEGER)"
49
+ },
50
+ {
51
+ "question": "List all movies with rating higher than 8.0 released in 2020.",
52
+ "description": "CREATE TABLE movies (movie_title VARCHAR, release_year INTEGER, rating FLOAT)"
53
+ },
54
+ {
55
+ "question": "What is the most common job title in the company?",
56
+ "description": "CREATE TABLE staff (employee_id INTEGER, job_title VARCHAR, department VARCHAR)"
57
+ },
58
+ {
59
+ "question": "Which products are out of stock (quantity = 0)?",
60
+ "description": "CREATE TABLE inventory (product_id INTEGER, product_name VARCHAR, quantity INTEGER)"
61
+ }
62
+ ]
63
+
64
+ @st.cache_resource
65
+ def load_model():
66
+ script_dir = os.path.dirname(os.path.abspath(__file__))
67
+ model_path = os.path.join(script_dir, "model")
68
+
69
+ if not os.path.exists(model_path):
70
+ raise FileNotFoundError(f"Model directory not found at {model_path}")
71
+
72
+ try:
73
+ tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
74
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
75
+ model.eval()
76
+ return model, tokenizer
77
+ except Exception as e:
78
+ raise RuntimeError(f"Error loading model: {str(e)}")
79
+
80
+ try:
81
+ model, tokenizer = load_model()
82
+ except Exception as e:
83
+ st.error(f"Failed to load model: {str(e)}")
84
+ st.stop()
85
+
86
+ if 'current_description' not in st.session_state:
87
+ st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
88
+ if 'current_question' not in st.session_state:
89
+ st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
90
+
91
+ def load_example(example):
92
+ st.session_state.current_description = example["description"]
93
+ st.session_state.current_question = example["question"]
94
+
95
+ st.subheader("Примеры:")
96
+ cols = st.columns(2)
97
+ for i, example in enumerate(examples):
98
+ col = cols[i % 2]
99
+ if col.button(
100
+ f"Пример {i+1}: {example['question'][:30]}...",
101
+ key=f"example_{i}",
102
+ ):
103
+ load_example(example)
104
+ st.rerun()
105
+
106
+ with st.form("query_form"):
107
+ description = st.text_area(
108
+ "Описание таблицы (столбцы и их типы):",
109
+ st.session_state.current_description,
110
+ height=150,
111
+ key="desc_input"
112
+ )
113
+
114
+ question = st.text_input(
115
+ "Ваш вопрос:",
116
+ st.session_state.current_question,
117
+ key="question_input"
118
+ )
119
+
120
+ submitted = st.form_submit_button("Сгенерировать запрос")
121
+
122
+ if submitted:
123
+ if description and question:
124
+ input_text = make_query(description, question)
125
+ try:
126
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
127
+ with st.spinner("Генерация запроса..."):
128
+ outputs = model.generate(
129
+ input_ids,
130
+ max_length=200,
131
+ num_beams=5,
132
+ early_stopping=True,
133
+ pad_token_id=tokenizer.eos_token_id,
134
+ )
135
+ generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
+ st.subheader("Результат:")
137
+ st.code(generated_sql, language="sql")
138
+ except Exception as e:
139
+ st.error(f"Ошибка при генерации: {str(e)}")
140
+ else:
141
  st.warning("Пожалуйста, заполните описание таблицы и вопрос")