"""Interface de chat com modelo de linguagem."""
import streamlit as st
import base64
from pathlib import Path
from backend import load_model, ChatModel
from config import get_model_options
from send_to_hub import send_message_buckt
import uuid
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
st.set_page_config(
layout="wide",
page_title="Small LLM - Chat",
initial_sidebar_state="collapsed"
)
def set_new_message_uuid():
st.session_state.message_uuid = uuid.uuid4()
logging.info(f'[NEW MESSAGE UUID] {st.session_state.message_uuid}')
def save_messages():
if "chat_model" not in st.session_state or not st.session_state.chat_model.conversation.messages:
return
messages = [ {'role':m.role, 'content':m.content} for m in st.session_state.chat_model.conversation.messages]
message_uuid = st.session_state.message_uuid
new_data = {}
if "model_info" in st.session_state:
for k,v in st.session_state.model_info.items():
new_data[str(k)] = str(v)
new_data['model_name'] = str(st.session_state.get('model_name', 'unknown'))
send_message_buckt(messages, message_uuid, new_data, REPO_ID = "CEIA-POSITIVO2/smol_chat")
# Caminho da logo (relativo à raiz do projeto)
PROJECT_ROOT = Path(__file__).parent.parent
LOGO_PATH = PROJECT_ROOT / "positivo-logo.png"
# Header com logo e título usando HTML/CSS para melhor controle
logo_html = ""
if LOGO_PATH.exists():
try:
with open(LOGO_PATH, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode()
logo_html = f'
'
except Exception:
logo_html = ""
st.markdown(f"""
""", unsafe_allow_html=True)
def handle_model_load_error(model_name: str, error_msg: str):
"""Trata erros de carregamento de modelo."""
# Nota: Assumindo que GATED_MODELS está definido em config ou globalmente,
# caso contrário essa verificação precisaria ser ajustada.
st.error(f"❌ Erro ao carregar modelo {model_name}: {error_msg}")
# ============================================================================
# LÓGICA DE CARREGAMENTO AUTOMÁTICO (Substitui a Sidebar)
# ============================================================================
if "chat_model" not in st.session_state:
with st.spinner("Inicializando o sistema e carregando o modelo..."):
try:
# 1. Obter opções de modelo
model_options = get_model_options() # tuple(model_id, config)
if not model_options:
st.error("Nenhum modelo encontrado nas configurações.")
st.stop()
# 2. Selecionar automaticamente o PRIMEIRO modelo da lista
selected_label, selected_model_config = model_options[2]
logging.info(f'[MODEL] Loading {selected_label} {selected_model_config}')
# Preparar configuração
model_name = selected_model_config.pop('model_id')
selected_model_config['dtype'] = 'bfloat16'
print(f'Carregando modelo automático: {model_name}')
# 3. Carregar o modelo
pipeline, model_info = load_model(
model_name,
**selected_model_config
)
# 4. Salvar na sessão
chat_model = ChatModel(pipeline)
st.session_state.chat_model = chat_model
st.session_state.model_info = selected_model_config
st.session_state.model_name = model_name
set_new_message_uuid()
# Força o rerun para atualizar a interface sem o spinner
st.rerun()
except Exception as e:
handle_model_load_error(model_name if 'model_name' in locals() else "Unknown", str(e))
st.stop()
# ============================================================================
# INTERFACE DE CHAT
# ============================================================================
chat_model = st.session_state.chat_model
if "messages" not in st.session_state:
st.session_state.messages = []
# Sincroniza mensagens do objeto chat_model com a session_state
if len(chat_model.conversation.messages) != len(st.session_state.messages):
st.session_state.messages = [ {"role": msg.role, "content": msg.content} for msg in chat_model.conversation.messages ]
chat_placeholder = st.empty()
with chat_placeholder.container():
for message in st.session_state.messages:
role = message["role"]
content = message["content"]
if role == "system":
continue
with st.chat_message(role):
st.markdown(content)
if user_input := st.chat_input("Digite sua mensagem..."):
# Limpa o historico
st.session_state.messages = []
chat_model.conversation.clear()
chat_placeholder.empty()
set_new_message_uuid()
chat_model.add_user_message(user_input)
save_messages()
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
response_placeholder = st.empty()
full_response = ""
try:
for token in chat_model.generate_streaming(max_new_tokens=4096):
full_response += token
response_placeholder.markdown(full_response)
chat_model.add_assistant_message(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
save_messages()
except Exception as e:
error_msg = f"Erro na geração: {str(e)}"
st.error(error_msg)
st.session_state.messages.append({"role": "assistant", "content": error_msg})