File size: 3,220 Bytes
8833bb2
 
 
 
dfae60a
 
 
 
 
 
 
8833bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
import requests
from transformers import T5Tokenizer, T5ForConditionalGeneration
# URL da API (substitua pelo seu endpoint de API, ex: ngrok)
@st.cache_data
def model1()
    model_name = "Text-to-SQL T5 modelo 1"  # Você pode trocar para o modelo que está usando, como "t5-base", "t5-large", etc.
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    return model,tokenizer
model,tokenizer = model1()
st.set_page_config(layout="wide")
st.title("Text-to-SQL app")
def preditor_sql(texto):
    # Prefixo para a tarefa de geração de SQL
    input_text = "gerar SQL: " + texto
    
    # Tokenizar a entrada
    inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True)

    # Gerar a saída do modelo
    outputs = model.generate(inputs, max_length=300, num_beams=5, early_stopping=True)

    # Decodificar a sequência gerada
    predicao = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return predicao
# CSS customizado para posicionar e estilizar as mensagens
def apply_custom_css():
    st.markdown(
        """
        <style>
        .user-message {
            background-color: #DCF8C6;
            color: black;
            padding: 8px;
            border-radius: 10px;
            max-width: 70%;
            margin-left: auto;
            margin-right: 10px;
            margin-top: 5px;
            text-align: right;
            font-size: 16px;
        }
        .assistant-message {
            background-color: #ECECEC;
            color: black;
            padding: 8px;
            border-radius: 10px;
            max-width: 70%;
            margin-left: 10px;
            margin-top: 5px;
            text-align: left;
            font-size: 16px;
        }
        </style>
        """,
        unsafe_allow_html=True
    )

# Chamar a função para aplicar o CSS
apply_custom_css()

# Inicializar o estado da sessão para armazenar o histórico de mensagens
if "messages" not in st.session_state:
    st.session_state.messages = []

# Função para exibir o chat
def display_chat():
    for message in st.session_state.messages:
        if message["role"] == "user":
            st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            st.markdown(f'<div class="assistant-message">{message["content"]}</div>', unsafe_allow_html=True)

# Exibir o histórico de mensagens
display_chat()

# Input do chat para o usuário
if prompt := st.chat_input("Say something"):
    # Armazenar a mensagem do usuário no estado da sessão
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    # Exibir a mensagem do usuário imediatamente
    st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)

    # Fazer uma requisição POST para a API
    
    
    # Armazenar a resposta da API no estado da sessão
    prediction = preditor_sql(prompt)
    st.session_state.messages.append({"role": "assistant", "content": prediction})
    
    # Exibir a resposta do assistente
    st.markdown(f'<div class="assistant-message">{prediction}</div>', unsafe_allow_html=True)