import streamlit as st from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch import os def make_query(context, question): result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question. ### Database schema (PostgreSQL): {context} ### Question: {question} ### SQL Query: ''' return result_query st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🤖") st.title("SQL Query Generator with T5") examples = [ { "question": "Who was the music director in 1971 for the movie Kalyani?", "description": "CREATE TABLE table_name_7 (Music VARCHAR, year VARCHAR, movie__in_kannada__ VARCHAR)" }, { "question": "What's the highest with a capacity of greater than 4,000 and an average of 615?", "description": "CREATE TABLE table_name_20 (highest INTEGER, average VARCHAR, capacity VARCHAR)" }, { "question": "If the letters is φαν, what is the founding date?", "description": "CREATE TABLE table_2538117_7 (founding_date VARCHAR, letters VARCHAR)" }, { "question": "How many weeks had a game on November 26, 1978, and an attendance higher than 26,248?", "description": "CREATE TABLE table_name_88 (week VARCHAR, date VARCHAR, attendance VARCHAR)" }, { "question": "How many television service are in italian and n°is greater than 856.0?", "description": "CREATE TABLE table_15887683_15 (television_service VARCHAR, language VARCHAR, n° VARCHAR)" }, { "question": "What date was Bury the home team?", "description": "CREATE TABLE table_name_67 (date VARCHAR, away_team VARCHAR)" }, { "question": "What regular season result had an average attendance less than 942?", "description": "CREATE TABLE table_name_16 (reg_season VARCHAR, avg_attendance INTEGER)" }, { "question": "What is the value for 2011 when `a` is the value for 2009, and 4r is the value for 2013?", "description": "CREATE TABLE table_name_89 (Id VARCHAR)" }, { "question": "Who wrote episode with production code 1.01?", "description": "CREATE TABLE table_28089666_1 (written_by VARCHAR, production_code VARCHAR)" }, { "question": "find the names of museums which have more staff than the minimum staff number of all museums opened after 2010.", "description": "CREATE TABLE museum (name VARCHAR, num_of_staff INTEGER, open_year INTEGER)" } ] @st.cache_resource def load_model(): script_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(script_dir, "model") if not os.path.exists(model_path): raise FileNotFoundError(f"Model directory not found at {model_path}") try: tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") model = AutoModelForSeq2SeqLM.from_pretrained(model_path) model.eval() return model, tokenizer except Exception as e: raise RuntimeError(f"Error loading model: {str(e)}") try: model, tokenizer = load_model() except Exception as e: st.error(f"Failed to load model: {str(e)}") st.stop() if 'current_description' not in st.session_state: st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)""" if 'current_question' not in st.session_state: st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?" def load_example(example): st.session_state.current_description = example["description"] st.session_state.current_question = example["question"] st.subheader("Примеры:") cols = st.columns(2) for i, example in enumerate(examples): col = cols[i % 2] if col.button( f"Пример {i+1}: {example['question'][:30]}...", key=f"example_{i}", ): load_example(example) st.rerun() with st.form("query_form"): description = st.text_area( "Описание таблицы (столбцы и их типы):", st.session_state.current_description, height=150, key="desc_input" ) question = st.text_input( "Ваш вопрос:", st.session_state.current_question, key="question_input" ) submitted = st.form_submit_button("Сгенерировать запрос") if submitted: if description and question: input_text = make_query(description, question) try: input_ids = tokenizer.encode(input_text, return_tensors="pt") with st.spinner("Генерация запроса..."): outputs = model.generate( input_ids, max_length=200, num_beams=5, early_stopping=True, pad_token_id=tokenizer.eos_token_id, ) generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) st.subheader("Результат:") st.code(generated_sql, language="sql") except Exception as e: st.error(f"Ошибка при генерации: {str(e)}") else: st.warning("Пожалуйста, заполните описание таблицы и вопрос")