| import streamlit as st
|
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
| import torch
|
| import os
|
|
|
| def make_query(context, question):
|
| result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
|
|
|
| ### Database schema (PostgreSQL):
|
| {context}
|
|
|
| ### Question:
|
| {question}
|
|
|
| ### SQL Query: '''
|
| return result_query
|
|
|
| st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🤖")
|
| st.title("SQL Query Generator with T5")
|
|
|
| examples = [
|
| {
|
| "question": "Who was the music director in 1971 for the movie Kalyani?",
|
| "description": "CREATE TABLE table_name_7 (Music VARCHAR, year VARCHAR, movie__in_kannada__ VARCHAR)"
|
| },
|
| {
|
| "question": "What's the highest with a capacity of greater than 4,000 and an average of 615?",
|
| "description": "CREATE TABLE table_name_20 (highest INTEGER, average VARCHAR, capacity VARCHAR)"
|
| },
|
| {
|
| "question": "If the letters is φαν, what is the founding date?",
|
| "description": "CREATE TABLE table_2538117_7 (founding_date VARCHAR, letters VARCHAR)"
|
| },
|
| {
|
| "question": "How many weeks had a game on November 26, 1978, and an attendance higher than 26,248?",
|
| "description": "CREATE TABLE table_name_88 (week VARCHAR, date VARCHAR, attendance VARCHAR)"
|
| },
|
| {
|
| "question": "How many television service are in italian and n°is greater than 856.0?",
|
| "description": "CREATE TABLE table_15887683_15 (television_service VARCHAR, language VARCHAR, n° VARCHAR)"
|
| },
|
| {
|
| "question": "What date was Bury the home team?",
|
| "description": "CREATE TABLE table_name_67 (date VARCHAR, away_team VARCHAR)"
|
| },
|
| {
|
| "question": "What regular season result had an average attendance less than 942?",
|
| "description": "CREATE TABLE table_name_16 (reg_season VARCHAR, avg_attendance INTEGER)"
|
| },
|
| {
|
| "question": "What is the value for 2011 when `a` is the value for 2009, and 4r is the value for 2013?",
|
| "description": "CREATE TABLE table_name_89 (Id VARCHAR)"
|
| },
|
| {
|
| "question": "Who wrote episode with production code 1.01?",
|
| "description": "CREATE TABLE table_28089666_1 (written_by VARCHAR, production_code VARCHAR)"
|
| },
|
| {
|
| "question": "find the names of museums which have more staff than the minimum staff number of all museums opened after 2010.",
|
| "description": "CREATE TABLE museum (name VARCHAR, num_of_staff INTEGER, open_year INTEGER)"
|
| }
|
| ]
|
|
|
| @st.cache_resource
|
| def load_model():
|
| script_dir = os.path.dirname(os.path.abspath(__file__))
|
| model_path = os.path.join(script_dir, "model")
|
|
|
| if not os.path.exists(model_path):
|
| raise FileNotFoundError(f"Model directory not found at {model_path}")
|
|
|
| try:
|
| tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
|
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
| model.eval()
|
| return model, tokenizer
|
| except Exception as e:
|
| raise RuntimeError(f"Error loading model: {str(e)}")
|
|
|
| try:
|
| model, tokenizer = load_model()
|
| except Exception as e:
|
| st.error(f"Failed to load model: {str(e)}")
|
| st.stop()
|
|
|
| if 'current_description' not in st.session_state:
|
| st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
|
| if 'current_question' not in st.session_state:
|
| st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
|
|
|
| def load_example(example):
|
| st.session_state.current_description = example["description"]
|
| st.session_state.current_question = example["question"]
|
|
|
| st.subheader("Примеры:")
|
| cols = st.columns(2)
|
| for i, example in enumerate(examples):
|
| col = cols[i % 2]
|
| if col.button(
|
| f"Пример {i+1}: {example['question'][:30]}...",
|
| key=f"example_{i}",
|
| ):
|
| load_example(example)
|
| st.rerun()
|
|
|
| with st.form("query_form"):
|
| description = st.text_area(
|
| "Описание таблицы (столбцы и их типы):",
|
| st.session_state.current_description,
|
| height=150,
|
| key="desc_input"
|
| )
|
|
|
| question = st.text_input(
|
| "Ваш вопрос:",
|
| st.session_state.current_question,
|
| key="question_input"
|
| )
|
|
|
| submitted = st.form_submit_button("Сгенерировать запрос")
|
|
|
| if submitted:
|
| if description and question:
|
| input_text = make_query(description, question)
|
| try:
|
| input_ids = tokenizer.encode(input_text, return_tensors="pt")
|
| with st.spinner("Генерация запроса..."):
|
| outputs = model.generate(
|
| input_ids,
|
| max_length=200,
|
| num_beams=5,
|
| early_stopping=True,
|
| pad_token_id=tokenizer.eos_token_id,
|
| )
|
| generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| st.subheader("Результат:")
|
| st.code(generated_sql, language="sql")
|
| except Exception as e:
|
| st.error(f"Ошибка при генерации: {str(e)}")
|
| else:
|
| st.warning("Пожалуйста, заполните описание таблицы и вопрос") |