daniel05 commited on
Commit
8833bb2
·
verified ·
1 Parent(s): 712381a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py CHANGED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ # URL da API (substitua pelo seu endpoint de API, ex: ngrok)
5
+ model_name = "Text-to-SQL T5 modelo 1" # Você pode trocar para o modelo que está usando, como "t5-base", "t5-large", etc.
6
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
7
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
8
+ st.set_page_config(layout="wide")
9
+ st.title("Text-to-SQL app")
10
+ def preditor_sql(texto):
11
+ # Prefixo para a tarefa de geração de SQL
12
+ input_text = "gerar SQL: " + texto
13
+
14
+ # Tokenizar a entrada
15
+ inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True)
16
+
17
+ # Gerar a saída do modelo
18
+ outputs = model.generate(inputs, max_length=300, num_beams=5, early_stopping=True)
19
+
20
+ # Decodificar a sequência gerada
21
+ predicao = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
+
23
+ return predicao
24
+ # CSS customizado para posicionar e estilizar as mensagens
25
+ def apply_custom_css():
26
+ st.markdown(
27
+ """
28
+ <style>
29
+ .user-message {
30
+ background-color: #DCF8C6;
31
+ color: black;
32
+ padding: 8px;
33
+ border-radius: 10px;
34
+ max-width: 70%;
35
+ margin-left: auto;
36
+ margin-right: 10px;
37
+ margin-top: 5px;
38
+ text-align: right;
39
+ font-size: 16px;
40
+ }
41
+ .assistant-message {
42
+ background-color: #ECECEC;
43
+ color: black;
44
+ padding: 8px;
45
+ border-radius: 10px;
46
+ max-width: 70%;
47
+ margin-left: 10px;
48
+ margin-top: 5px;
49
+ text-align: left;
50
+ font-size: 16px;
51
+ }
52
+ </style>
53
+ """,
54
+ unsafe_allow_html=True
55
+ )
56
+
57
+ # Chamar a função para aplicar o CSS
58
+ apply_custom_css()
59
+
60
+ # Inicializar o estado da sessão para armazenar o histórico de mensagens
61
+ if "messages" not in st.session_state:
62
+ st.session_state.messages = []
63
+
64
+ # Função para exibir o chat
65
+ def display_chat():
66
+ for message in st.session_state.messages:
67
+ if message["role"] == "user":
68
+ st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
69
+ else:
70
+ st.markdown(f'<div class="assistant-message">{message["content"]}</div>', unsafe_allow_html=True)
71
+
72
+ # Exibir o histórico de mensagens
73
+ display_chat()
74
+
75
+ # Input do chat para o usuário
76
+ if prompt := st.chat_input("Say something"):
77
+ # Armazenar a mensagem do usuário no estado da sessão
78
+ st.session_state.messages.append({"role": "user", "content": prompt})
79
+
80
+ # Exibir a mensagem do usuário imediatamente
81
+ st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)
82
+
83
+ # Fazer uma requisição POST para a API
84
+
85
+
86
+ # Armazenar a resposta da API no estado da sessão
87
+ prediction = preditor_sql(prompt)
88
+ st.session_state.messages.append({"role": "assistant", "content": prediction})
89
+
90
+ # Exibir a resposta do assistente
91
+ st.markdown(f'<div class="assistant-message">{prediction}</div>', unsafe_allow_html=True)