Spaces:
Sleeping
Sleeping
File size: 6,434 Bytes
91c131d 05c318d 23aa078 88881cd a5721dd 14fffd7 c4935d0 88881cd d911a66 566e08a d911a66 88881cd 14fffd7 e8e5036 88881cd 509c44e 14fffd7 509c44e 14fffd7 509c44e 3821b68 88881cd 66a875b 91c131d 02c3825 91c131d 02c3825 91c131d a5721dd 14fffd7 a5721dd 14fffd7 a5721dd 14fffd7 d911a66 14fffd7 66a875b 14fffd7 d911a66 91c131d 14fffd7 093af2b 14fffd7 566e08a 14fffd7 566e08a 504bdc7 566e08a 91c131d 991811d 566e08a e28ba46 566e08a e28ba46 91c131d e28ba46 91c131d e28ba46 566e08a 91c131d 566e08a e28ba46 14fffd7 e28ba46 91c131d e28ba46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """Interface de chat com modelo de linguagem."""
import streamlit as st
import base64
from pathlib import Path
from backend import load_model, ChatModel
from config import get_model_options
from send_to_hub import send_message_buckt
import uuid
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
st.set_page_config(
layout="wide",
page_title="Small LLM - Chat",
initial_sidebar_state="collapsed"
)
def set_new_message_uuid():
st.session_state.message_uuid = uuid.uuid4()
logging.info(f'[NEW MESSAGE UUID] {st.session_state.message_uuid}')
def save_messages():
if "chat_model" not in st.session_state or not st.session_state.chat_model.conversation.messages:
return
messages = [ {'role':m.role, 'content':m.content} for m in st.session_state.chat_model.conversation.messages]
message_uuid = st.session_state.message_uuid
new_data = {}
if "model_info" in st.session_state:
for k,v in st.session_state.model_info.items():
new_data[str(k)] = str(v)
new_data['model_name'] = str(st.session_state.get('model_name', 'unknown'))
send_message_buckt(messages, message_uuid, new_data, REPO_ID = "CEIA-POSITIVO2/smol_chat")
# Caminho da logo (relativo à raiz do projeto)
PROJECT_ROOT = Path(__file__).parent.parent
LOGO_PATH = PROJECT_ROOT / "positivo-logo.png"
# Header com logo e título usando HTML/CSS para melhor controle
logo_html = ""
if LOGO_PATH.exists():
try:
with open(LOGO_PATH, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode()
logo_html = f'<img src="data:image/png;base64,{img_base64}" />'
except Exception:
logo_html = ""
st.markdown(f"""
<style>
.logo-header {{
display: flex;
align-items: center;
gap: 20px;
margin-bottom: 0.5rem;
}}
.logo-header img {{
width: 90px;
height: 90px;
object-fit: contain;
flex-shrink: 0;
}}
</style>
<div class="logo-header">
{logo_html}
<h1 style="margin: 0; padding: 0; display: inline-block;">Small LLM - Chat</h1>
</div>
""", unsafe_allow_html=True)
def handle_model_load_error(model_name: str, error_msg: str):
"""Trata erros de carregamento de modelo."""
# Nota: Assumindo que GATED_MODELS está definido em config ou globalmente,
# caso contrário essa verificação precisaria ser ajustada.
st.error(f"❌ Erro ao carregar modelo {model_name}: {error_msg}")
# ============================================================================
# LÓGICA DE CARREGAMENTO AUTOMÁTICO (Substitui a Sidebar)
# ============================================================================
if "chat_model" not in st.session_state:
with st.spinner("Inicializando o sistema e carregando o modelo..."):
try:
# 1. Obter opções de modelo
model_options = get_model_options() # tuple(model_id, config)
if not model_options:
st.error("Nenhum modelo encontrado nas configurações.")
st.stop()
# 2. Selecionar automaticamente o PRIMEIRO modelo da lista
selected_label, selected_model_config = model_options[2]
logging.info(f'[MODEL] Loading {selected_label} {selected_model_config}')
# Preparar configuração
model_name = selected_model_config.pop('model_id')
selected_model_config['dtype'] = 'bfloat16'
print(f'Carregando modelo automático: {model_name}')
# 3. Carregar o modelo
pipeline, model_info = load_model(
model_name,
**selected_model_config
)
# 4. Salvar na sessão
chat_model = ChatModel(pipeline)
st.session_state.chat_model = chat_model
st.session_state.model_info = selected_model_config
st.session_state.model_name = model_name
set_new_message_uuid()
# Força o rerun para atualizar a interface sem o spinner
st.rerun()
except Exception as e:
handle_model_load_error(model_name if 'model_name' in locals() else "Unknown", str(e))
st.stop()
# ============================================================================
# INTERFACE DE CHAT
# ============================================================================
chat_model = st.session_state.chat_model
if "messages" not in st.session_state:
st.session_state.messages = []
# Sincroniza mensagens do objeto chat_model com a session_state
if len(chat_model.conversation.messages) != len(st.session_state.messages):
st.session_state.messages = [ {"role": msg.role, "content": msg.content} for msg in chat_model.conversation.messages ]
chat_placeholder = st.empty()
with chat_placeholder.container():
for message in st.session_state.messages:
role = message["role"]
content = message["content"]
if role == "system":
continue
with st.chat_message(role):
st.markdown(content)
if user_input := st.chat_input("Digite sua mensagem..."):
# Limpa o historico
st.session_state.messages = []
chat_model.conversation.clear()
chat_placeholder.empty()
set_new_message_uuid()
chat_model.add_user_message(user_input)
save_messages()
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
response_placeholder = st.empty()
full_response = ""
try:
for token in chat_model.generate_streaming(max_new_tokens=4096):
full_response += token
response_placeholder.markdown(full_response)
chat_model.add_assistant_message(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
save_messages()
except Exception as e:
error_msg = f"Erro na geração: {str(e)}"
st.error(error_msg)
st.session_state.messages.append({"role": "assistant", "content": error_msg}) |