Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import requests | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| # URL da API (substitua pelo seu endpoint de API, ex: ngrok) | |
| def model1() | |
| model_name = "Text-to-SQL T5 modelo 1" # Você pode trocar para o modelo que está usando, como "t5-base", "t5-large", etc. | |
| tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| return model,tokenizer | |
| model,tokenizer = model1() | |
| st.set_page_config(layout="wide") | |
| st.title("Text-to-SQL app") | |
| def preditor_sql(texto): | |
| # Prefixo para a tarefa de geração de SQL | |
| input_text = "gerar SQL: " + texto | |
| # Tokenizar a entrada | |
| inputs = tokenizer.encode(input_text, return_tensors="pt", truncation=True) | |
| # Gerar a saída do modelo | |
| outputs = model.generate(inputs, max_length=300, num_beams=5, early_stopping=True) | |
| # Decodificar a sequência gerada | |
| predicao = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return predicao | |
| # CSS customizado para posicionar e estilizar as mensagens | |
| def apply_custom_css(): | |
| st.markdown( | |
| """ | |
| <style> | |
| .user-message { | |
| background-color: #DCF8C6; | |
| color: black; | |
| padding: 8px; | |
| border-radius: 10px; | |
| max-width: 70%; | |
| margin-left: auto; | |
| margin-right: 10px; | |
| margin-top: 5px; | |
| text-align: right; | |
| font-size: 16px; | |
| } | |
| .assistant-message { | |
| background-color: #ECECEC; | |
| color: black; | |
| padding: 8px; | |
| border-radius: 10px; | |
| max-width: 70%; | |
| margin-left: 10px; | |
| margin-top: 5px; | |
| text-align: left; | |
| font-size: 16px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Chamar a função para aplicar o CSS | |
| apply_custom_css() | |
| # Inicializar o estado da sessão para armazenar o histórico de mensagens | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Função para exibir o chat | |
| def display_chat(): | |
| for message in st.session_state.messages: | |
| if message["role"] == "user": | |
| st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True) | |
| else: | |
| st.markdown(f'<div class="assistant-message">{message["content"]}</div>', unsafe_allow_html=True) | |
| # Exibir o histórico de mensagens | |
| display_chat() | |
| # Input do chat para o usuário | |
| if prompt := st.chat_input("Say something"): | |
| # Armazenar a mensagem do usuário no estado da sessão | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Exibir a mensagem do usuário imediatamente | |
| st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True) | |
| # Fazer uma requisição POST para a API | |
| # Armazenar a resposta da API no estado da sessão | |
| prediction = preditor_sql(prompt) | |
| st.session_state.messages.append({"role": "assistant", "content": prediction}) | |
| # Exibir a resposta do assistente | |
| st.markdown(f'<div class="assistant-message">{prediction}</div>', unsafe_allow_html=True) |