File size: 5,546 Bytes
6767d07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import streamlit as st
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import os
def make_query(context, question):
result_query = f'''You are a SQL expert with extensive experience, you need to create a query to answer the question.
### Database schema (PostgreSQL):
{context}
### Question:
{question}
### SQL Query: '''
return result_query
st.set_page_config(page_title="SQL-to-Text with T5", page_icon="🤖")
st.title("SQL Query Generator with T5")
examples = [
{
"question": "Who was the music director in 1971 for the movie Kalyani?",
"description": "CREATE TABLE table_name_7 (Music VARCHAR, year VARCHAR, movie__in_kannada__ VARCHAR)"
},
{
"question": "What's the highest with a capacity of greater than 4,000 and an average of 615?",
"description": "CREATE TABLE table_name_20 (highest INTEGER, average VARCHAR, capacity VARCHAR)"
},
{
"question": "If the letters is φαν, what is the founding date?",
"description": "CREATE TABLE table_2538117_7 (founding_date VARCHAR, letters VARCHAR)"
},
{
"question": "How many weeks had a game on November 26, 1978, and an attendance higher than 26,248?",
"description": "CREATE TABLE table_name_88 (week VARCHAR, date VARCHAR, attendance VARCHAR)"
},
{
"question": "How many television service are in italian and n°is greater than 856.0?",
"description": "CREATE TABLE table_15887683_15 (television_service VARCHAR, language VARCHAR, n° VARCHAR)"
},
{
"question": "What date was Bury the home team?",
"description": "CREATE TABLE table_name_67 (date VARCHAR, away_team VARCHAR)"
},
{
"question": "What regular season result had an average attendance less than 942?",
"description": "CREATE TABLE table_name_16 (reg_season VARCHAR, avg_attendance INTEGER)"
},
{
"question": "What is the value for 2011 when `a` is the value for 2009, and 4r is the value for 2013?",
"description": "CREATE TABLE table_name_89 (Id VARCHAR)"
},
{
"question": "Who wrote episode with production code 1.01?",
"description": "CREATE TABLE table_28089666_1 (written_by VARCHAR, production_code VARCHAR)"
},
{
"question": "find the names of museums which have more staff than the minimum staff number of all museums opened after 2010.",
"description": "CREATE TABLE museum (name VARCHAR, num_of_staff INTEGER, open_year INTEGER)"
}
]
@st.cache_resource
def load_model():
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, "model")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model directory not found at {model_path}")
try:
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
return model, tokenizer
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
try:
model, tokenizer = load_model()
except Exception as e:
st.error(f"Failed to load model: {str(e)}")
st.stop()
if 'current_description' not in st.session_state:
st.session_state.current_description = """CREATE TABLE table_name_28 (played INTEGER, points VARCHAR, position VARCHAR)"""
if 'current_question' not in st.session_state:
st.session_state.current_question = "Which Played has a Points of 2, and a Position smaller than 8?"
def load_example(example):
st.session_state.current_description = example["description"]
st.session_state.current_question = example["question"]
st.subheader("Примеры:")
cols = st.columns(2)
for i, example in enumerate(examples):
col = cols[i % 2]
if col.button(
f"Пример {i+1}: {example['question'][:30]}...",
key=f"example_{i}",
):
load_example(example)
st.rerun()
with st.form("query_form"):
description = st.text_area(
"Описание таблицы (столбцы и их типы):",
st.session_state.current_description,
height=150,
key="desc_input"
)
question = st.text_input(
"Ваш вопрос:",
st.session_state.current_question,
key="question_input"
)
submitted = st.form_submit_button("Сгенерировать запрос")
if submitted:
if description and question:
input_text = make_query(description, question)
try:
input_ids = tokenizer.encode(input_text, return_tensors="pt")
with st.spinner("Генерация запроса..."):
outputs = model.generate(
input_ids,
max_length=200,
num_beams=5,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id,
)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.subheader("Результат:")
st.code(generated_sql, language="sql")
except Exception as e:
st.error(f"Ошибка при генерации: {str(e)}")
else:
st.warning("Пожалуйста, заполните описание таблицы и вопрос") |